diff --git a/src/main/java/org/apache/sysds/api/DMLOptions.java b/src/main/java/org/apache/sysds/api/DMLOptions.java index a289b29bcde..9bbcbddfaa7 100644 --- a/src/main/java/org/apache/sysds/api/DMLOptions.java +++ b/src/main/java/org/apache/sysds/api/DMLOptions.java @@ -58,6 +58,7 @@ public class DMLOptions { public int[] statsNGramSizes = { 3 }; // Default n-gram tuple sizes public int statsTopKNGrams = 10; // How many of the most heavy hitting n-grams are displayed public boolean statsNGramsUseLineage = true; // If N-Grams use lineage for data-dependent tracking + public boolean applyGeneratedRewrites = false; // If generated rewrites should be applied public boolean fedStats = false; // Whether to record and print the federated statistics public int fedStatsCount = 10; // Default federated statistics count public boolean memStats = false; // max memory statistics @@ -246,6 +247,8 @@ else if (lineageType.equalsIgnoreCase("debugger")) } } + dmlOptions.applyGeneratedRewrites = line.hasOption("applyGeneratedRewrites"); + dmlOptions.fedStats = line.hasOption("fedStats"); if (dmlOptions.fedStats) { String fedStatsCount = line.getOptionValue("fedStats"); @@ -372,6 +375,7 @@ private static Options createCLIOptions() { Option ngramsOpt = OptionBuilder//.withArgName("ngrams") .withDescription("monitors and reports the most occurring n-grams; -ngrams ") .hasOptionalArgs(2).create("ngrams"); + Option applyGeneratedRewritesOpt = OptionBuilder.withArgName("applyGeneratedRewrites").withDescription("if automatically generated rewrites should be applied").create("applyGeneratedRewrites"); Option fedStatsOpt = OptionBuilder.withArgName("count") .withDescription("monitors and reports summary execution statistics of federated workers; heavy hitter is 10 unless overridden; default off") .hasOptionalArg().create("fedStats"); @@ -434,6 +438,7 @@ private static Options createCLIOptions() { options.addOption(cleanOpt); options.addOption(statsOpt); options.addOption(ngramsOpt); + options.addOption(applyGeneratedRewritesOpt); options.addOption(fedStatsOpt); options.addOption(memOpt); options.addOption(explainOpt); diff --git a/src/main/java/org/apache/sysds/api/DMLScript.java b/src/main/java/org/apache/sysds/api/DMLScript.java index d6853891e24..5777128396a 100644 --- a/src/main/java/org/apache/sysds/api/DMLScript.java +++ b/src/main/java/org/apache/sysds/api/DMLScript.java @@ -35,6 +35,8 @@ import java.util.Date; import java.util.Map; import java.util.Scanner; +import java.util.function.BiConsumer; +import java.util.function.Function; import org.apache.commons.cli.AlreadySelectedException; import org.apache.commons.cli.HelpFormatter; @@ -106,6 +108,7 @@ public class DMLScript public static int STATISTICS_TOP_K_NGRAMS = DMLOptions.defaultOptions.statsTopKNGrams; // Set if N-Grams use lineage for data-dependent tracking public static boolean STATISTICS_NGRAMS_USE_LINEAGE = DMLOptions.defaultOptions.statsNGramsUseLineage; + public static boolean APPLY_GENERATED_REWRITES = DMLOptions.defaultOptions.applyGeneratedRewrites; // Set statistics maximum wrap length public static int STATISTICS_MAX_WRAP_LEN = 30; // Enable/disable to print federated statistics @@ -168,6 +171,9 @@ public class DMLScript public static String _uuid = IDHandler.createDistributedUniqueID(); private static final Log LOG = LogFactory.getLog(DMLScript.class.getName()); + public static Function preHopInterceptor = null; // Intercepts HOPs before they are rewritten + public static Function hopInterceptor = null; // Intercepts HOPs after they are rewritten + /////////////////////////////// // public external interface //////// @@ -261,6 +267,7 @@ public static boolean executeScript( String[] args ) STATISTICS_NGRAMS = dmlOptions.statsNGrams; STATISTICS_NGRAM_SIZES = dmlOptions.statsNGramSizes; STATISTICS_TOP_K_NGRAMS = dmlOptions.statsTopKNGrams; + APPLY_GENERATED_REWRITES = dmlOptions.applyGeneratedRewrites; FED_STATISTICS = dmlOptions.fedStats; FED_STATISTICS_COUNT = dmlOptions.fedStatsCount; JMLC_MEM_STATISTICS = dmlOptions.memStats; @@ -456,9 +463,15 @@ private static void execute(String dmlScriptStr, String fnameOptConfig, Map params = new HashMap<>(); + params.put(DataExpression.RAND_ROWS, rows); + params.put(DataExpression.RAND_COLS, cols); + params.put(DataExpression.RAND_MIN, val); + params.put(DataExpression.RAND_MAX, val); + params.put(DataExpression.RAND_PDF, new LiteralOp(DataExpression.RAND_PDF_UNIFORM)); + params.put(DataExpression.RAND_LAMBDA, new LiteralOp(-1.0)); + params.put(DataExpression.RAND_SPARSITY, new LiteralOp(1.0)); + params.put(DataExpression.RAND_SEED, new LiteralOp(DataGenOp.UNSPECIFIED_SEED) ); + + //note internal refresh size information + Hop datagen = new DataGenOp(OpOpDG.RAND, new DataIdentifier("tmp"), params); + datagen.setBlocksize(1000); + //copyLineNumbers(rowInput, datagen); + + if( value==0 ) + datagen.setNnz(0); + + return datagen; + } public static Hop createDataGenOp( Hop rowInput, Hop colInput, double value ) { @@ -661,6 +685,84 @@ public static BinaryOp createBinary(Hop input1, Hop input2, OpOp2 op, boolean ou bop.refreshSizeInformation(); return bop; } + + // To fix issues with createBinary, which does not always correctly set value types (e.g. INT-MATRIX+FLOAT-SCALAR -> bop(+)::INT) + public static BinaryOp createAutoGeneratedBinary(Hop input1, Hop input2, OpOp2 op) { + Hop mainInput = input1.getDataType().isMatrix() ? input1 : + input2.getDataType().isMatrix() ? input2 : input1; + BinaryOp bop = new BinaryOp(mainInput.getName(), getImplicitDataType(input1, input2), + getImplicitValueType(input1, input2), op, input1, input2); + //cleanup value type for relational operations + if( bop.isPPredOperation() && bop.getDataType().isScalar() ) + bop.setValueType(ValueType.BOOLEAN); + bop.setOuterVectorOperation(false); + bop.setBlocksize(mainInput.getBlocksize()); + copyLineNumbers(mainInput, bop); + bop.refreshSizeInformation(); + return bop; + } + + public static DataType getImplicitDataType(Hop... inputs) { + for (int i = 0; i < inputs.length; i++) + if (inputs[i].getDataType().isMatrix()) + return inputs[i].getDataType(); + + return inputs[0].getDataType(); + } + + public static ValueType getImplicitValueType(Hop... inputs) { + ValueType out = null; + for (int i = 0; i < inputs.length; i++) { + switch (inputs[i].getValueType()) { + case FP64: + return inputs[i].getValueType(); + case FP32: + out = inputs[i].getValueType(); + break; + case INT64: + out = implicitValueType(out, ValueType.INT64); + break; + case INT32: + out = implicitValueType(out, ValueType.INT32); + break; + case BOOLEAN: + out = implicitValueType(out, ValueType.BOOLEAN); + break; + } + } + + return out == null ? inputs[0].getValueType() : out; + } + + private static ValueType implicitValueType(ValueType type1, ValueType type2) { + int rank1 = getTypeRank(type1); + int rank2 = getTypeRank(type2); + + if (rank1 == Integer.MIN_VALUE && rank2 == Integer.MIN_VALUE) + return null; + + return rank1 > rank2 ? type1 : type2; + } + + private static int getTypeRank(ValueType vt) { + if (vt == null) + return Integer.MIN_VALUE; + + switch (vt) { + case FP64: + return 5; + case FP32: + return 4; + case INT64: + return 3; + case INT32: + return 2; + case BOOLEAN: + return 1; + } + + return Integer.MIN_VALUE; + } public static AggUnaryOp createSum( Hop input ) { return createAggUnaryOp(input, AggOp.SUM, Direction.RowCol); diff --git a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java index b08d836efe5..357f2860bc4 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java @@ -27,6 +27,9 @@ import org.apache.sysds.conf.CompilerConfig.ConfigType; import org.apache.sysds.hops.Hop; import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.hops.rewriter.generated.GeneratedRewriteClass; +import org.apache.sysds.hops.rewriter.generated.RewriteAutomaticallyGenerated; +import org.apache.sysds.hops.rewriter.dml.DMLExecutor; import org.apache.sysds.parser.DMLProgram; import org.apache.sysds.parser.ForStatement; import org.apache.sysds.parser.ForStatementBlock; @@ -83,6 +86,9 @@ public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites) _dagRuleSet.add( new RewriteCommonSubexpressionElimination() ); if( OptimizerUtils.ALLOW_CONSTANT_FOLDING ) _dagRuleSet.add( new RewriteConstantFolding() ); //dependency: cse + if ( DMLScript.APPLY_GENERATED_REWRITES ) { + _dagRuleSet.add(new RewriteAutomaticallyGenerated(new GeneratedRewriteClass())); + } if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION ) _dagRuleSet.add( new RewriteAlgebraicSimplificationStatic() ); //dependencies: cse if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION ) //dependency: simplifications (no need to merge leafs again) @@ -124,6 +130,9 @@ public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites) if ( DMLScript.USE_ACCELERATOR ){ _dagRuleSet.add( new RewriteGPUSpecificOps() ); // gpu-specific rewrites } + if ( DMLScript.APPLY_GENERATED_REWRITES ) { + _dagRuleSet.add(new RewriteAutomaticallyGenerated(new GeneratedRewriteClass())); + } if ( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES) { _dagRuleSet.add( new RewriteMatrixMultChainOptimization() ); //dependency: cse if( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES2 ) diff --git a/src/main/java/org/apache/sysds/hops/rewriter/MetaPropagator.java b/src/main/java/org/apache/sysds/hops/rewriter/MetaPropagator.java new file mode 100644 index 00000000000..bfe9a1a880f --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/MetaPropagator.java @@ -0,0 +1,369 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter; + +import org.apache.commons.lang3.NotImplementedException; +import org.apache.commons.lang3.mutable.MutableObject; +import org.apache.sysds.hops.rewriter.assertions.RewriterAssertions; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; + +import java.util.HashMap; +import java.util.Optional; +import java.util.UUID; +import java.util.function.Function; + +/** + * This class is used to propagate dimension information. + * Each instruction that produces a matrix must be implemented here. + */ +public class MetaPropagator implements Function { + private final RuleContext ctx; + + public MetaPropagator(RuleContext ctx) { + this.ctx = ctx; + } + + public RewriterStatement apply(RewriterStatement root) { + RewriterAssertions assertions = root.getAssertions(ctx); + MutableObject out = new MutableObject<>(root); + HashMap literalMap = new HashMap<>(); + + root.forEachPostOrderWithDuplicates((el, parent, pIdx) -> { + RewriterStatement toSet = propagateDims(el, parent, pIdx, assertions); + + if (toSet != null && toSet != el) { + el = toSet; + if (parent == null) + out.setValue(toSet); + else + parent.getOperands().set(pIdx, toSet); + } + + // Assert + if (el.getResultingDataType(ctx).startsWith("MATRIX") + && (el.getNCol() == null || el.getNRow() == null)) + throw new IllegalArgumentException("Some properties have not been set by the meta propagator: " + el.toString(ctx) + " :: " + el.getResultingDataType(ctx)); + + + // Eliminate common literals + if (el.isLiteral()) { + RewriterStatement existingLiteral = literalMap.get(el.getLiteral()); + + if (existingLiteral != null) { + if (parent == null) + out.setValue(existingLiteral); + else + parent.getOperands().set(pIdx, existingLiteral); + } else { + literalMap.put(el.getLiteral(), el); + } + } + + validate(el); + }); + + return out.getValue(); + } + + private RewriterStatement propagateDims(RewriterStatement root, RewriterStatement parent, int pIdx, RewriterAssertions assertions) { + if (root.getResultingDataType(ctx) == null) + throw new IllegalArgumentException("Null type: " + root.toParsableString(ctx)); + if (!root.getResultingDataType(ctx).startsWith("MATRIX")) { + if (root.isInstruction()) { + String ti = root.trueTypedInstruction(ctx); + RewriterStatement ret = null; + + switch (ti) { + case "ncol(MATRIX)": + ret = (RewriterStatement)root.getOperands().get(0).getMeta("ncol"); + break; + case "nrow(MATRIX)": + ret = (RewriterStatement)root.getOperands().get(0).getMeta("nrow"); + break; + } + + if (ret == null) + return null; + + RewriterStatement asserted = assertions != null ? assertions.getAssertionStatement(ret, parent) : null; + + if (asserted == null) + return ret; + + return asserted; + } + return null; + } + + Object colAccess; + Object rowAccess; + + if (root.getOperands() == null || root.getOperands().isEmpty()) { + RewriterStatement ncol = root.getNCol(); + + if (ncol == null) { + root.unsafePutMeta("ncol", new RewriterInstruction().withInstruction("ncol").withOps(root).as(UUID.randomUUID().toString()).consolidate(ctx)); + } + + RewriterStatement nrow = root.getNRow(); + + if (nrow == null) { + root.unsafePutMeta("nrow", new RewriterInstruction().withInstruction("nrow").withOps(root).as(UUID.randomUUID().toString()).consolidate(ctx)); + } + + return null; + } + + if (root.isInstruction()) { + Optional firstMatrixStatement = root.getOperands().stream().filter(el -> el.getResultingDataType(ctx).startsWith("MATRIX")).findFirst(); + switch(root.trueInstruction()) { + // Handle generators + case "rand": + root.unsafePutMeta("nrow", root.getOperands().get(0)); + root.unsafePutMeta("ncol", root.getOperands().get(1)); + return null; + case "as.matrix": + root.unsafePutMeta("ncol", new RewriterDataType().ofType("INT").as("1").asLiteral(1L).consolidate(ctx)); + root.unsafePutMeta("nrow", new RewriterDataType().ofType("INT").as("1").asLiteral(1L).consolidate(ctx)); + return null; + case "argList": + // We assume argLists always occur if the matrix properties don't change + root.unsafePutMeta("nrow", firstMatrixStatement.get().getMeta("nrow")); + root.unsafePutMeta("ncol", firstMatrixStatement.get().getMeta("ncol")); + return null; + case "_map": + root.unsafePutMeta("nrow", root.getOperands().get(1).getMeta("nrow")); + root.unsafePutMeta("ncol", root.getOperands().get(1).getMeta("ncol")); + return null; + case "+": + case "-": + case "*": + case "inv": + case "==": + case "!=": + case "&": + case "|": + case "<": + case ">": + case "abs": + case "round": + case "exp": + case "^": + if (firstMatrixStatement.isEmpty()) + throw new IllegalArgumentException(root.toString(ctx) + " has empty args!"); + root.unsafePutMeta("nrow", firstMatrixStatement.get().getMeta("nrow")); + root.unsafePutMeta("ncol", firstMatrixStatement.get().getMeta("ncol")); + return null; + case "cast.MATRIX": + String mDT = root.getChild(0).getResultingDataType(ctx); + if (mDT.equals("BOOL") || mDT.equals("INT") || mDT.equals("FLOAT")) { + root.unsafePutMeta("ncol", new RewriterDataType().ofType("INT").as("1").asLiteral(1L).consolidate(ctx)); + root.unsafePutMeta("nrow", new RewriterDataType().ofType("INT").as("1").asLiteral(1L).consolidate(ctx)); + return null; + } + case "log_nz": + case "log": + root.unsafePutMeta("nrow", root.getOperands().get(0).getMeta("nrow")); + root.unsafePutMeta("ncol", root.getOperands().get(0).getMeta("ncol")); + return null; + } + + switch(root.trueTypedInstruction(ctx)) { + case "t(MATRIX)": + colAccess = root.getOperands().get(0).getMeta("ncol"); + rowAccess = root.getOperands().get(0).getMeta("nrow"); + root.unsafePutMeta("ncol", rowAccess); + root.unsafePutMeta("nrow", colAccess); + return null; + case "_m(INT,INT,FLOAT)": + case "_m(INT,INT,BOOL)": + case "_m(INT,INT,INT)": + if (root.getOperands().get(0).isInstruction() + && root.getOperands().get(0).trueTypedInstruction(ctx).equals("_idx(INT,INT)")) { + root.unsafePutMeta("nrow", root.getOperands().get(0).getOperands().get(1)); + } else { + root.unsafePutMeta("nrow", new RewriterDataType().ofType("INT").as("1").asLiteral(1L).consolidate(ctx)); + } + + if (root.getOperands().get(1).isInstruction() + && root.getOperands().get(1).trueTypedInstruction(ctx).equals("_idx(INT,INT)")) { + root.unsafePutMeta("ncol", root.getOperands().get(1).getOperands().get(1)); + } else { + root.unsafePutMeta("ncol", new RewriterDataType().ofType("INT").as("1").asLiteral(1L).consolidate(ctx)); + } + return null; + case "%*%(MATRIX,MATRIX)": + rowAccess = root.getOperands().get(0).getMeta("nrow"); + colAccess = root.getOperands().get(1).getMeta("ncol"); + root.unsafePutMeta("nrow", rowAccess); + root.unsafePutMeta("ncol", colAccess); + return null; + case "diag(MATRIX)": + root.unsafePutMeta("nrow", root.getOperands().get(0).getMeta("nrow")); + root.unsafePutMeta("ncol", root.getOperands().get(0).getMeta("ncol")); + return null; + case "[](MATRIX,INT,INT,INT,INT)": + Long[] ints = new Long[4]; + + for (int i = 1; i < 5; i++) + if (root.getChild(i).isLiteral()) + if (root.getChild(i).getLiteral() instanceof Integer) + ints[i-1] = (Long)root.getChild(i).getLiteral(); + + if (ints[0] != null && ints[1] != null) { + String literalString = Long.toString(ints[1] - ints[0] + 1); + root.unsafePutMeta("nrow", RewriterUtils.foldConstants(RewriterUtils.parse(literalString, ctx, "LITERAL_INT:" + literalString), ctx)); + } else { + HashMap subStmts = new HashMap<>(); + subStmts.put("i1", root.getOperands().get(2)); + subStmts.put("i0", root.getOperands().get(1)); + + if (ints[0] != null) { + root.unsafePutMeta("nrow", RewriterUtils.foldConstants(RewriterUtils.parse("+(argList(i1, " + (1 - ints[0]) + "))", ctx, subStmts, "LITERAL_INT:" + (1 - ints[0])), ctx)); + } else if (ints[1] != null) { + root.unsafePutMeta("nrow", RewriterUtils.foldConstants(RewriterUtils.parse("+(argList(" + (ints[1] + 1) + ", -(i0)))", ctx, subStmts, "LITERAL_INT:" + (ints[1] + 1)), ctx)); + } else { + root.unsafePutMeta("nrow", RewriterUtils.foldConstants(RewriterUtils.parse("+(argList(i1, -(i0), 1))", ctx, subStmts, "LITERAL_INT:1"), ctx)); + } + } + + if (ints[2] != null && ints[3] != null) { + root.unsafePutMeta("ncol", ints[3] - ints[2] + 1); + } else { + HashMap subStmts = new HashMap<>(); + subStmts.put("i3", root.getOperands().get(4)); + subStmts.put("i2", root.getOperands().get(3)); + if (ints[2] != null) { + root.unsafePutMeta("ncol", RewriterUtils.foldConstants(RewriterUtils.parse("+(argList(i3, " + (1 - ints[2]) + "))", ctx, subStmts, "LITERAL_INT:" + (1 - ints[2])), ctx)); + } else if (ints[3] != null) { + root.unsafePutMeta("ncol", RewriterUtils.foldConstants(RewriterUtils.parse("+(argList(" + (ints[3] + 1) + ", -(i2)))", ctx, subStmts, "LITERAL_INT:" + (ints[3] + 1)), ctx)); + } else { + root.unsafePutMeta("ncol", RewriterUtils.foldConstants(RewriterUtils.parse("+(argList(i3, -(i2), 1))", ctx, subStmts, "LITERAL_INT:1"), ctx)); + } + } + + return null; + case "rowSums(MATRIX)": + root.unsafePutMeta("nrow", root.getOperands().get(0).getMeta("nrow")); + root.unsafePutMeta("ncol", new RewriterDataType().ofType("INT").as("1").asLiteral(1L).consolidate(ctx)); + return null; + case "colSums(MATRIX)": + root.unsafePutMeta("ncol", root.getOperands().get(0).getMeta("ncol")); + root.unsafePutMeta("nrow", new RewriterDataType().ofType("INT").as("1").asLiteral(1L).consolidate(ctx)); + return null; + case "cast.MATRIX(MATRIX)": + root.unsafePutMeta("nrow", root.getOperands().get(0).getMeta("nrow")); + root.unsafePutMeta("ncol", root.getOperands().get(0).getMeta("ncol")); + return null; + case "RBind(MATRIX,MATRIX)": + HashMap mstmts = new HashMap<>(); + mstmts.put("row1", (RewriterStatement)root.getOperands().get(0).getMeta("nrow")); + mstmts.put("row2", (RewriterStatement)root.getOperands().get(1).getMeta("nrow")); + root.unsafePutMeta("nrow", RewriterUtils.parse("+(argList(row1, row2))", ctx, mstmts)); + root.unsafePutMeta("ncol", root.getOperands().get(0).getMeta("ncol")); + return null; + case "CBind(MATRIX,MATRIX)": + mstmts = new HashMap<>(); + mstmts.put("col1", (RewriterStatement)root.getOperands().get(0).getMeta("ncol")); + mstmts.put("col2", (RewriterStatement)root.getOperands().get(1).getMeta("ncol")); + root.unsafePutMeta("nrow", root.getOperands().get(0).getMeta("nrow")); + root.unsafePutMeta("ncol", RewriterUtils.parse("+(argList(col1, col2))", ctx, mstmts)); + return null; + + // Fused ops + case "1-*(MATRIX,MATRIX)": + case "log_nz(MATRIX)": + root.unsafePutMeta("nrow", root.getOperands().get(0).getMeta("nrow")); + root.unsafePutMeta("ncol", root.getOperands().get(0).getMeta("ncol")); + return null; + case "const(MATRIX,FLOAT)": + root.unsafePutMeta("nrow", root.getOperands().get(0).getMeta("nrow")); + root.unsafePutMeta("ncol", root.getOperands().get(0).getMeta("ncol")); + return null; + case "rowVec(MATRIX)": + root.unsafePutMeta("nrow", root.getOperands().get(0).getMeta("nrow")); + root.unsafePutMeta("ncol", RewriterStatement.literal(ctx, 1L)); + return null; + case "colVec(MATRIX)": + root.unsafePutMeta("ncol", root.getOperands().get(0).getMeta("ncol")); + root.unsafePutMeta("nrow", RewriterStatement.literal(ctx, 1L)); + return null; + case "cellMat(MATRIX)": + root.unsafePutMeta("ncol", RewriterStatement.literal(ctx, 1L)); + root.unsafePutMeta("nrow", RewriterStatement.literal(ctx, 1L)); + return null; + case "rev(MATRIX)": + case "replace(MATRIX,FLOAT,FLOAT)": + case "sumSq(MATRIX)": + case "+*(MATRIX,FLOAT,MATRIX)": + case "-*(MATRIX,FLOAT,MATRIX)": + case "*2(MATRIX)": + case "sq(MATRIX)": + case "!(MATRIX)": + root.unsafePutMeta("nrow", root.getChild(0).getMeta("nrow")); + root.unsafePutMeta("ncol", root.getChild(0).getMeta("ncol")); + return null; + } + + RewriterInstruction instr = (RewriterInstruction) root; + + if (instr.getProperties(ctx).contains("ElementWiseInstruction")) { + if (root.getOperands().get(0).getResultingDataType(ctx).startsWith("MATRIX")) { + root.unsafePutMeta("nrow", root.getOperands().get(0).getMeta("nrow")); + root.unsafePutMeta("ncol", root.getOperands().get(0).getMeta("ncol")); + } else { + root.unsafePutMeta("nrow", root.getOperands().get(1).getMeta("nrow")); + root.unsafePutMeta("ncol", root.getOperands().get(1).getMeta("ncol")); + } + + return null; + } + + if (instr.getProperties(ctx).contains("ElementWiseUnary.FLOAT")) { + if (root.getOperands().get(0).getResultingDataType(ctx).startsWith("MATRIX")) { + root.unsafePutMeta("nrow", root.getOperands().get(0).getMeta("nrow")); + root.unsafePutMeta("ncol", root.getOperands().get(0).getMeta("ncol")); + } else { + root.unsafePutMeta("nrow", root.getOperands().get(1).getMeta("nrow")); + root.unsafePutMeta("ncol", root.getOperands().get(1).getMeta("ncol")); + } + + return null; + } + + throw new NotImplementedException("Unknown instruction: " + instr.trueTypedInstruction(ctx) + "\n" + instr.toParsableString(ctx)); + } + + return null; + } + + private void validate(RewriterStatement stmt) { + if (stmt.isInstruction()) { + if (stmt.trueInstruction().equals("_idx") && (stmt.getMeta("ownerId") == null || stmt.getMeta("idxId") == null)) + throw new IllegalArgumentException(stmt.toString(ctx)); + + if (stmt.trueInstruction().equals("_m") && stmt.getMeta("ownerId") == null) + throw new IllegalArgumentException(stmt.toString(ctx)); + + if (stmt.getResultingDataType(ctx) == null) + throw new IllegalArgumentException(stmt.toString(ctx)); + } + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriterContextSettings.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriterContextSettings.java new file mode 100644 index 00000000000..f1cc25fa095 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriterContextSettings.java @@ -0,0 +1,299 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter; + +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; + +import java.util.List; +import java.util.Random; + +public class RewriterContextSettings { + + public static final List ALL_TYPES = List.of("FLOAT", "INT", "BOOL", "MATRIX"); + public static final List SCALARS = List.of("FLOAT", "INT", "BOOL"); + + public static String getDefaultContextString() { + StringBuilder builder = new StringBuilder(); + ALL_TYPES.forEach(t -> { + builder.append("argList(" + t + ")::" + t + "...\n"); + builder.append("argList(" + t + "...)::" + t + "...\n"); + }); // This is a meta function that can take any number of arguments + + builder.append("CBind(MATRIX,MATRIX)::MATRIX\n"); // This instruction is not really supported + builder.append("RBind(MATRIX,MATRIX)::MATRIX\n"); // This instruction is not really supported + + builder.append("sum(MATRIX)::FLOAT\n"); + builder.append("rowSums(MATRIX)::MATRIX\n"); + builder.append("colSums(MATRIX)::MATRIX\n"); + + builder.append("max(MATRIX)::FLOAT\n"); // Support for min/max is limited + builder.append("min(MATRIX)::FLOAT\n"); // Support for min/max is limited + + builder.append("%*%(MATRIX,MATRIX)::MATRIX\n"); + + builder.append("rev(MATRIX)::MATRIX\n"); + builder.append("t(MATRIX)::MATRIX\n"); + + RewriterUtils.buildBinaryPermutations(List.of("INT", "FLOAT", "BOOL"), (t1, t2) -> { + builder.append("BinaryScalarInstruction(" + t1 + "," + t2 + ")::" + RewriterUtils.defaultTypeHierarchy(t1, t2) + "\n"); + builder.append("impl ElementWiseInstruction\n"); + }); + + RewriterUtils.buildBinaryPermutations(List.of("MATRIX...", "MATRIX", "INT", "FLOAT", "BOOL"), (t1, t2) -> { + builder.append("ElementWiseInstruction(" + t1 + "," + t2 + ")::" + RewriterUtils.defaultTypeHierarchy(t1, t2) + "\n"); + builder.append("impl ElementWiseSumExpandableInstruction\n"); + builder.append("impl /\n"); + builder.append("impl max\n"); + builder.append("impl min\n"); + builder.append("impl ^\n"); + builder.append("impl >\n"); + builder.append("impl <\n"); + builder.append("impl >=\n"); + builder.append("impl <=\n"); + builder.append("impl ==\n"); + builder.append("impl |\n"); + builder.append("impl &\n"); + builder.append("impl /\n"); + builder.append("impl !=\n"); + }); + + builder.append("ElementWiseInstruction(MATRIX...)::MATRIX\n"); + builder.append("impl ElementWiseSumExpandableInstruction\n"); + builder.append("impl /\n"); + builder.append("impl max\n"); + builder.append("impl min\n"); + builder.append("impl ^\n"); + builder.append("impl >\n"); + builder.append("impl <\n"); + builder.append("impl >=\n"); + builder.append("impl <=\n"); + builder.append("impl ==\n"); + builder.append("impl |\n"); + builder.append("impl &\n"); + + RewriterUtils.buildBinaryPermutations(List.of("MATRIX...", "MATRIX", "INT", "FLOAT", "BOOL"), (t1, t2) -> { + builder.append("ElementWiseSumExpandableInstruction(" + t1 + "," + t2 + ")::" + RewriterUtils.defaultTypeHierarchy(t1, t2) + "\n"); // Any instruction that allows op(sum(A*), sum(B*)) = sum(op(A, B)) + builder.append("impl ElementWiseAdditiveInstruction\n"); + builder.append("impl *\n"); + + builder.append("ElementWiseAdditiveInstruction(" + t1 + "," + t2 + ")::" + RewriterUtils.defaultTypeHierarchy(t1, t2) + "\n"); + builder.append("impl +\n"); + builder.append("impl -\n"); + }); + + builder.append("ElementWiseAdditiveInstruction(MATRIX...)::MATRIX\n"); + builder.append("impl +\n"); + //builder.append("impl -\n"); + + + ALL_TYPES.forEach(t -> { + builder.append("UnaryElementWiseOperator(" + t + ")::" + t + "\n"); + builder.append("impl -\n"); + builder.append("impl abs\n"); + builder.append("impl !\n"); + builder.append("impl round\n"); + }); + + builder.append("rowSelect(MATRIX,INT,INT)::MATRIX\n"); + builder.append("colSelect(MATRIX,INT,INT)::MATRIX\n"); + builder.append("min(INT,INT)::INT\n"); + builder.append("max(INT,INT)::INT\n"); + + builder.append("index(MATRIX,INT,INT,INT,INT)::MATRIX\n"); + + RewriterUtils.buildBinaryPermutations(List.of("MATRIX...", "MATRIX", "INT", "FLOAT", "BOOL"), (t1, t2) -> { + builder.append("FusableBinaryOperator(" + t1 + "," + t2 + ")::" + RewriterUtils.defaultTypeHierarchy(t1, t2) + "\n"); + builder.append("impl +\n"); + builder.append("impl *\n"); + }); + + List.of("MATRIX", "INT", "FLOAT", "BOOL").forEach(t -> { + builder.append("FusedOperator(" + t + "...)::" + t + "\n"); + builder.append("impl +\n"); + builder.append("impl *\n"); + }); + + builder.append("ncol(MATRIX)::INT\n"); + builder.append("nrow(MATRIX)::INT\n"); + builder.append("length(MATRIX)::INT\n"); + + RewriterUtils.buildBinaryAlgebraInstructions(builder, "+", List.of("INT", "FLOAT", "BOOL", "MATRIX")); + RewriterUtils.buildBinaryAlgebraInstructions(builder, "*", List.of("INT", "FLOAT", "BOOL", "MATRIX")); + RewriterUtils.buildBinaryAlgebraInstructions(builder, "^", ALL_TYPES); + ALL_TYPES.forEach(t -> builder.append("-(" + t + ")::" + t + "\n")); + ALL_TYPES.forEach(t -> builder.append("inv(" + t + ")::" + t + "\n")); + + + builder.append("as.matrix(INT)::MATRIX\n"); + builder.append("as.matrix(FLOAT)::MATRIX\n"); + builder.append("as.matrix(BOOL)::MATRIX\n"); + builder.append("as.scalar(MATRIX)::FLOAT\n"); + builder.append("as.scalar(FLOAT)::FLOAT\n"); + builder.append("as.float(INT)::FLOAT\n"); + builder.append("as.float(BOOL)::FLOAT\n"); + builder.append("as.int(BOOL)::INT\n"); + + RewriterUtils.buildBinaryPermutations(ALL_TYPES, (tFrom, tTo) -> { + builder.append("cast." + tTo + "(" + tFrom + ")::" + tTo + "\n"); + }); + + builder.append("rand(INT,INT,FLOAT,FLOAT)::MATRIX\n"); // Args: rows, cols, min, max + builder.append("rand(INT,INT)::FLOAT\n"); // Just to make it possible to say that random is dependent on both matrix indices + builder.append("rand(INT...)::FLOAT\n"); + builder.append("matrix(INT,INT,INT)::MATRIX\n"); + + builder.append("trace(MATRIX)::FLOAT\n"); + + // Boole algebra + + RewriterUtils.buildBinaryPermutations(List.of("MATRIX", "FLOAT", "INT", "BOOL"), (t1, t2) -> { + String ret = t1.equals("MATRIX") || t2.equals("MATRIX") ? "MATRIX" : "BOOL"; + builder.append("==(" + t1 + "," + t2 + ")::" + ret + "\n"); + builder.append("!=(" + t1 + "," + t2 + ")::" + ret + "\n"); + builder.append("<(" + t1 + "," + t2 + ")::" + ret + "\n"); + builder.append("<=(" + t1 + "," + t2 + ")::" + ret + "\n"); + builder.append(">(" + t1 + "," + t2 + ")::" + ret + "\n"); + builder.append(">=(" + t1 + "," + t2 + ")::" + ret + "\n"); + builder.append("&(" + t1 + "," + t2 + ")::" + ret + "\n"); + builder.append("|(" + t1 + "," + t2 + ")::" + ret + "\n"); + }); + + List.of("MATRIX", "FLOAT", "INT", "BOOL").forEach(t -> { + builder.append("!(" + t + ")::" + (t.equals("MATRIX") ? "MATRIX" : "BOOL") + "\n"); + }); + + // Expressions that will be rewritten to an equivalent expression + RewriterUtils.buildBinaryPermutations(ALL_TYPES, (t1, t2) -> { + builder.append("-(" + t1+ "," + t2 + ")::" + RewriterUtils.defaultTypeHierarchy(t1, t2) + "\n"); + builder.append("/(" + t1+ "," + t2 + ")::" + RewriterUtils.defaultTypeHierarchy(t1, t2) + "\n"); + }); + + // Unary ops + ALL_TYPES.forEach(t -> { + builder.append("ElementWiseUnary.FLOAT(" + t + ")::" + (t.equals("MATRIX") ? "MATRIX" : "FLOAT") + "\n"); + builder.append("impl sqrt\n"); + builder.append("impl exp\n"); + builder.append("impl log\n"); + builder.append("impl inv\n"); + }); + + builder.append("[](MATRIX,INT,INT)::FLOAT\n"); + builder.append("[](MATRIX,INT,INT,INT,INT)::MATRIX\n"); + builder.append("diag(MATRIX)::MATRIX\n"); + builder.append("replace(MATRIX,FLOAT,FLOAT)::MATRIX\n"); // This is not supported + builder.append("_nnz(MATRIX)::INT\n"); + builder.append("sumSq(MATRIX)::FLOAT\n"); + builder.append("sq(MATRIX)::MATRIX\n"); + builder.append("+*(MATRIX,FLOAT,MATRIX)::MATRIX\n"); + builder.append("-*(MATRIX,FLOAT,MATRIX)::MATRIX\n"); + builder.append("*2(MATRIX)::MATRIX\n"); + + for (String t : SCALARS) { + for (String t2 : SCALARS) + builder.append("ifelse(BOOL," + t + "," + t2 + ")::" + RewriterUtils.convertibleType(t, t2) + "\n"); + } + + + List.of("INT", "FLOAT", "BOOL").forEach(t -> { + String newType = t.equals("BOOL") ? "INT" : t; + builder.append("sum(" + t + "...)::" + newType + "\n"); + builder.append("sum(" + t + "*)::" + newType + "\n"); + builder.append("sum(" + t + ")::" + newType + "\n"); + + builder.append("min(" + t + "...)::" + t + "\n"); + builder.append("min(" + t + "*)::" + t + "\n"); + builder.append("min(" + t + ")::" + t + "\n"); + + builder.append("max(" + t + "...)::" + t + "\n"); + builder.append("max(" + t + "*)::" + t + "\n"); + builder.append("max(" + t + ")::" + t + "\n"); + }); + + // Some fused operators + builder.append("1-*(MATRIX,MATRIX)::MATRIX\n"); // OpOp2.MINUS1_MULT + builder.append("log_nz(MATRIX)::MATRIX\n"); // OpOp1.LOG_NZ + SCALARS.forEach(t -> { + builder.append("log(MATRIX," + t + ")::MATRIX\n"); + builder.append("log_nz(MATRIX," + t + ")::MATRIX\n"); + }); + + builder.append("const(MATRIX,FLOAT)::MATRIX\n"); + + builder.append("rowVec(MATRIX)::MATRIX\n"); + builder.append("colVec(MATRIX)::MATRIX\n"); + builder.append("cellMat(MATRIX)::MATRIX\n"); + + builder.append("_m(INT,INT,FLOAT)::MATRIX\n"); + builder.append("_m(INT,INT,BOOL)::MATRIX\n"); + builder.append("_m(INT,INT,INT)::MATRIX\n"); + List.of("FLOAT", "INT", "BOOL").forEach(t -> { + builder.append("_idxExpr(INT," + t + ")::" + t + "*\n"); + builder.append("_idxExpr(INT," + t + "*)::" + t + "*\n"); + builder.append("_idxExpr(INT...," + t + ")::" + t + "*\n"); + builder.append("_idxExpr(INT...," + t + "*)::" + t + "*\n"); + }); + builder.append("_idx(INT,INT)::INT\n"); + + ALL_TYPES.forEach(t -> builder.append("_EClass(" + t + "...)::" + t + "\n")); + ALL_TYPES.forEach(t -> builder.append("_backRef." + t + "()::" + t + "\n")); + + for (String s : SCALARS) + builder.append("literal." + s + "()::" + s + "\n"); + + return builder.toString(); + } + public static RuleContext getDefaultContext() { + String ctxString = getDefaultContextString(); + + RuleContext ctx = RuleContext.createContext(ctxString); + + ctx.customStringRepr.put("rand(INT,INT,FLOAT,FLOAT)", (stmt, mctx) -> { + List ops = stmt.getOperands(); + return "rand(rows=(" + ops.get(0) + "), cols=(" + ops.get(1) + "), min=(" + ops.get(2) + "), max=(" + ops.get(3) + "))"; + }); + ctx.customStringRepr.put("rand(INT,INT,INT,INT)", ctx.customStringRepr.get("rand(INT,INT,FLOAT,FLOAT)")); + ctx.customStringRepr.put("rand(INT,INT,FLOAT,INT)", ctx.customStringRepr.get("rand(INT,INT,FLOAT,FLOAT)")); + ctx.customStringRepr.put("rand(INT,INT,INT,FLOAT)", ctx.customStringRepr.get("rand(INT,INT,FLOAT,FLOAT)")); + + RewriterUtils.putAsDefaultBinaryPrintable(List.of("<", "<=", ">", ">=", "==", "!=", "&", "|"), List.of("INT", "FLOAT", "BOOL", "MATRIX"), ctx.customStringRepr); + + RewriterUtils.putAsBinaryPrintable("*", List.of("INT", "FLOAT", "MATRIX", "BOOL"), ctx.customStringRepr, RewriterUtils.binaryStringRepr(" * ")); + RewriterUtils.putAsBinaryPrintable("+", List.of("INT", "FLOAT", "MATRIX", "BOOL"), ctx.customStringRepr, RewriterUtils.binaryStringRepr(" + ")); + + ctx.customStringRepr.put("%*%(MATRIX,MATRIX)", RewriterUtils.binaryStringRepr(" %*% ")); + ctx.customStringRepr.put("&&(INT,INT)", RewriterUtils.binaryStringRepr(" && ")); + ctx.customStringRepr.put("index(MATRIX,INT,INT,INT,INT)", (stmt, ctx2) -> { + String out; + RewriterInstruction mInstr = (RewriterInstruction) stmt; + List ops = mInstr.getOperands(); + RewriterStatement op1 = ops.get(0); + + if (op1 instanceof RewriterDataType) + out = op1.toString(ctx2); + else + out = "(" + op1.toString(ctx2) + ")"; + + out += "[" + ops.get(1).toString(ctx2) + " : " + ops.get(2).toString(ctx2) + ", " + ops.get(3).toString(ctx2) + " : " + ops.get(4).toString(ctx2) + "]"; + return out; + }); + + return ctx; + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriterDataType.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriterDataType.java new file mode 100644 index 00000000000..5e42f7dbd63 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriterDataType.java @@ -0,0 +1,487 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter; + +import org.apache.commons.lang3.function.TriFunction; +import org.apache.sysds.hops.rewriter.assertions.RewriterAssertions; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +public class RewriterDataType extends RewriterStatement { + private String id; + private String type; + private Object literal = null; + private boolean consolidated = false; + private int hashCode; + private RewriterStatement ncol; + private RewriterStatement nrow; + + @Override + protected void compress(RewriterAssertions assertions) { + if (literal != null) + id = null; + + if (meta != null) { + if (type.equals("MATRIX")) { + nrow = getNRow(); + ncol = getNCol(); + + if (assertions != null) { + RewriterStatement mAss1 = assertions.getAssertionStatement(nrow, null); + RewriterStatement mAss2 = assertions.getAssertionStatement(ncol, null); + + if (mAss1 != null) + nrow = mAss1; + + if (mAss2 != null) + ncol = mAss2; + } + } + } + } + + @Override + public RewriterStatement getNRow() { + if (nrow != null) + return nrow; + + return super.getNRow(); + } + + public void setNRow(RewriterStatement stmt) { + nrow = stmt; + } + + @Override + public RewriterStatement getNCol() { + if (ncol != null) + return ncol; + + return super.getNCol(); + } + + public void setNCol(RewriterStatement stmt) { + ncol = stmt; + } + + @Override + public String getId() { + return id; + } + + @Override + public String getResultingDataType(final RuleContext ctx) { + return type; + } + + @Override + public void refreshReturnType(final RuleContext ctx) {} + + @Override + public boolean isLiteral() { + return literal != null && !(literal instanceof List); + } + + @Override + public Object getLiteral() { + return literal; + } + + @Override + public long intLiteral(boolean cast) { + if (getLiteral() instanceof Boolean) + return (boolean)getLiteral() ? 1 : 0; + + if (cast && getLiteral() instanceof Double) { + double val = floatLiteral(); + return (long)val; + } + + return (long)getLiteral(); + } + + @Override + public double floatLiteral() { + if (getLiteral() instanceof Boolean) + return (boolean)getLiteral() ? 1 : 0; + if (getLiteral() instanceof Long) + return Double.valueOf((Long)getLiteral()); + return (double)getLiteral(); + } + + @Override + public boolean boolLiteral() { + if (getLiteral() instanceof Boolean) + return (boolean)getLiteral(); + if (getLiteral() instanceof Long) + return (long)getLiteral() == 0L; + return (double)getLiteral() == 0.0D; + } + + @Override + public void setLiteral(Object literal) { + if (consolidated) + throw new IllegalArgumentException(); + + this.literal = literal; + } + + @Override + public RewriterStatement getLiteralStatement() { + return this; + } + + @Override + public boolean isArgumentList() { + return false; + } + + @Override + public List getArgumentList() { + return null; + } + + @Override + public boolean isInstruction() { + return false; + } + + @Override + public boolean isEClass() { + return false; + } + + @Override + public String trueInstruction() { + return null; + } + + @Override + public String trueTypedInstruction(RuleContext ctx) { + return null; + } + + @Override + public String trueTypedInstruction(boolean allowImplicitConversions, RuleContext ctx) { + return null; + } + + @Override + public RewriterStatement consolidate(final RuleContext ctx) { + if (consolidated) + return this; + + if (!isLiteral() && (id == null || id.isEmpty())) + throw new IllegalArgumentException("The id of a data type cannot be empty"); + if (type == null ||type.isEmpty()) + throw new IllegalArgumentException("The type of a data type cannot be empty"); + + if (isLiteral()) + hashCode = Objects.hash(-1, -1, type, literal); + else + hashCode = Objects.hash(rid, refCtr, type); + return this; + } + + @Override + public int recomputeHashCodes(boolean recursively, final RuleContext ctx) { + if (isLiteral()) + hashCode = Objects.hash(-1, -1, type, literal); + else + hashCode = Objects.hash(rid, refCtr, type); + return hashCode; + } + + @Override + public int structuralHashCode() { + return hashCode; + } + + @Override + public RewriterStatement rename(String id) { + this.id = id; + return this; + } + + @Override + public int hashCode() { + if (isLiteral()) + return hashCode; + + return super.hashCode(); + } + + @Override + public boolean equals(Object o) { + if (isLiteral()) + return o instanceof RewriterDataType && getLiteral().equals(((RewriterDataType)o).getLiteral()); + return super.equals(o); + } + + @Override + public int computeIds(int id) { + if (!isLiteral()) + return super.computeIds(id); + + rid = -1; + return id; + } + + @Override + public void computeRefCtrs() { + refCtr = -1; + } + + @Override + public boolean isConsolidated() { + return consolidated; + } + + @Override + public boolean match(final MatcherContext mCtx) { + RewriterStatement stmt = mCtx.currentStatement; + RuleContext ctx = mCtx.ctx; + String dType = stmt.getResultingDataType(ctx); + + if (!(stmt instanceof RewriterDataType) && !mCtx.statementsCanBeVariables) { + mCtx.setFirstMismatch(this, stmt); + return false; + } + + if (!dType.equals(type)) { + if (!mCtx.allowImplicitTypeConversions || !RewriterUtils.isImplicitlyConvertible(dType, type)) { + if (!mCtx.allowTypeHierarchy) { + mCtx.setFirstMismatch(this, stmt); + return false; + } + + Set types = ctx.typeHierarchy.get(dType); + if (types == null || !types.contains(type)) { + mCtx.setFirstMismatch(this, stmt); + return false; + } + } + } + + if (mCtx.literalsCanBeVariables) { + if (isLiteral()) { + if (!mCtx.ignoreLiteralValues && (!stmt.isLiteral() || !RewriterUtils.compareLiterals(this, (RewriterDataType)stmt, mCtx.allowImplicitTypeConversions))) { + mCtx.setFirstMismatch(this, stmt); + return false; + } + } + } else { + if (isLiteral() != stmt.isLiteral()) { + mCtx.setFirstMismatch(this, stmt); + return false; + } + if (!mCtx.ignoreLiteralValues && isLiteral() && !RewriterUtils.compareLiterals(this, (RewriterDataType)stmt, mCtx.allowImplicitTypeConversions)) { + mCtx.setFirstMismatch(this, stmt); + return false; + } + } + + // If matrix, check if the dimensions + if (!mCtx.statementsCanBeVariables && dType.equals("MATRIX")) { + RewriterStatement ncolEquiv = getNCol(); + RewriterStatement nrowEquiv = getNRow(); + + if (ncolEquiv != null && nrowEquiv != null) { + if (!mCtx.wasVisited(this)) { + mCtx.dontVisitAgain(this); + RewriterStatement ncolEquivThat = stmt.getNCol(); + RewriterStatement nrowEquivThat = stmt.getNRow(); + + RewriterAssertions assertionsThis = mCtx.getOldAssertionsThis(); + RewriterAssertions assertionsThat = mCtx.getOldAssertionsThat(); + + if (assertionsThis != null) { + RewriterStatement ncolAssertion = assertionsThis.getAssertionStatement(ncolEquiv, null); + + RewriterStatement nrowAssertion = assertionsThis.getAssertionStatement(nrowEquiv, null); + ncolEquiv = ncolAssertion == null ? ncolEquiv : ncolAssertion; + nrowEquiv = nrowAssertion == null ? nrowEquiv : nrowAssertion; + } + + if (assertionsThat != null) { + RewriterStatement ncolAssertionThat = assertionsThat.getAssertionStatement(ncolEquivThat, null); + + RewriterStatement nrowAssertionThat = assertionsThat.getAssertionStatement(nrowEquivThat, null); + ncolEquivThat = ncolAssertionThat == null ? ncolEquiv : ncolAssertionThat; + nrowEquivThat = nrowAssertionThat == null ? nrowEquiv : nrowAssertionThat; + } + + // Now, match those statements + mCtx.currentStatement = ncolEquivThat; + if (!ncolEquiv.match(mCtx)) { + if (mCtx.isDebug()) + System.out.println("MismatchNcolEquiv: " + ncolEquiv + " <=> " + ncolEquivThat); + return false; + } + mCtx.currentStatement = nrowEquivThat; + if (!nrowEquiv.match(mCtx)) { + if (mCtx.isDebug()) + System.out.println("MismatchNrowEquiv: " + nrowEquiv + " <=> " + nrowEquivThat); + return false; + } + } + } + } + + RewriterStatement assoc = mCtx.getDependencyMap().get(this); + if (assoc == null) { + if (!mCtx.allowDuplicatePointers && mCtx.getDependencyMap().containsValue(stmt)) { + mCtx.setFirstMismatch(this, stmt); + if (mCtx.isDebug()) + System.out.println("MismatchAssocNull: " + stmt); + return false; // Then the statement variable is already associated with another variable + } + mCtx.getDependencyMap().put(this, stmt); + return true; + } else if (assoc.equals(stmt)) { + return true; + } + + if (mCtx.isDebug()) + System.out.println("MismatchAssoc: " + stmt + " <=> " + assoc); + + mCtx.setFirstMismatch(this, stmt); + return false; + } + + @Override + public RewriterStatement clone() { + return new RewriterDataType().as(id).ofType(type); + } + + @Override + public RewriterStatement copyNode() { + return new RewriterDataType().as(id).ofType(type).asLiteral(literal); + } + + @Override + public RewriterStatement nestedCopyOrInject(Map copiedObjects, TriFunction injector, RewriterStatement parent, int pIdx) { + RewriterStatement mCpy = copiedObjects.get(this); + if (mCpy != null) + return mCpy; + mCpy = injector.apply(this, parent, pIdx); + if (mCpy != null) { + // Then change the reference to the injected object + copiedObjects.put(this, mCpy); + return mCpy; + } + + RewriterDataType mCopy = new RewriterDataType(); + mCopy.id = id; + mCopy.type = type; + if (literal != null && literal instanceof List) { + final ArrayList mList = new ArrayList<>(((List)literal).size()); + mCopy.literal = mList; + ((List) literal).forEach(el -> { + if (el instanceof RewriterStatement) + mList.add(((RewriterStatement)el).nestedCopyOrInject(copiedObjects, injector)); + }); + } else + mCopy.literal = literal; + mCopy.consolidated = consolidated; + mCopy.hashCode = hashCode; + if (meta != null) + mCopy.meta = new HashMap<>(meta); + copiedObjects.put(this, mCopy); + mCopy.nestedCopyOrInjectMetaStatements(copiedObjects, injector); + + return mCopy; + } + + @Override + public RewriterStatement simplify(final RuleContext ctx) { + return this; + } + + public String getType() { + return type; + } + + @Override + public RewriterDataType as(String id) { + if (consolidated) + throw new IllegalArgumentException("A data type cannot be modified after consolidation"); + this.id = id; + return this; + } + + public RewriterDataType ofType(String type) { + if (consolidated) + throw new IllegalArgumentException("A data type cannot be modified after consolidation"); + this.type = type; + return this; + } + + public RewriterDataType asLiteral(Object literal) { + if (consolidated) + throw new IllegalArgumentException("A data type cannot be modified after consolidation"); + this.literal = literal; + return this; + } + + @Override + public int toParsableString(StringBuilder sb, Map refs, int maxRefId, Map> vars, Set forceCreateRefs, final RuleContext ctx) { + String mType = type; + String varStr = id; + + if (isLiteral()) { + mType = "LITERAL_" + type; + varStr = getLiteral().toString(); + + if (getLiteral() instanceof Boolean) + varStr = varStr.toUpperCase(); + } + + Set varSet = vars.get(mType); + + if (varSet == null) { + varSet = new HashSet<>(); + vars.put(mType, varSet); + } + + varSet.add(varStr); + sb.append(varStr); + + return maxRefId; + } + + @Override + public String toString(final RuleContext ctx) { + if (!isLiteral()) + return getId() + "::" + getResultingDataType(ctx) + "[" + hashCode() + "]"; + + if (getLiteral() instanceof Boolean) + return getLiteral().toString().toUpperCase(); + + return getLiteral().toString() + "::" + getResultingDataType(ctx) + "[" + hashCode() + "]"; + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriterDatabase.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriterDatabase.java new file mode 100644 index 00000000000..7c38e4dc80d --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriterDatabase.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter; + +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; + +import java.io.BufferedReader; +import java.io.BufferedWriter; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Consumer; + +public class RewriterDatabase { + + private ConcurrentHashMap db = new ConcurrentHashMap<>(); + + public void clear() { + db.clear(); + } + + public boolean containsEntry(RewriterStatement instr) { + return db.containsKey(instr); + } + + public boolean insertEntry(final RuleContext ctx, RewriterStatement stmt) { + return db.putIfAbsent(new RewriterStatementEntry(ctx, stmt), stmt) == null; + } + + public RewriterStatement find(final RuleContext ctx, RewriterStatement stmt) { + return db.get(new RewriterStatementEntry(ctx, stmt)); + } + + public RewriterStatement insertOrReturn(final RuleContext ctx, RewriterStatement stmt) { + return db.putIfAbsent(new RewriterStatementEntry(ctx, stmt), stmt); + } + + public void forEach(Consumer consumer) { + db.values().forEach(consumer); + } + + public void parForEach(Consumer consumer) { + db.values().parallelStream().forEach(consumer); + } + + public int size() {return db.size(); } + + public void serialize(BufferedWriter writer, final RuleContext ctx) throws IOException { + for (RewriterStatement entry : db.values()) { + writer.write("\n::STMT\n"); + writer.write(entry.toParsableString(ctx, true)); + } + } + + public void deserialize(BufferedReader reader, final RuleContext ctx) throws IOException { + List strBuffer = new ArrayList<>(); + + String line; + while ((line = reader.readLine()) != null) { + if (line.isBlank()) + continue; + + if (line.startsWith("::STMT")) { + if (strBuffer.isEmpty()) + continue; + try { + RewriterStatement stmt = RewriterUtils.parse(String.join("\n", strBuffer), ctx); + insertEntry(ctx, stmt); + strBuffer.clear(); + } catch (Exception e) { + System.err.println("An error occurred while parsing the string:\n" + String.join("\n", strBuffer)); + strBuffer.clear(); + e.printStackTrace(); + } + } else { + strBuffer.add(line); + } + } + + if (!strBuffer.isEmpty()) { + try { + RewriterStatement stmt = RewriterUtils.parse(String.join("\n", strBuffer), ctx); + insertEntry(ctx, stmt); + } catch (Exception e) { + e.printStackTrace(); + } + } + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriterEquivalenceDatabase.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriterEquivalenceDatabase.java new file mode 100644 index 00000000000..a134ecb893f --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriterEquivalenceDatabase.java @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter; + +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; + +import java.io.BufferedReader; +import java.io.BufferedWriter; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Consumer; + +public class RewriterEquivalenceDatabase { + private ConcurrentHashMap db = new ConcurrentHashMap<>(); + + public void clear() { + db.clear(); + } + + public boolean containsEntry(RewriterStatement instr) { + return db.containsKey(instr); + } + + public DBEntry insert(final RuleContext ctx, RewriterStatement canonicalForm, RewriterStatement equivalence) { + return db.compute(new RewriterStatementEntry(ctx, canonicalForm), (k, v) -> { + if (v == null) + return new DBEntry(canonicalForm, equivalence); + + v.insertEquivalence(equivalence); + return v; + }); + } + + public DBEntry find(final RuleContext ctx, RewriterStatement canonicalForm) { + return db.get(new RewriterStatementEntry(ctx, canonicalForm)); + } + + public void forEach(Consumer consumer) { + db.values().forEach(consumer); + } + + public void parForEach(Consumer consumer) { + db.values().parallelStream().forEach(consumer); + } + + public int size() {return db.size(); } + + @Deprecated + public void serialize(BufferedWriter writer, final RuleContext ctx) throws IOException { + for (DBEntry entry : db.values()) { + writer.write("\n::STMT\n"); + writer.write(entry.canonicalForm.toParsableString(ctx, true)); + } + } + + @Deprecated + public void deserialize(BufferedReader reader, final RuleContext ctx) throws IOException { + List strBuffer = new ArrayList<>(); + + String line; + while ((line = reader.readLine()) != null) { + if (line.isBlank()) + continue; + + if (line.startsWith("::STMT")) { + if (strBuffer.isEmpty()) + continue; + try { + RewriterStatement stmt = RewriterUtils.parse(String.join("\n", strBuffer), ctx); + insert(ctx, stmt, null); + strBuffer.clear(); + } catch (Exception e) { + System.err.println("An error occurred while parsing the string:\n" + String.join("\n", strBuffer)); + strBuffer.clear(); + e.printStackTrace(); + } + } else { + strBuffer.add(line); + } + } + + if (!strBuffer.isEmpty()) { + try { + RewriterStatement stmt = RewriterUtils.parse(String.join("\n", strBuffer), ctx); + insert(ctx, stmt, null); + } catch (Exception e) { + e.printStackTrace(); + } + } + } + + public static class DBEntry { + public final RewriterStatement canonicalForm; + public final List equivalences; + + public DBEntry(RewriterStatement canonicalForm, RewriterStatement firstEquivalence) { + this.canonicalForm = canonicalForm; + this.equivalences = new ArrayList<>(3); + + if (firstEquivalence != null) + this.equivalences.add(firstEquivalence); + } + + public void insertEquivalence(RewriterStatement equivalence) { + equivalences.add(equivalence); + } + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriterFramework.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriterFramework.java new file mode 100644 index 00000000000..074a6d1952f --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriterFramework.java @@ -0,0 +1,496 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter; + +import org.apache.commons.collections.list.SynchronizedList; +import org.apache.commons.lang3.mutable.MutableInt; +import org.apache.sysds.hops.rewriter.assertions.RewriterAssertionUtils; +import org.apache.sysds.hops.rewriter.assertions.RewriterAssertions; +import org.apache.sysds.hops.rewriter.estimators.RewriterCostEstimator; +import org.apache.sysds.hops.rewriter.generated.RewriteAutomaticallyGenerated; +import org.apache.sysds.hops.rewriter.rule.RewriterRule; +import org.apache.sysds.hops.rewriter.rule.RewriterRuleCreator; +import org.apache.sysds.hops.rewriter.rule.RewriterRuleSet; +import org.apache.sysds.hops.rewriter.utils.RewriterSearchUtils; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; +import scala.Tuple2; +import scala.Tuple4; + +import java.io.BufferedReader; +import java.io.FileReader; +import java.io.FileWriter; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +public class RewriterFramework { + + // To test the framework + public static void main(String[] args) { + String dbPath = "./src/test/resources/rewriterframework/expressions.db"; + RewriterFramework rwf = new RewriterFramework(dbPath); + rwf.init(true,true); + rwf.dataDrivenSearch(1000); + rwf.systematicSearch(3); + //rwf.randomSearch(4, 4, 5000); + rwf.createRules(true); + rwf.removeInvalidRules(); + // Note that unconditional rules are not 'static' rules. + // It is a set of equivalences that have a single optimal expression + System.out.println(rwf.getUnconditionalRuleSet()); + //rwf.removeInapplicableRules(); + //System.out.println(rwf.getUnconditionalRuleSet().toJavaCode("GeneratedRewriteClass", true)); + + /*RewriterRuleSet rs = loadRuleSet(rPath); + saveJavaCode(sPath, rs, "GeneratedRewriteClass", true);*/ + } + + + private RuleContext ctx; + private Function converter; + private RewriterDatabase db; + private String dbFile; + + private int BATCH_SIZE = 1000; + private int MAX_COST_SAMPLES = 50; + + private RewriterEquivalenceDatabase equivalenceDB; + private List foundEquivalences; + private boolean pruneNovelExpressions = false; + + private RewriterRuleCreator unconditionalRuleCreator; + private RewriterRuleSet conditionalRuleSet; + + public RewriterFramework(String dbFile) { + this.dbFile = dbFile; + } + + private void setupDataDrivenSearch() { + if (db != null && db.size() > 0) + return; // Then a database has already been loaded + + try(BufferedReader reader = new BufferedReader(new FileReader(dbFile))) { + db.deserialize(reader, ctx); + } catch (IOException e) { + e.printStackTrace(); + } + } + + /** + * Initializes the rewriter framework + * @param allowInversionCanonicalization if the conversion from a/c => a*(c^-1) should be applied (during canonicalization) + * @param pruneNovelExpressions if only equivalence groups should be stored, where at least one expression was in the data-set + */ + public void init(boolean allowInversionCanonicalization, boolean pruneNovelExpressions) { + ctx = RewriterUtils.buildDefaultContext(); + converter = RewriterUtils.buildCanonicalFormConverter(ctx, allowInversionCanonicalization, false); + db = new RewriterDatabase(); + equivalenceDB = new RewriterEquivalenceDatabase(); + foundEquivalences = new ArrayList<>(); + this.pruneNovelExpressions = pruneNovelExpressions; + } + + public RuleContext getContext() { + return ctx; + } + + /** + * Performs a data-driven search where existing expressions and their subexpressions are considered + * @param exprPruningThreshold the maximum number of generated subexpressions (to avoid exploding numbers of subgraphs for big graphs) + */ + public void dataDrivenSearch(int exprPruningThreshold) { + setupDataDrivenSearch(); // Load the expression DB + + int size = db.size(); + RewriterDatabase exactExprDB = new RewriterDatabase(); + + MutableInt ctr = new MutableInt(0); + MutableInt failures = new MutableInt(0); + MutableInt generatedExpressions = new MutableInt(0); + MutableInt evaluatedExpressions = new MutableInt(0); + MutableInt totalCanonicalizationMillis = new MutableInt(0); + db.parForEach(expr -> { + if (ctr.incrementAndGet() % 10 == 0) + System.out.println("Done: " + ctr.intValue() + " / " + size); + + List subExprs = RewriterSearchUtils.generateSubtrees(expr, ctx, exprPruningThreshold); + if (subExprs.size() > exprPruningThreshold) + System.out.println("Critical number of subtrees: " + subExprs.size()); + if (subExprs.size() > 2 * exprPruningThreshold) { + System.out.println("Skipping subtrees..."); + subExprs = List.of(expr); + } + long evaluationCtr = 0; + long mCanonicalizationMillis = 0; + + for (RewriterStatement subExpr : subExprs) { + try { + if (!exactExprDB.insertEntry(ctx, subExpr)) + continue; + + evaluationCtr++; + + // Duplicate the statement as we do not want to canonicalize the original statement + long startMillis = System.currentTimeMillis(); + RewriterStatement canonicalForm = converter.apply(subExpr); + mCanonicalizationMillis += System.currentTimeMillis() - startMillis; + + synchronized (this) { + RewriterEquivalenceDatabase.DBEntry entry = equivalenceDB.insert(ctx, canonicalForm, subExpr); + + // Now, we use common variables + if (entry.equivalences.size() > 1) { + RewriterStatement commonForm = RewriterRuleCreator.createCommonForm(subExpr, entry.equivalences.get(0), canonicalForm, entry.canonicalForm, ctx)._1; + entry.equivalences.set(entry.equivalences.size()-1, commonForm); + } + + if (entry.equivalences.size() == 2) + foundEquivalences.add(entry); + } + } catch (Exception e) { + try { + System.err.println("Error from expression: " + subExpr.toParsableString(ctx)); + } catch (Exception e2) { + } + e.printStackTrace(); + failures.incrementAndGet(); + } + } + + generatedExpressions.addAndGet(subExprs.size()); + evaluatedExpressions.addAndGet(evaluationCtr); + totalCanonicalizationMillis.addAndGet(mCanonicalizationMillis); + }); + } + + /** + * Performs a systematic search + * @param maxDepth the maximum number of (virtual) operands + */ + public void systematicSearch(int maxDepth) { + systematicSearch(0, RewriterSearchUtils.getMaxSearchNumberForNumOps(maxDepth), true, false); + } + + /** + * Performs a systematic search + * @param maxDepth the maximum number of (virtual) operands + * @param includeDuplicateReferences if the search space should be extended to contain a shared variable (e.g. +(A,B) => [+(A,B), +(A,A)]) + */ + public void systematicSearch(int maxDepth, boolean includeDuplicateReferences) { + systematicSearch(0, RewriterSearchUtils.getMaxSearchNumberForNumOps(maxDepth), includeDuplicateReferences, false); + } + + /** + * Performs a systematic search + * @param fromIdx the start index + * @param toIdx the end index + * @param includeDuplicateReferences if the search space should be extended to contain a shared variable (e.g. +(A,B) => [+(A,B), +(A,A)]) + * @param includeRowColVectors if row-vectors and col-vectors should be included in the search (note that the data-driven approach does not support this) + */ + public void systematicSearch(int fromIdx, int toIdx, boolean includeDuplicateReferences, boolean includeRowColVectors) { + int diff = toIdx - fromIdx; + int maxN = toIdx; + + for (int batch = 0; batch < 10000 && batch * BATCH_SIZE < diff; batch++) { + List indices = IntStream.range(fromIdx + batch * BATCH_SIZE, fromIdx + Math.min((batch + 1) * BATCH_SIZE - 1, maxN)).boxed().collect(Collectors.toList()); + Collections.shuffle(indices); + MutableInt ctr2 = new MutableInt(0); + int maxSize = indices.size(); + final int mBATCH = batch; + indices.parallelStream().forEach(idx -> { + if (ctr2.incrementAndGet() % 10 == 0) + System.out.println("Done: " + (mBATCH * BATCH_SIZE + ctr2.intValue()) + " / " + (mBATCH * BATCH_SIZE + maxSize)); + + + List ops = RewriterSearchUtils.decodeOrderedStatements(idx); + List stmts = RewriterSearchUtils.buildAllPossibleDAGs(ops, ctx, true); + + for (RewriterStatement dag : stmts) { + List expanded = new ArrayList<>(); + expanded.add(dag); + if (includeDuplicateReferences) + expanded.addAll(RewriterSearchUtils.buildVariations(dag, ctx)); + if (includeRowColVectors) + expanded.addAll(RewriterSearchUtils.buildAssertionVariations(dag, ctx)); + + insertEquivalences(expanded); + } + }); + } + } + + public void randomSearch(int minExprSize, int maxExprSize, int numSamples) { + randomSearchFromIndex(RewriterSearchUtils.getMaxSearchNumberForNumOps(minExprSize-1)+1, RewriterSearchUtils.getMaxSearchNumberForNumOps(maxExprSize), numSamples, true, false); + } + + /** + * Performs a random search. Samples numSamples expression groups (groups of expressions encoded by a single integer) + * @param fromIdx the start index + * @param toIdx the end index + * @param numSamples the number of sampmles + * @param includeDuplicateReferences if expressions such as +(A,A) should be included in the search + * @param includeRowColVectors if row-col vectors should be included in the search + */ + public void randomSearchFromIndex(int fromIdx, int toIdx, int numSamples, boolean includeDuplicateReferences, boolean includeRowColVectors) { + // Now we will just do random sampling for a few rounds + Random rd = new Random(42); + for (int batch = 0; batch < 200 && batch * BATCH_SIZE < numSamples; batch++) { + List indices = IntStream.range(batch * BATCH_SIZE, (batch + 1) * BATCH_SIZE - 1).boxed().map(v -> fromIdx + rd.nextInt(toIdx-fromIdx)).collect(Collectors.toList()); + MutableInt ctr2 = new MutableInt(0); + int maxSize = indices.size(); + final int mBATCH = batch; + indices.parallelStream().forEach(idx -> { + if (ctr2.incrementAndGet() % 10 == 0) + System.out.println("Done: " + (mBATCH * BATCH_SIZE + ctr2.intValue()) + " / " + (mBATCH * BATCH_SIZE + maxSize)); + + List ops = RewriterSearchUtils.decodeOrderedStatements(idx); + List stmts = RewriterSearchUtils.buildAllPossibleDAGs(ops, ctx, true); + + for (RewriterStatement dag : stmts) { + List expanded = new ArrayList<>(); + expanded.add(dag); + if (includeDuplicateReferences) + expanded.addAll(RewriterSearchUtils.buildVariations(dag, ctx)); + if (includeRowColVectors) + expanded.addAll(RewriterSearchUtils.buildAssertionVariations(dag, ctx)); + + insertEquivalences(expanded); + } + }); + } + } + + private void insertEquivalences(List stmts) { + for (RewriterStatement stmt : stmts) { + try { + RewriterStatement canonicalForm = converter.apply(stmt); + + synchronized (this) { + if (pruneNovelExpressions && !equivalenceDB.containsEntry(canonicalForm)) + return; + + RewriterEquivalenceDatabase.DBEntry entry = equivalenceDB.insert(ctx, canonicalForm, stmt); + + // Now, we use common variables + if (entry.equivalences.size() > 1) { + RewriterStatement commonForm = RewriterRuleCreator.createCommonForm(stmt, entry.equivalences.get(0), canonicalForm, entry.canonicalForm, ctx)._1; + entry.equivalences.set(entry.equivalences.size()-1, commonForm); + } + + if (entry.equivalences.size() == 2) + foundEquivalences.add(entry); + } + } catch (Exception e) { + System.err.println("Faulty expression: " + stmt.toParsableString(ctx)); + e.printStackTrace(); + } + } + } + + /** + * Create rules from all observed equivalences + * @param freeDBMemory if all the stored equivalences that are not needed for rule creation should be dropped immediately + */ + public void createRules(boolean freeDBMemory) { + System.out.println("===== SUGGESTED REWRITES ====="); + List, Long, Boolean>> rewrites = findSuggestedRewrites(foundEquivalences, MAX_COST_SAMPLES); + + if (freeDBMemory) { + db.clear(); + foundEquivalences.clear(); + equivalenceDB.clear(); + } + + // Here, we create any rule + List> allRules = new ArrayList<>(); + int mCtr = 0; + for (Tuple4, Long, Boolean> rewrite : rewrites) { + if (++mCtr % 100 == 0) + System.out.println("Creating rule: " + mCtr + " / " + rewrites.size()); + + try { + RewriterRule rule; + if (rewrite._4()) + rule = RewriterRuleCreator.createRuleFromCommonStatements(rewrite._1(), rewrite._2().get(0), ctx); + else + rule = RewriterRuleCreator.createConditionalRuleFromCommonStatements(rewrite._1(), rewrite._2(), ctx); + + allRules.add(new Tuple4<>(rule, rewrite._3(), rule.getStmt1().countInstructions(), rewrite._4())); + } catch (Exception e) { + System.err.println("An error occurred while trying to create a rule:"); + System.err.println(rewrite._1().toParsableString(ctx, true)); + for (RewriterStatement stmt : rewrite._2()) + System.err.println(stmt.toParsableString(ctx, true)); + e.printStackTrace(); + } + } + + System.out.println("Rule creation complete!"); + + allRules.sort(Comparator.comparing(Tuple4::_3)); + + System.out.println("Rules sorted!"); + + unconditionalRuleCreator = new RewriterRuleCreator(ctx); + List conditionalRules = new ArrayList<>(); + + mCtr = 0; + + for (Tuple4 t : allRules) { + if (++mCtr % 100 == 0) + System.out.println("Registering rule: " + mCtr + " / " + allRules.size()); + + try { + // First, without validating correctness + // This might throw out some fallback options if a rule turns out to be incorrect but we there is a huge performance benefit + if (!t._1().isConditionalMultiRule()) { + unconditionalRuleCreator.registerRule(t._1(), converter, ctx); + } else { + conditionalRules.add(t._1()); + } + } catch (Exception e) { + e.printStackTrace(); + } + } + + conditionalRuleSet = new RewriterRuleSet(ctx, conditionalRules); + } + + /** + * This function removes rules where the output of the origin expression does not match + * the output of the target expression. + */ + public void removeInvalidRules() { + unconditionalRuleCreator.throwOutInvalidRules(true, false); + } + + /** + * This function removes rules where the origin expression is modified by the HOP-DAG rewriter. + * We aim to remove rules that are already implemented by intercepting the HOP-DAG after rewriting. + * We disable operator fusion and sum-product rewrites during execution. + * However, we throw away any rule that does not match our expected DAG structure, which may affect + * valid rules that are not correctly extracted during runtime. + */ + public void removeInapplicableRules() { + unconditionalRuleCreator.throwOutInvalidRules(false, true); + } + + /** + * + * @return the unconditional rule set (includes rules where there is exactly one possible optimum per equality set) + */ + public RewriterRuleSet getUnconditionalRuleSet() { + return unconditionalRuleCreator.getRuleSet(); + } + + /** + * + * @return the conditional rule set (rules where the optimal expression may change, e.g., (A*B)+(A*C) <=> A*(B+C)) + */ + public RewriterRuleSet getConditionalRuleSet() { + return conditionalRuleSet; + } + + public static boolean saveRuleSet(String filePath, RewriterRuleSet ruleSet) { + try (FileWriter writer = new FileWriter(filePath)) { + writer.write(ruleSet.serialize()); + } catch (IOException ex) { + ex.printStackTrace(); + return false; + } + + return true; + } + + public static RewriterRuleSet loadRuleSet(String filePath) { + try { + List lines = Files.readAllLines(Paths.get(filePath)); + return RewriterRuleSet.deserialize(lines, RewriterUtils.buildDefaultContext()); + } catch (IOException ex) { + ex.printStackTrace(); + return null; + } + } + + public static boolean saveJavaCode(String filePath, RewriterRuleSet ruleSet, String className, boolean optimize) { + try (FileWriter writer = new FileWriter(filePath)) { + writer.write(ruleSet.toJavaCode(className, optimize)); + } catch (IOException ex) { + ex.printStackTrace(); + return false; + } + + return true; + } + + /** + * This function computes rewrite suggestions based on cost-estimates. To enable random sampling, sample_size should be bigger than 1. + * Note that random sampling might generate incorrect suggestions due to inaccurate cost-estimates (especially for fused ops) + * @param equivalences + * @param sample_size how many sparsity and dimension values should be sampled; a sample size of 1 uses a fixed cost esimtate with ncols=nrows=2000 and fully dense matrices + * @return + */ + private List, Long, Boolean>> findSuggestedRewrites(List equivalences, int sample_size) { + List, Long, Boolean>> suggestions = SynchronizedList.decorate(new ArrayList<>()); + + AtomicLong idCtr = new AtomicLong(); + equivalences.parallelStream().forEach(entry -> { + try { + List mEq = entry.equivalences; + RewriterAssertions assertions = RewriterAssertionUtils.buildImplicitAssertions(mEq.get(0), ctx); + + for (int i = 1; i < mEq.size(); i++) + RewriterAssertionUtils.buildImplicitAssertions(mEq.get(i), assertions, ctx); + + List, List>> costs = RewriterCostEstimator.compareCosts(mEq, assertions, ctx, true, sample_size); + + Set> rewriteProposals = RewriterCostEstimator.findOptima(costs); + long mId = idCtr.incrementAndGet(); + + if (!rewriteProposals.isEmpty()) { + int targetIdx = rewriteProposals.stream().findFirst().get()._2; + boolean hasOneTarget = rewriteProposals.stream().allMatch(t -> t._2 == targetIdx); + + // Group by origin expression + Map>> grouped = rewriteProposals.stream().collect(Collectors.groupingBy(Tuple2::_1)); + + for (List> proposalsFromSameOrigin : grouped.values()) { + suggestions.add(new Tuple4<>(mEq.get(proposalsFromSameOrigin.get(0)._1), proposalsFromSameOrigin.stream().map(t -> mEq.get(t._2)).collect(Collectors.toList()), mId, hasOneTarget)); + } + } + } catch (Exception e) { + //e.printStackTrace(); + } + }); + + return suggestions; + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriterInstruction.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriterInstruction.java new file mode 100644 index 00000000000..bfd9ca615db --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriterInstruction.java @@ -0,0 +1,627 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter; + +import org.apache.commons.lang3.function.TriFunction; +import org.apache.sysds.hops.rewriter.assertions.RewriterAssertions; +import org.apache.sysds.hops.rewriter.rule.RewriterRule; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Random; +import java.util.Set; +import java.util.UUID; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.stream.Collectors; + +public class RewriterInstruction extends RewriterStatement { + + private String id; + private String returnType; + private String instr; + private ArrayList operands = new ArrayList<>(); + private Function, Long> costFunction = null; + private boolean consolidated = false; + private int hashCode; + + public RewriterInstruction() { + } + + public RewriterInstruction(String instr, final RuleContext ctx, RewriterStatement... ops) { + id = UUID.randomUUID().toString(); + this.instr = instr; + withOps(ops); + consolidate(ctx); + } + + @Override + protected void compress(RewriterAssertions assertions) { + id = null; + operands.trimToSize(); + meta = null; + } + + @Override + public String getId() { + if (isDataOrigin()) { + if (trueInstruction().equals("const")) { + boolean regen = id == null; + if (!regen) { + try { + UUID.fromString(id); + regen = true; + } catch (Exception e) { + } + } + if (regen) { + id = "mConst" + new Random().nextInt(10000); + } + } else { + return getChild(0).getId(); + } + } + + return id; + } + + @Override + public String getResultingDataType(final RuleContext ctx) { + if (returnType != null) + return returnType; + + if (isArgumentList()) + returnType = getOperands().stream().map(op -> op.getResultingDataType(ctx)).reduce(RewriterUtils::defaultTypeHierarchy).get() + "..."; + else + returnType = ctx.instrTypes.get(trueTypedInstruction(ctx));//getResult(ctx).getResultingDataType(ctx); + + if (returnType == null) + throw new IllegalArgumentException("Return type not found for: " + trueTypedInstruction(ctx)); + + return returnType; + } + + @Override + public void refreshReturnType(final RuleContext ctx) { + returnType = null; + } + + @Override + public boolean isLiteral() { + return false; + } + + @Override + public Object getLiteral() { + return null; + } + + @Override + public RewriterStatement getLiteralStatement() { + for (RewriterStatement op : getChild(0).getOperands()) + if (op.isLiteral()) + return op; + + return null; + } + + @Override + public long intLiteral(boolean cast) { + throw new UnsupportedOperationException(); + } + + @Override + public double floatLiteral() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean boolLiteral() { + throw new UnsupportedOperationException(); + } + + @Override + public RewriterStatement consolidate(final RuleContext ctx) { + if (consolidated) + return this; + + if (instr == null || instr.isEmpty()) + throw new IllegalArgumentException("Instruction type cannot be empty"); + + if (getCostFunction(ctx) == null) + throw new IllegalArgumentException("Could not find a matching cost function for " + typedInstruction(ctx)); + + for (RewriterStatement operand : operands) + operand.consolidate(ctx); + + hashCode = Objects.hash(rid, refCtr, instr, getResultingDataType(ctx), operands); + consolidated = true; + + return this; + } + @Override + public int recomputeHashCodes(boolean recursively, final RuleContext ctx) { + if (recursively) { + operands.forEach(op -> op.recomputeHashCodes(true, ctx)); + } + + hashCode = Objects.hash(rid, refCtr, instr, getResultingDataType(ctx), operands.stream().map(RewriterStatement::structuralHashCode).collect(Collectors.toList())); + return hashCode; + } + + @Override + public boolean isConsolidated() { + return consolidated; + } + + @Override + public boolean match(final MatcherContext mCtx) { + RewriterStatement stmt = mCtx.currentStatement; + RuleContext ctx = mCtx.ctx; + + if (mCtx.isDebug()) + System.out.println("Matching: " + this.toString(ctx) + " <=> " + stmt.toString(ctx)); + + // Check for some meta information + if (mCtx.statementsCanBeVariables && getResultingDataType(ctx).equals("MATRIX")) { + if ((trueInstruction().equals("rowVec") && stmt.isRowVector()) + || (trueInstruction().equals("colVec") && stmt.isColVector())) { + RewriterStatement existingRef = mCtx.findInternalReference(this); + + if (existingRef != null) { + if (existingRef == stmt) + return true; + else { + mCtx.setFirstMismatch(this, stmt); + return false; + } + } + + if (!mCtx.allowDuplicatePointers && mCtx.getInternalReferences().containsValue(stmt)) { + mCtx.setFirstMismatch(this, stmt); + return false; + } + + mCtx.getInternalReferences().put(this, stmt); + + if (stmt.isInstruction() && (stmt.trueInstruction().equals("rowVec") || stmt.trueInstruction().equals("colVec"))) + mCtx.getDependencyMap().put(getChild(0), stmt.getChild(0)); + else + mCtx.getDependencyMap().put(getChild(0), stmt); + + + return true; + } + } + + if (stmt instanceof RewriterInstruction && (getResultingDataType(ctx).equals(stmt.getResultingDataType(ctx)) || (mCtx.allowImplicitTypeConversions && RewriterUtils.isImplicitlyConvertible(stmt.getResultingDataType(ctx), getResultingDataType(ctx))))) { + RewriterInstruction inst = (RewriterInstruction)stmt; + + if(!inst.instr.equals(this.instr)) { + if (!mCtx.allowPropertyScan) { + mCtx.setFirstMismatch(this, stmt); + return false; + } + Set props = inst.getProperties(ctx); + + if (props == null || !props.contains(typedInstruction(ctx))) { + mCtx.setFirstMismatch(this, stmt); + return false; + } + } + if (this.operands.size() != inst.operands.size()) { + mCtx.setFirstMismatch(this, stmt); + return false; + } + + RewriterStatement existingRef = mCtx.findInternalReference(this); + + if (existingRef != null) { + if (existingRef == stmt) + return true; + else { + mCtx.setFirstMismatch(this, stmt); + return false; + } + } + + if (!mCtx.allowDuplicatePointers && mCtx.getInternalReferences().containsValue(stmt)) { + mCtx.setFirstMismatch(this, stmt); + return false; + } + + RewriterRule.LinkObject ruleLink = mCtx.ruleLinks.get(this); + + if (ruleLink != null) + mCtx.getLinks().add(new RewriterRule.ExplicitLink(inst, ruleLink.stmt, ruleLink.transferFunction)); + + int s = inst.operands.size(); + + if (mCtx.findMinimalMismatchRoot) { + int mismatchCtr = 0; + + for (int i = 0; i < s; i++) { + mCtx.currentStatement = inst.operands.get(i); + + if (!operands.get(i).match(mCtx)) + mismatchCtr++; + } + + if (mismatchCtr == 0) + mCtx.getInternalReferences().put(this, stmt); + else if (mismatchCtr > 1) + mCtx.setFirstMismatch(this, stmt); + + return mismatchCtr == 0; + } else { + for (int i = 0; i < s; i++) { + mCtx.currentStatement = inst.operands.get(i); + + if (!operands.get(i).match(mCtx)) { + if (mCtx.isDebug()) + System.out.println("Mismatch: " + operands.get(i) + " <=> " + inst.operands.get(i)); + return false; + } + } + + mCtx.getInternalReferences().put(this, stmt); + return true; + } + } + + mCtx.setFirstMismatch(this, stmt); + return false; + } + + @Override + public RewriterStatement copyNode() { + RewriterInstruction mCopy = new RewriterInstruction(); + mCopy.instr = instr; + mCopy.id = id; + mCopy.costFunction = costFunction; + mCopy.consolidated = consolidated; + mCopy.operands = new ArrayList<>(operands); + mCopy.returnType = returnType; + if (meta != null) + mCopy.meta = new HashMap<>(meta); + else + mCopy.meta = null; + return mCopy; + } + + @Override + public RewriterStatement nestedCopyOrInject(Map copiedObjects, TriFunction injector, RewriterStatement parent, int pIdx) { + RewriterStatement mCpy = copiedObjects.get(this); + if (mCpy != null) + return mCpy; + mCpy = injector.apply(this, parent, pIdx); + if (mCpy != null) { + // Then change the reference to the injected object + copiedObjects.put(this, mCpy); + return mCpy; + } + + RewriterInstruction mCopy = new RewriterInstruction(); + mCopy.instr = instr; + mCopy.id = id; + mCopy.costFunction = costFunction; + mCopy.consolidated = consolidated; + mCopy.operands = new ArrayList<>(operands.size()); + mCopy.returnType = returnType; + mCopy.hashCode = hashCode; + if (meta != null) + mCopy.meta = new HashMap<>(meta); + else + mCopy.meta = null; + mCopy.nestedCopyOrInjectMetaStatements(copiedObjects, injector); + copiedObjects.put(this, mCopy); + + for (int i = 0; i < operands.size(); i++) + mCopy.operands.add(operands.get(i).nestedCopyOrInject(copiedObjects, injector, mCopy, i)); + + return mCopy; + } + + @Override + public boolean isArgumentList() { + return trueInstruction().equals("argList"); + } + + @Override + public List getArgumentList() { + return isArgumentList() ? getOperands() : null; + } + + @Override + public boolean isInstruction() { + return true; + } + + @Override + public boolean isEClass() { + return trueInstruction().equals("_EClass"); + } + + @Deprecated + @Override + public RewriterStatement clone() { + RewriterInstruction mClone = new RewriterInstruction(); + mClone.instr = instr; + mClone.id = id; + ArrayList clonedOperands = new ArrayList<>(operands.size()); + + for (RewriterStatement stmt : operands) + clonedOperands.add(stmt.clone()); + + mClone.operands = clonedOperands; + mClone.costFunction = costFunction; + mClone.consolidated = consolidated; + mClone.returnType = returnType; + mClone.meta = meta; + return mClone; + } + + @Override + public List getOperands() { + return operands == null ? Collections.emptyList() : operands; + } + + + @Override + public RewriterStatement simplify(final RuleContext ctx) { + for (int i = 0; i < operands.size(); i++) { + RewriterStatement stmt = operands.get(i).simplify(ctx); + if (stmt != null) + operands.set(i, stmt); + } + + Function rule = ctx.simplificationRules.get(typedInstruction(ctx)); + if (rule != null) { + RewriterStatement stmt = rule.apply(this); + + if (stmt != null) + return stmt; + } + return this; + } + + public RewriterInstruction withInstruction(String instr) { + if (consolidated) + throw new IllegalArgumentException("An instruction cannot be modified after consolidation"); + this.instr = instr; + return this; + } + + public RewriterInstruction withOps(RewriterStatement... operands) { + if (consolidated) + throw new IllegalArgumentException("An instruction cannot be modified after consolidation"); + this.operands = new ArrayList<>(Arrays.asList(operands)); + return this; + } + + public RewriterInstruction addOp(String id) { + if (consolidated) + throw new IllegalArgumentException("An instruction cannot be modified after consolidation"); + this.operands.add(new RewriterDataType().as(id)); + return this; + } + + public RewriterInstruction addOp(RewriterStatement operand) { + if (consolidated) + throw new IllegalArgumentException("An instruction cannot be modified after consolidation"); + this.operands.add(operand); + return this; + } + + public RewriterInstruction ofType(String type) { + if (consolidated) + throw new IllegalArgumentException("An instruction cannot be modified after consolidation"); + RewriterStatement stmt = this.operands.get(this.operands.size()-1); + + if (stmt instanceof RewriterDataType) + ((RewriterDataType)stmt).ofType(type); + else + throw new IllegalArgumentException("Can only set the data type of RewriterDataType class"); + + return this; + } + + public Function, Long> getCostFunction(final RuleContext ctx) { + if (this.costFunction == null) + this.costFunction = ctx.instrCosts.get(typedInstruction(ctx)); + + return this.costFunction; + } + + public RewriterInstruction withCostFunction(Function, Long> costFunction) { + if (consolidated) + throw new IllegalArgumentException("An instruction cannot be modified after consolidation"); + this.costFunction = costFunction; + return this; + } + + public Optional findOperand(String id) { + return this.operands.stream().filter(op -> op.getId().equals(id)).findFirst(); + } + + @Override + public RewriterInstruction as(String id) { + if (consolidated) + throw new IllegalArgumentException("An instruction cannot be modified after consolidation"); + this.id = id; + return this; + } + + public String typedInstruction(final RuleContext ctx) { + return typedInstruction(this.instr, false, ctx); + } + + public String getInstr() { + return instr; + } + + private String typedInstruction(String instrName, boolean allowImplicitConversions, final RuleContext ctx) { + StringBuilder builder = new StringBuilder(); + builder.append(instrName); + builder.append("("); + + if (!operands.isEmpty()) { + String resultingDataType = operands.get(0).getResultingDataType(ctx); + if (allowImplicitConversions) + resultingDataType = RewriterUtils.convertImplicitly(resultingDataType); + builder.append(resultingDataType); + } + + if (!isArgumentList()) { + for (int i = 1; i < operands.size(); i++) { + builder.append(","); + String resultingDataType = operands.get(i).getResultingDataType(ctx); + if (allowImplicitConversions) + resultingDataType = RewriterUtils.convertImplicitly(resultingDataType); + builder.append(resultingDataType); + } + } + + builder.append(")"); + return builder.toString(); + } + + @Override + public int toParsableString(StringBuilder sb, Map refs, int maxRefId, Map> vars, Set forceCreateRefs, final RuleContext ctx) { + Integer ref = refs.get(this); + + if (ref != null) { + sb.append('$'); + sb.append(ref); + return maxRefId; + } + + if (refCtr > 1 || forceCreateRefs.contains(this)) { + maxRefId++; + sb.append('$'); + sb.append(maxRefId); + sb.append(':'); + refs.put(this, maxRefId); + } + + sb.append(instr); + sb.append('('); + + for (int i = 0; i < getOperands().size(); i++) { + if (i > 0) + sb.append(','); + + RewriterStatement op = getOperands().get(i); + maxRefId = op.toParsableString(sb, refs, maxRefId, vars, forceCreateRefs, ctx); + } + + sb.append(')'); + + return maxRefId; + } + + @Override + public String toString(final RuleContext ctx) { + Object varName = getMeta(META_VARNAME); + if (varName != null) + return varName.toString(); + + Object trueInstrObj = getMeta("trueInstr"); + String typedInstr = trueInstrObj != null ? typedInstruction((String)trueInstrObj, false, ctx) : typedInstruction(ctx); + BiFunction customStringFunc = ctx.customStringRepr.get(typedInstr); + if (customStringFunc != null) + return customStringFunc.apply(this, ctx); + + String instrName = meta == null ? instr : meta.getOrDefault("trueName", instr).toString(); + + StringBuilder builder = new StringBuilder(); + builder.append(instrName); + builder.append("("); + for (int i = 0; i < operands.size(); i++) { + if (i > 0) + builder.append(", "); + builder.append(operands.get(i).toString(ctx)); + } + builder.append(")"); + return builder + "[" + System.identityHashCode(this) + "]"; + } + + @Override + public int structuralHashCode() { + return hashCode; + } + + @Override + public RewriterStatement rename(String id) { + this.id = id; + return this; + } + + public String changeConsolidatedInstruction(String newName, final RuleContext ctx) { + String typedInstruction = newName; + String newInstrReturnType = ctx.instrTypes.get(typedInstruction); + if (newInstrReturnType == null || !newInstrReturnType.equals(getResultingDataType(ctx))) + throw new IllegalArgumentException("An instruction name can only be changed if it has the same signature (return type) [" + typedInstruction + "::" + newInstrReturnType + " <-> " + typedInstruction(ctx) + "::" + getResultingDataType(ctx) + "]"); + String oldName = instr; + instr = newName.substring(0, newName.indexOf('(')); + recomputeHashCodes(false, ctx); + return oldName; + } + + public boolean hasProperty(String property, final RuleContext ctx) { + Set properties = getProperties(ctx); + + if (properties == null) + return false; + + return properties.contains(property); + } + + public String trueInstruction() { + return instr; + } + + public String trueTypedInstruction(final RuleContext ctx) { + return typedInstruction(trueInstruction(), false, ctx); + } + + public String trueTypedInstruction(boolean allowImplicitConversions, final RuleContext ctx) { + return typedInstruction(trueInstruction(), allowImplicitConversions, ctx); + } + + public Set getProperties(final RuleContext ctx) { + Set ret = ctx.instrProperties.get(trueTypedInstruction(ctx)); + if (ret == null) + return Collections.emptySet(); + return ret; + } + + public void unsafeSetInstructionName(String str) { + this.instr = str; + } + +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriterRuntimeUtils.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriterRuntimeUtils.java new file mode 100644 index 00000000000..9ed52b1506e --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriterRuntimeUtils.java @@ -0,0 +1,938 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter; + +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.common.Types; +import org.apache.sysds.hops.AggBinaryOp; +import org.apache.sysds.hops.AggUnaryOp; +import org.apache.sysds.hops.BinaryOp; +import org.apache.sysds.hops.DataGenOp; +import org.apache.sysds.hops.DataOp; +import org.apache.sysds.hops.Hop; +import org.apache.sysds.hops.IndexingOp; +import org.apache.sysds.hops.LiteralOp; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.hops.ReorgOp; +import org.apache.sysds.hops.UnaryOp; +import org.apache.sysds.hops.rewrite.HopRewriteUtils; +import org.apache.sysds.hops.rewriter.dml.DMLExecutor; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; +import org.apache.sysds.parser.DMLProgram; +import org.apache.sysds.parser.ForStatement; +import org.apache.sysds.parser.ForStatementBlock; +import org.apache.sysds.parser.FunctionStatement; +import org.apache.sysds.parser.FunctionStatementBlock; +import org.apache.sysds.parser.IfStatement; +import org.apache.sysds.parser.IfStatementBlock; +import org.apache.sysds.parser.StatementBlock; +import org.apache.sysds.parser.WhileStatement; +import org.apache.sysds.parser.WhileStatementBlock; + +import javax.annotation.Nullable; +import java.io.BufferedReader; +import java.io.BufferedWriter; +import java.io.FileReader; +import java.io.FileWriter; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Random; +import java.util.Set; +import java.util.UUID; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.stream.Collectors; + +public class RewriterRuntimeUtils { + public static final boolean interceptAll = false; + public static boolean printUnknowns = false; + public static final String dbFile = "./src/test/resources/rewriterframework/expressions.db"; + public static final boolean readDB = true; + public static final boolean writeDB = true; + + private static boolean setupComplete = false; + + private static HashMap unknownOps = new HashMap<>(); + private static boolean ENFORCE_FLOAT_OBSERVATIONS = true; // To force every data type to float + private static boolean OBSERVE_SELECTIONS = false; + private static boolean OBSERVE_RAND = false; + + public static void setupIfNecessary() { + if (setupComplete) + return; + + setupComplete = true; + + if (interceptAll) { + System.out.println("INTERCEPTOR"); + OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = false; + OptimizerUtils.ALLOW_OPERATOR_FUSION = false; + System.out.println("OptLevel:" + OptimizerUtils.getOptLevel().toString()); + System.out.println("AllowOpFusion: " + OptimizerUtils.ALLOW_OPERATOR_FUSION); + System.out.println("AllowSumProductRewrites: " + OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES); + System.out.println("AllowConstantFolding: " + OptimizerUtils.ALLOW_CONSTANT_FOLDING); + + // Setup default context + RuleContext ctx = RewriterUtils.buildDefaultContext(); + + RewriterDatabase exactExprDB = new RewriterDatabase(); + + if (readDB) { + try(BufferedReader reader = new BufferedReader(new FileReader(dbFile))) { + exactExprDB.deserialize(reader, ctx); + } catch (IOException ex) { + ex.printStackTrace(); + } + } + + RewriterRuntimeUtils.attachPreHopInterceptor(prog -> { + RewriterRuntimeUtils.forAllUniqueTranslatableStatements(prog, 4, mstmt -> {}, exactExprDB, ctx); + return true; // We will continue to extract the rewritten hop + }); + + RewriterRuntimeUtils.attachHopInterceptor(prog -> { + RewriterRuntimeUtils.forAllUniqueTranslatableStatements(prog, 4, mstmt -> {}, exactExprDB, ctx); + return false; // Then we cancel the excecution to save time + }); + + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + if (writeDB) { + try (BufferedWriter writer = new BufferedWriter(new FileWriter(dbFile))) { + exactExprDB.serialize(writer, ctx); + } catch (IOException e) { + e.printStackTrace(); + } + } + })); + } + } + + public static void attachHopInterceptor(Function interceptor) { + DMLScript.hopInterceptor = interceptor; + } + + public static void detachHopInterceptor() { + DMLScript.hopInterceptor = null; + } + + public static void attachPreHopInterceptor(Function interceptor) { + DMLScript.preHopInterceptor = interceptor; + } + + public static void detachPreHopInterceptor() { + DMLScript.preHopInterceptor = null; + } + + public static RewriterStatement buildDAGFromHop(Hop hop, int maxDepth, boolean mindDataCharacteristics, final RuleContext ctx) { + RewriterStatement out = buildDAGRecursively(hop, null, new HashMap<>(), 0, maxDepth, ctx); + + if (mindDataCharacteristics) + return populateDataCharacteristics(out, ctx); + + return out; + } + + public static RewriterStatement populateDataCharacteristics(RewriterStatement stmt, final RuleContext ctx) { + if (stmt == null) + return null; + + if (stmt instanceof RewriterDataType && stmt.getResultingDataType(ctx).equals("MATRIX")) { + Long nrow = (Long) stmt.getMeta("_actualNRow"); + Long ncol = (Long) stmt.getMeta("_actualNCol"); + int matType = 0; + + if (nrow != null && nrow == 1L) { + matType = 1; + } else if (ncol != null && ncol == 1L) { + matType = 2; + } + + if (matType > 0) { + return new RewriterInstruction() + .as(stmt.getId()) + .withInstruction(matType == 1L ? "rowVec" : "colVec") + .withOps(stmt) + .consolidate(ctx); + } + } + + Map createdObjects = new HashMap<>(); + + stmt.forEachPostOrder((cur, pred) -> { + for (int i = 0; i < cur.getOperands().size(); i++) { + RewriterStatement child = cur.getChild(i); + + if (child instanceof RewriterDataType && child.getResultingDataType(ctx).equals("MATRIX")) { + Long nrow = (Long) child.getMeta("_actualNRow"); + Long ncol = (Long) child.getMeta("_actualNCol"); + int matType = 0; + + if (nrow != null && nrow == 1L) { + matType = 1; + } else if (ncol != null && ncol == 1L) { + matType = 2; + } + + if (matType > 0) { + RewriterStatement created = createdObjects.get(child); + + if (created == null) { + created = new RewriterInstruction() + .as(stmt.getId()) + .withInstruction(matType == 1 ? "rowVec" : "colVec") + .withOps(child) + .consolidate(ctx); + createdObjects.put(child, created); + } + + cur.getOperands().set(i, created); + } + } + } + }, false); + + return stmt; + } + + public static void forAllUniqueTranslatableStatements(DMLProgram program, int maxDepth, Consumer stmt, RewriterDatabase db, final RuleContext ctx) { + try { + Set visited = new HashSet<>(); + + for (String namespaceKey : program.getNamespaces().keySet()) { + for (String fname : program.getFunctionStatementBlocks(namespaceKey).keySet()) { + FunctionStatementBlock fsblock = program.getFunctionStatementBlock(namespaceKey, fname); + handleStatementBlock(fsblock, maxDepth, stmt, visited, db, ctx); + } + } + + for (StatementBlock sb : program.getStatementBlocks()) { + handleStatementBlock(sb, maxDepth, stmt, visited, db, ctx); + } + } catch (Exception e) { + e.printStackTrace(); + } + } + + private static void handleStatementBlock(StatementBlock sb, int maxDepth, Consumer consumer, Set visited, RewriterDatabase db, final RuleContext ctx) { + if (sb instanceof FunctionStatementBlock) + { + FunctionStatementBlock fsb = (FunctionStatementBlock) sb; + FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0); + fstmt.getBody().forEach(s -> handleStatementBlock(s, maxDepth, consumer, visited, db, ctx)); + } + else if (sb instanceof WhileStatementBlock) + { + WhileStatementBlock wsb = (WhileStatementBlock) sb; + WhileStatement wstmt = (WhileStatement)wsb.getStatement(0); + forAllUniqueTranslatableStatements(wsb.getPredicateHops(), maxDepth, consumer, visited, db, ctx); + wstmt.getBody().forEach(s -> handleStatementBlock(s, maxDepth, consumer, visited, db, ctx)); + } + else if (sb instanceof IfStatementBlock) + { + IfStatementBlock isb = (IfStatementBlock) sb; + IfStatement istmt = (IfStatement)isb.getStatement(0); + forAllUniqueTranslatableStatements(isb.getPredicateHops(), maxDepth, consumer, visited, db, ctx); + istmt.getIfBody().forEach(s -> handleStatementBlock(s, maxDepth, consumer, visited, db, ctx)); + istmt.getElseBody().forEach(s -> handleStatementBlock(s, maxDepth, consumer, visited, db, ctx)); + } + else if (sb instanceof ForStatementBlock) + { + ForStatementBlock fsb = (ForStatementBlock) sb; + ForStatement fstmt = (ForStatement)fsb.getStatement(0); + forAllUniqueTranslatableStatements(fsb.getFromHops(), maxDepth, consumer, visited, db, ctx); + forAllUniqueTranslatableStatements(fsb.getToHops(), maxDepth, consumer, visited, db, ctx); + forAllUniqueTranslatableStatements(fsb.getIncrementHops(), maxDepth, consumer, visited, db, ctx); + fstmt.getBody().forEach(s -> handleStatementBlock(s, maxDepth, consumer, visited, db, ctx)); + } + else + { + if (sb.getHops() != null) + sb.getHops().forEach(hop -> forAllUniqueTranslatableStatements(hop, maxDepth, consumer, visited, db, ctx)); + } + } + + private static void forAllUniqueTranslatableStatements(Hop currentHop, int maxDepth, Consumer consumer, Set visited, RewriterDatabase db, final RuleContext ctx) { + if (currentHop == null || visited.contains(currentHop)) + return; + + visited.add(currentHop); + RewriterStatement stmt = buildDAGRecursively(currentHop, null, new HashMap<>(), 0, maxDepth, ctx); + + if (stmt instanceof RewriterInstruction) + stmt = ctx.metaPropagator.apply(stmt); + + if (stmt == null) { + // TODO: What to do about TWrite and PWrite? + // Just ignore these ops? + if (!currentHop.getOpString().startsWith("TWrite") && !currentHop.getOpString().startsWith("PWrite") && !currentHop.getValueType().toString().equals("STRING") && !currentHop.getOpString().startsWith("LiteralOp") && !currentHop.getOpString().startsWith("fcall") && !currentHop.getOpString().startsWith("TRead") && !currentHop.getOpString().startsWith("PRead")) + unknownOps.compute(currentHop.getOpString() + "::" + currentHop.getDataType() + "::" + currentHop.getValueType(), (k, v) -> v == null ? 1 : v + 1); + } + + if (stmt != null) { + stmt.prepareForHashing(); + stmt.recomputeHashCodes(ctx); + } + + if (stmt != null && db.insertEntry(ctx, stmt)) { + RewriterStatement cpy = stmt.nestedCopyOrInject(new HashMap<>(), el -> null); + consumer.accept(cpy); + } + + if (currentHop.getInput() != null) + currentHop.getInput().forEach(child -> forAllUniqueTranslatableStatements(child, maxDepth, consumer, visited, db, ctx)); + } + + private static RewriterStatement buildDAGRecursively(Hop next, @Nullable String expectedType, Map cache, int depth, int maxDepth, final RuleContext ctx) { + if (depth == maxDepth) + return buildLeaf(next, expectedType, ctx); + + if (cache.containsKey(next)) + return checkForCorrectTypes(cache.get(next), expectedType, next, ctx); + + if (next instanceof LiteralOp) { + RewriterStatement literal = buildLiteral((LiteralOp)next, expectedType, ctx); + literal = checkForCorrectTypes(literal, expectedType, next, ctx); + cache.put(next, literal); + return literal; + } + + if (next instanceof AggBinaryOp) { + RewriterStatement stmt = buildAggBinaryOp((AggBinaryOp) next, expectedType, ctx); + stmt = checkForCorrectTypes(stmt, expectedType, next, ctx); + + if (stmt == null) + return buildLeaf(next, expectedType, ctx); + + insertDataCharacteristics(next, stmt, ctx); + + if (buildInputs(stmt, next.getInput(), cache, true, depth, maxDepth, ctx)) + return stmt; + + return null; + } + + if (next instanceof AggUnaryOp) { + RewriterStatement stmt = buildAggUnaryOp((AggUnaryOp) next, expectedType, ctx); + stmt = checkForCorrectTypes(stmt, expectedType, next, ctx); + + if (stmt == null) + return buildLeaf(next, expectedType, ctx); + + insertDataCharacteristics(next, stmt, ctx); + + if (buildInputs(stmt, next.getInput(), cache, true, depth, maxDepth, ctx)) + return stmt; + + return null; + } + + if (next instanceof BinaryOp) { + RewriterStatement stmt = buildBinaryOp((BinaryOp) next, expectedType, ctx); + stmt = checkForCorrectTypes(stmt, expectedType, next, ctx); + + if (stmt == null) + return buildLeaf(next, expectedType, ctx); + + insertDataCharacteristics(next, stmt, ctx); + + if (buildInputs(stmt, next.getInput(), cache, true, depth, maxDepth, ctx)) + return stmt; + + return null; + } + + if (next instanceof ReorgOp) { + RewriterStatement stmt = buildReorgOp((ReorgOp) next, expectedType, ctx); + stmt = checkForCorrectTypes(stmt, expectedType, next, ctx); + + if (stmt == null) + return buildLeaf(next, expectedType, ctx); + + insertDataCharacteristics(next, stmt, ctx); + + if (buildInputs(stmt, next.getInput(), cache, true, depth, maxDepth, ctx)) + return stmt; + + return null; + } + + if (next instanceof UnaryOp) { + RewriterStatement stmt = buildUnaryOp((UnaryOp)next, expectedType, ctx); + stmt = checkForCorrectTypes(stmt, expectedType, next, ctx); + + if (stmt == null) + return buildLeaf(next, expectedType, ctx); + + insertDataCharacteristics(next, stmt, ctx); + + if (buildInputs(stmt, next.getInput(), cache, true, depth, maxDepth, ctx)) + return stmt; + + return null; + } + + if (next instanceof IndexingOp) { + RewriterStatement stmt = buildIndexingOp((IndexingOp) next, expectedType, ctx); + stmt = checkForCorrectTypes(stmt, expectedType, next, ctx); + + if (stmt == null) + return buildLeaf(next, expectedType, ctx); + + insertDataCharacteristics(next, stmt, ctx); + + if (buildInputs(stmt, next.getInput(), cache, true, depth, maxDepth, ctx)) + return stmt; + + return null; + } + + if (next instanceof DataGenOp) { + List interestingHops = new ArrayList<>(); + RewriterStatement stmt = buildDataGenOp((DataGenOp)next, expectedType, ctx, interestingHops); + stmt = checkForCorrectTypes(stmt, expectedType, next, ctx); + + if (stmt == null) + return buildLeaf(next, expectedType, ctx); + + insertDataCharacteristics(next, stmt, ctx); + + if (buildInputs(stmt, interestingHops, cache, true, depth, maxDepth, ctx)) + return stmt; + + return null; + } + + if (next instanceof DataOp) { + DataOp dop = (DataOp) next; + + if (dop.isRead()) + return buildLeaf(next, expectedType, ctx); + } + + if (printUnknowns) { + System.out.println("Unknown Op: " + next); + System.out.println("Class: " + next.getClass().getSimpleName()); + System.out.println("OPString: " + next.getOpString()); + } + + return null; + } + + private static void insertDataCharacteristics(Hop hop, RewriterStatement stmt, final RuleContext ctx) { + if (stmt.getResultingDataType(ctx).equals("MATRIX")) { + if (hop.getDataCharacteristics() != null) { + long nrows = hop.getDataCharacteristics().getRows(); + long ncols = hop.getDataCharacteristics().getCols(); + if (nrows > 0) + stmt.unsafePutMeta("_actualNRow", nrows); + if (ncols > 0) + stmt.unsafePutMeta("_actualNCol", ncols); + } + } + } + + private static RewriterStatement checkForCorrectTypes(RewriterStatement stmt, @Nullable String expectedType, Hop hop, final RuleContext ctx) { + if (stmt == null) + return null; + + if (expectedType == null) + expectedType = stmt.getResultingDataType(ctx); + + String actualType = resolveExactDataType(hop); + + if (actualType == null) + return null; + + if (actualType.equals(expectedType)) + return stmt; + + if (actualType.equals("MATRIX")) { + HashMap oldTypes = new HashMap<>(); + oldTypes.put("A", stmt); + RewriterStatement newStmt = RewriterUtils.parseExpression("as.matrix(A)", new HashMap<>(), oldTypes, ctx); + return newStmt; + } + + return null; + } + + private static RewriterStatement buildLeaf(Hop hop, @Nullable String expectedType, final RuleContext ctx) { + String hopName = hop.getName(); + + // Check if hopName collides with literal values + if (RewriterUtils.LONG_PATTERN.matcher(hopName).matches()) + hopName = "int" + new Random().nextInt(1000); + if (RewriterUtils.DOUBLE_PATTERN.matcher(hopName).matches() || RewriterUtils.SPECIAL_FLOAT_PATTERN.matcher(hopName).matches()) + hopName = "float" + new Random().nextInt(1000); + + if (expectedType != null) { + RewriterStatement stmt = RewriterUtils.parse(hopName, ctx, expectedType + ":" + hopName); + insertDataCharacteristics(hop, stmt, ctx); + return stmt; + } + + switch (hop.getDataType()) { + case SCALAR: + return buildScalarLeaf(hop, hopName, ctx); + case MATRIX: + RewriterStatement stmt = RewriterUtils.parse(hopName, ctx, "MATRIX:" + hopName); + insertDataCharacteristics(hop, stmt, ctx); + return stmt; + } + + return null; // Not supported then + } + + private static RewriterStatement buildScalarLeaf(Hop hop, final RuleContext ctx) { + return buildScalarLeaf(hop, null, ctx); + } + + private static RewriterStatement buildScalarLeaf(Hop hop, @Nullable String newName, final RuleContext ctx) { + if (newName == null) + newName = hop.getName(); + + switch (hop.getValueType()) { + case FP64: + case FP32: + return RewriterUtils.parse(newName, ctx, "FLOAT:" + newName); + case INT64: + case INT32: + if (ENFORCE_FLOAT_OBSERVATIONS) + return RewriterUtils.parse(newName, ctx, "FLOAT:" + newName); + return RewriterUtils.parse(newName, ctx, "INT:" + newName); + case BOOLEAN: + if (ENFORCE_FLOAT_OBSERVATIONS) + return RewriterUtils.parse(newName, ctx, "FLOAT:" + newName); + return RewriterUtils.parse(newName, ctx, "BOOL:" + newName); + } + + return null; // Not supported then + } + + private static boolean buildInputs(RewriterStatement stmt, List inputs, Map cache, boolean fixedSize, int depth, int maxDepth, final RuleContext ctx) { + if (fixedSize && stmt.getOperands().size() != inputs.size()) + return false; + + List children = new ArrayList<>(); + int ctr = 0; + for (Hop in : inputs) { + RewriterStatement childStmt = buildDAGRecursively(in, fixedSize ? stmt.getOperands().get(ctr).getResultingDataType(ctx) : null, cache, depth + 1, maxDepth, ctx); + + if (childStmt == null) { + //System.out.println("Could not build child: " + in); + // TODO: Then just build leaf + //return false; + childStmt = buildLeaf(in, stmt.getOperands().get(ctr).getResultingDataType(ctx), ctx); + + if (childStmt == null) + return false; + } + + if (fixedSize && !RewriterUtils.convertImplicitly(childStmt.getResultingDataType(ctx), ENFORCE_FLOAT_OBSERVATIONS).equals(stmt.getOperands().get(ctr).getResultingDataType(ctx))) + throw new IllegalArgumentException("Different data type than expected: " + stmt.toString(ctx) + "; [" + ctr + "] " + childStmt.toString(ctx) + " ::" + childStmt.getResultingDataType(ctx)); + + children.add(childStmt); + ctr++; + } + + stmt.getOperands().clear(); + stmt.getOperands().addAll(children); + stmt.consolidate(ctx); + return true; + } + + private static RewriterStatement buildIndexingOp(IndexingOp op, @Nullable String expectedType, final RuleContext ctx) { + if (!OBSERVE_SELECTIONS) + return null; + + if (expectedType == null) { + expectedType = resolveExactDataType(op); + + if (expectedType == null) + return null; + } + + switch (op.getOpString()) { + case "rix": + return RewriterUtils.parse("[](A, i, j, k, l)", ctx, "MATRIX:A", "INT:i,j,k,l"); + } + + return null; + } + + private static RewriterStatement buildUnaryOp(UnaryOp op, @Nullable String expectedType, final RuleContext ctx) { + if (expectedType == null) { + expectedType = resolveExactDataType(op); + + if (expectedType == null) + return null; + } + + String fromType = resolveExactDataType(op.getInput(0)); + Types.DataType toDT = op.getDataType(); + + if (!toDT.isMatrix() && !toDT.isScalar()) + return null; + + switch(op.getOpString()) { + case "u(castdts)": + if (toDT.isMatrix()) + return RewriterUtils.parse("cast.MATRIX(A)", ctx, "MATRIX:A"); + if (fromType != null) + return RewriterUtils.parse("cast." + expectedType + "(A)", ctx, fromType + ":A"); + + return null; + case "u(castdtm)": + if (fromType != null) + return RewriterUtils.parse("cast.MATRIX(a)", ctx, fromType + ":a"); + + return null; + case "u(sqrt)": + return RewriterUtils.parse("sqrt(A)", ctx, fromType + ":A"); + case "u(!)": + return RewriterUtils.parse("!(A)", ctx, fromType + ":A"); + case "u(ncol)": + return RewriterUtils.parse("ncol(A)", ctx, "MATRIX:A"); + case "u(nrow)": + return RewriterUtils.parse("nrow(A)", ctx, "MATRIX:A"); + case "u(length)": + return RewriterUtils.parse("length(A)", ctx, "MATRIX:A"); + case "u(exp)": + return RewriterUtils.parse("exp(A)", ctx, fromType + ":A"); + case "u(round)": + return RewriterUtils.parse("round(A)", ctx, fromType + ":A"); + case "u(abs)": + return RewriterUtils.parse("abs(A)", ctx, fromType + ":A"); + } + + if (printUnknowns) + DMLExecutor.println("Unknown UnaryOp: " + op.getOpString()); + return null; + } + + private static RewriterStatement buildAggBinaryOp(AggBinaryOp op, @Nullable String expectedType, final RuleContext ctx) { + if (expectedType != null && !expectedType.equals("MATRIX")) + throw new IllegalArgumentException(); + + // Some placeholder definitions + switch(op.getOpString()) { + case "ba(+*)": // Matrix multiplication + return RewriterUtils.parse("%*%(A, B)", ctx, "MATRIX:A,B"); + } + + if (printUnknowns) + DMLExecutor.println("Unknown AggBinaryOp: " + op.getOpString()); + return null; + } + + private static RewriterStatement buildAggUnaryOp(AggUnaryOp op, @Nullable String expectedType, final RuleContext ctx) { + // Some placeholder definitions + switch(op.getOpString()) { + case "ua(+C)": // Matrix multiplication + if (expectedType != null && !expectedType.equals("MATRIX")) + throw new IllegalArgumentException("Unexpected type: " + expectedType); + return RewriterUtils.parse("colSums(A)", ctx, "MATRIX:A"); + case "ua(+R)": + if (expectedType != null && !expectedType.equals("MATRIX")) + throw new IllegalArgumentException("Unexpected type:" + expectedType); + return RewriterUtils.parse("rowSums(A)", ctx, "MATRIX:A"); + case "ua(+RC)": + if (expectedType != null && !expectedType.equals("FLOAT")) + throw new IllegalArgumentException("Unexpected type: " + expectedType); + return RewriterUtils.parse("sum(A)", ctx, "MATRIX:A"); + case "ua(nrow)": + if (expectedType != null && !expectedType.equals("INT")) + throw new IllegalArgumentException("Unexpected type: " + expectedType); + return RewriterUtils.parse("nrow(A)", ctx, "MATRIX:A"); + case "ua(ncol)": + if (expectedType != null && !expectedType.equals("INT")) + throw new IllegalArgumentException("Unexpected type: " + expectedType); + return RewriterUtils.parse("ncol(A)", ctx, "MATRIX:A"); + case "ua(maxRC)": + if (expectedType != null && !expectedType.equals("FLOAT")) + throw new IllegalArgumentException("Unexpected type: " + expectedType); + return RewriterUtils.parse("max(A)", ctx, "MATRIX:A"); + case "ua(minRC)": + if (expectedType != null && !expectedType.equals("FLOAT")) + throw new IllegalArgumentException("Unexpected type: " + expectedType); + return RewriterUtils.parse("min(A)", ctx, "MATRIX:A"); + case "ua(traceRC)": + return RewriterUtils.parse("trace(A)", ctx, "MATRIX:A"); + } + + if (printUnknowns) + DMLExecutor.println("Unknown AggUnaryOp: " + op.getOpString()); + return null; + } + + private static RewriterStatement buildBinaryOp(BinaryOp op, @Nullable String expectedType, final RuleContext ctx) { + String t1 = resolveExactDataType(op.getInput().get(0)); + String t2 = resolveExactDataType(op.getInput().get(1)); + + if (t1 == null || t2 == null) + return null; + + t1 += ":a"; + t2 += ":b"; + + RewriterStatement parsed = null; + + switch(op.getOpString()) { + case "b(+)": // Addition + parsed = RewriterUtils.parse("+(a, b)", ctx, t1, t2); + break; + case "b(*)": // Matrix multiplication + parsed = RewriterUtils.parse("*(a, b)", ctx, t1, t2); + break; + case "b(-)": + parsed = RewriterUtils.parse("-(a, b)", ctx, t1, t2); + break; + case "b(/)": + parsed = RewriterUtils.parse("/(a, b)", ctx, t1, t2); + break; + case "b(||)": + parsed = RewriterUtils.parse("|(a, b)", ctx, t1, t2); + break; + case "b(!=)": + parsed = RewriterUtils.parse("!=(a, b)", ctx, t1, t2); + break; + case "b(==)": + parsed = RewriterUtils.parse("==(a, b)", ctx, t1, t2); + break; + case "b(&&)": + parsed = RewriterUtils.parse("&(a, b)", ctx, t1, t2); + break; + case "b(<)": + parsed = RewriterUtils.parse("<(a, b)", ctx, t1, t2); + break; + case "b(>)": + parsed = RewriterUtils.parse(">(a, b)", ctx, t1, t2); + break; + case "b(>=)": + parsed = RewriterUtils.parse(">=(a, b)", ctx, t1, t2); + break; + case "b(<=)": + parsed = RewriterUtils.parse("<=(a, b)", ctx, t1, t2); + break; + case "b(^)": + parsed = RewriterUtils.parse("^(a, b)", ctx, t1, t2); + break; + case "b(rbind)": + if (!t1.equals("MATRIX") || !t2.equals("MATRIX")) + return null; + return RewriterUtils.parse("RBind(a, b)", ctx, t1, t2); + case "b(cbind)": + if (!t1.equals("MATRIX") || !t2.equals("MATRIX")) + return null; + return RewriterUtils.parse("CBind(a, b)", ctx, t1, t2); + case "b(1-*)": + return RewriterUtils.parse("1-*(A, B)", ctx, "MATRIX:A,B"); + } + + if (parsed != null) + return parsed.rename(op.getName()); + + if (printUnknowns) + DMLExecutor.println("Unknown BinaryOp: " + op.getOpString()); + return null; + } + + private static String resolveExactDataType(Hop hop) { + if (hop.getDataType() == Types.DataType.MATRIX) + return "MATRIX"; + + switch (hop.getValueType()) { + case FP64: + case FP32: + return "FLOAT"; + case INT64: + case INT32: + if (ENFORCE_FLOAT_OBSERVATIONS) + return "FLOAT"; + return "INT"; + case BOOLEAN: + if (ENFORCE_FLOAT_OBSERVATIONS) + return "FLOAT"; + return "BOOL"; + } + + if (printUnknowns) + DMLExecutor.println("Unknown type: " + hop + " -> " + hop.getDataType() + " : " + hop.getValueType()); + + return null; + } + + private static RewriterStatement buildReorgOp(ReorgOp op, @Nullable String expectedType, final RuleContext ctx) { + if (expectedType != null && !expectedType.equals("MATRIX")) + throw new IllegalArgumentException(); + + switch(op.getOpString()) { + case "r(r')": // Matrix multiplication + return RewriterUtils.parse("t(A)", ctx, "MATRIX:A"); + case "r(rev)": + return RewriterUtils.parse("rev(A)", ctx, "MATRIX:A"); + case "r(rdiag)": + return RewriterUtils.parse("diag(A)", ctx, "MATRIX:A"); + } + + //System.out.println("Unknown BinaryOp: " + op.getOpString()); + if (printUnknowns) + DMLExecutor.println("Unknown ReorgOp: " + op.getOpString()); + return null; + } + + private static RewriterStatement buildDataGenOp(DataGenOp op, @Nullable String expectedType, final RuleContext ctx, List interestingHops) { + if (expectedType != null && !expectedType.equals("MATRIX")) + throw new IllegalArgumentException(); + + switch(op.getOpString()) { + case "dg(rand)": + if (OBSERVE_RAND) { + interestingHops.add(op.getParam("rows")); + interestingHops.add(op.getParam("cols")); + interestingHops.add(op.getParam("min")); + interestingHops.add(op.getParam("max")); + return RewriterUtils.parse("rand(i1, i2, f1, f2)", ctx, "INT:i1,i2", "FLOAT:f1,f2").rename(op.getName()); + } + return null; + } + + return null; + } + + private static RewriterStatement buildLiteral(LiteralOp literal, @Nullable String expectedType, final RuleContext ctx) { + if (literal.getDataType() != Types.DataType.SCALAR) + return null; // Then it is not supported yet + + String mType; + Object mValue; + + switch (literal.getValueType()) { + case FP64: + case FP32: + if (expectedType != null && !expectedType.equals("FLOAT")) + throw new IllegalArgumentException("Unexpected type: " + expectedType); + return new RewriterDataType().as(UUID.randomUUID().toString()).ofType("FLOAT").asLiteral(literal.getDoubleValue()).consolidate(ctx); + case INT32: + case INT64: + if (expectedType != null) { + if (expectedType.equals("INT")) { + mType = expectedType; + mValue = literal.getLongValue(); + } else if (expectedType.equals("FLOAT")) { + mType = "FLOAT"; + mValue = (double)literal.getLongValue(); + } else { + throw new IllegalArgumentException(); + } + } else { + mType = "INT"; + mValue = literal.getLongValue(); + } + return new RewriterDataType().as(UUID.randomUUID().toString()).ofType(mType).asLiteral(mValue).consolidate(ctx); + case BOOLEAN: + if (expectedType != null) { + if (expectedType.equals("FLOAT")) { + mType = expectedType; + mValue = literal.getBooleanValue() ? 1.0D : 0.0D; + } else if (expectedType.equals("INT")) { + mType = expectedType; + mValue = literal.getBooleanValue() ? 1L : 0L; + } else if (expectedType.equals("BOOL")) { + mType = expectedType; + mValue = literal.getBooleanValue(); + } else { + throw new IllegalArgumentException(); + } + } else { + mType = "BOOL"; + mValue = literal.getBooleanValue(); + } + return new RewriterDataType().as(UUID.randomUUID().toString()).ofType(mType).asLiteral(mValue).consolidate(ctx); + default: + return null; // Not supported yet + } + } + + public static boolean executeScript(String script) { + try { + return DMLScript.executeScript(new String[]{"-s", script}); + } catch (Exception ex) { + ex.printStackTrace(); + return false; + } + } + + + /** + * Validates matrix dimensions to ensure that broadcasting still works afer the transformation + * @param hop1 the first HOP + * @param hop2 the second HOP + * @return if the new binary op would work in terms of broadcasting + */ + public static boolean validateBinaryBroadcasting(Hop hop1, Hop hop2) { + if (hop1.isMatrix() && hop2.isMatrix()) { + if (!hop1.dimsKnown() || !hop2.dimsKnown()) + return false; + + if (hop1.getDim1() == hop2.getDim1()) { + if (hop1.getDim2() == hop2.getDim2()) + return true; // Then both dimensions match + + return hop2.getDim2() == 1; // Otherwise we require a column vector + } else if (hop1.getDim2() == hop2.getDim2()) { + return hop2.getDim1() == 1; // We require a row vector + } + + // At least one dimension must match + return false; + } + + return true; + } + + public static boolean hasMatchingDims(Hop hop1, Hop hop2) { + return hop1.dimsKnown() && hop2.dimsKnown() && hop1.getDim1() == hop2.getDim1() && hop1.getDim2() == hop2.getDim2(); + } + + public static boolean hasMatchingDims(Hop... hops) { + if (hops.length < 2) + return true; + + for (Hop hop : hops) + if (!hop.dimsKnown()) + return false; + + long dim1 = hops[0].getDim1(); + long dim2 = hops[0].getDim2(); + + for (int i = 1; i < hops.length; i++) + if (hops[i].getDim1() != dim1 && hops[i].getDim2() != dim2) + return false; + + return true; + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriterStatement.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriterStatement.java new file mode 100644 index 00000000000..e6e633100a7 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriterStatement.java @@ -0,0 +1,1092 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter; + +import org.apache.commons.collections4.bidimap.DualHashBidiMap; +import org.apache.commons.lang3.NotImplementedException; +import org.apache.commons.lang3.function.TriFunction; +import org.apache.commons.lang3.mutable.MutableInt; +import org.apache.logging.log4j.util.TriConsumer; +import org.apache.sysds.hops.rewriter.assertions.RewriterAssertions; +import org.apache.sysds.hops.rewriter.estimators.RewriterCostEstimator; +import org.apache.sysds.hops.rewriter.rule.RewriterRule; +import org.apache.sysds.hops.rewriter.utils.StatementUtils; +import scala.Tuple2; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.function.BiConsumer; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.regex.Pattern; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; + +public abstract class RewriterStatement { + public static final String META_VARNAME = "_varName"; + + + protected int rid = 0; + public int refCtr = 0; + protected long cost = -2; + + protected HashMap meta = null; + + + public static class MatchingSubexpression { + private final RewriterStatement expressionRoot; + private final RewriterStatement matchRoot; + private final RewriterPredecessor pred; + private final Map assocs; + private final List links; + public RewriterStatement newExprRoot; + + public MatchingSubexpression(RewriterStatement expressionRoot, RewriterStatement matchRoot, RewriterPredecessor pred, Map assocs, List links) { + this.expressionRoot = expressionRoot; + this.matchRoot = matchRoot; + this.pred = pred; + this.assocs = assocs; + this.links = links; + } + + public RewriterStatement getExpressionRoot() { + return expressionRoot; + } + + public RewriterStatement getMatchRoot() { + return matchRoot; + } + + public RewriterPredecessor getPredecessor() { + return pred; + } + + public Map getAssocs() { + return assocs; + } + + public List getLinks() { + return links; + } + + public RewriterStatement getNewExprRoot() { + return newExprRoot; + } + + public void setNewExprRoot(RewriterStatement exprRoot) { + newExprRoot = exprRoot; + } + } + + public static class MatcherContext { + final RuleContext ctx; + final boolean statementsCanBeVariables; + final boolean literalsCanBeVariables; + final boolean ignoreLiteralValues; + final boolean allowDuplicatePointers; + final boolean allowPropertyScan; + final boolean allowTypeHierarchy; + final boolean terminateOnFirstMatch; + final boolean findMinimalMismatchRoot; + final boolean traceVariableEliminations; + final boolean allowImplicitTypeConversions; + final Map ruleLinks; + final RewriterStatement expressionRoot; + final RewriterStatement thisExpressionRoot; + RewriterStatement matchRoot; + RewriterPredecessor pred; + + public RewriterStatement currentStatement; + + private Map dependencyMap; + private List links; + private DualHashBidiMap internalReferences; + + private List subMatches; + private Tuple2 firstMismatch; + private boolean debug; + private boolean assertionsFetched = false; + private RewriterAssertions assertionsThat; + private RewriterAssertions assertionsThis; + private Set dontVisitAgain; + + public MatcherContext(final RuleContext ctx, RewriterStatement matchRoot, RewriterStatement expressionRoot, RewriterStatement thisExpressionRoot) { + this(ctx, matchRoot, expressionRoot, thisExpressionRoot, false, false, false, false, false, false, false, false, false, Collections.emptyMap()); + } + + public MatcherContext(final RuleContext ctx, RewriterStatement matchRoot, RewriterStatement expressionRoot, RewriterStatement thisExpressionRoot, final boolean statementsCanBeVariables, final boolean literalsCanBeVariables, final boolean ignoreLiteralValues, final boolean allowDuplicatePointers, final boolean allowPropertyScan, final boolean allowTypeHierarchy, final boolean terminateOnFirstMatch, final boolean findMinimalMismatchRoot, boolean traceVariableEliminations, final Map ruleLinks) { + this.ctx = ctx; + this.matchRoot = matchRoot; + this.pred = new RewriterPredecessor(); + this.expressionRoot = expressionRoot; + this.thisExpressionRoot = thisExpressionRoot; + this.statementsCanBeVariables = statementsCanBeVariables; + this.currentStatement = matchRoot; + this.literalsCanBeVariables = literalsCanBeVariables; + this.ignoreLiteralValues = ignoreLiteralValues; + this.allowDuplicatePointers = allowDuplicatePointers; + this.allowPropertyScan = allowPropertyScan; + this.allowTypeHierarchy = allowTypeHierarchy; + this.terminateOnFirstMatch = terminateOnFirstMatch; + this.ruleLinks = ruleLinks; + this.findMinimalMismatchRoot = findMinimalMismatchRoot; + this.traceVariableEliminations = traceVariableEliminations; + this.allowImplicitTypeConversions = false; + this.debug = false; + } + + public MatcherContext(final RuleContext ctx, RewriterStatement matchRoot, RewriterPredecessor pred, RewriterStatement expressionRoot, RewriterStatement thisExprRoot, final boolean statementsCanBeVariables, final boolean literalsCanBeVariables, final boolean ignoreLiteralValues, final boolean allowDuplicatePointers, final boolean allowPropertyScan, final boolean allowTypeHierarchy, final boolean terminateOnFirstMatch, final boolean findMinimalMismatchRoot, boolean traceVariableEliminations, boolean allowImplicitTypeConversions, final Map ruleLinks) { + this.ctx = ctx; + this.matchRoot = matchRoot; + this.pred = pred; + this.expressionRoot = expressionRoot; + this.thisExpressionRoot = thisExprRoot; + this.currentStatement = matchRoot; + this.statementsCanBeVariables = statementsCanBeVariables; + this.literalsCanBeVariables = literalsCanBeVariables; + this.ignoreLiteralValues = ignoreLiteralValues; + this.allowDuplicatePointers = allowDuplicatePointers; + this.allowPropertyScan = allowPropertyScan; + this.allowTypeHierarchy = allowTypeHierarchy; + this.terminateOnFirstMatch = terminateOnFirstMatch; + this.ruleLinks = ruleLinks; + this.findMinimalMismatchRoot = findMinimalMismatchRoot; + this.traceVariableEliminations = traceVariableEliminations; + this.allowImplicitTypeConversions = allowImplicitTypeConversions; + this.debug = false; + } + + private void fetchAssertions() { + if (!assertionsFetched) { + assertionsThat = (RewriterAssertions) expressionRoot.getMeta("_assertions"); + assertionsThis = (RewriterAssertions) thisExpressionRoot.getMeta("_assertions"); + assertionsFetched = true; + } + } + + public boolean allowsImplicitTypeConversions() { + return allowImplicitTypeConversions; + } + + public void dontVisitAgain(RewriterStatement stmt) { + if (dontVisitAgain == null) { + dontVisitAgain = new HashSet<>(); + } + + dontVisitAgain.add(stmt); + } + + public boolean wasVisited(RewriterStatement stmt) { + if (dontVisitAgain == null) + return false; + + return dontVisitAgain.contains(stmt); + } + + public RewriterAssertions getOldAssertionsThat() { + fetchAssertions(); + + return assertionsThat; + } + + public RewriterAssertions getOldAssertionsThis() { + fetchAssertions(); + + return assertionsThis; + } + + public Map getDependencyMap() { + if (dependencyMap == null) + if (allowDuplicatePointers) + dependencyMap = new HashMap<>(); + else + dependencyMap = new DualHashBidiMap(); + return dependencyMap; + } + + public List getLinks() { + if (links == null) + links = new ArrayList<>(); + return links; + } + + public RewriterStatement findInternalReference(RewriterStatement stmt) { + if (internalReferences == null) + return null; + return internalReferences.get(stmt); + } + + public RewriterStatement findReverseInternalReference(RewriterStatement stmt) { + if (internalReferences == null) + return null; + return internalReferences.getKey(stmt); + } + + public Map getInternalReferences() { + if (internalReferences == null) + internalReferences = new DualHashBidiMap<>(); + return internalReferences; + } + + public List getSubMatches() { + if (subMatches == null) + return Collections.emptyList(); + return subMatches; + } + + public boolean hasSubMatches() { + return subMatches != null && !subMatches.isEmpty(); + } + + public void addSubMatch(MatcherContext matcherContext) { + if (subMatches == null) + subMatches = new ArrayList<>(); + subMatches.addAll(matcherContext.getFlattenedSubMatches()); + } + + public List getFlattenedSubMatches() { + if (hasSubMatches()) + return subMatches.stream().flatMap(mCtx -> mCtx.getFlattenedSubMatches().stream()).collect(Collectors.toList()); + return Collections.emptyList(); + } + + public MatchingSubexpression toMatch() { + return new MatchingSubexpression(expressionRoot, matchRoot, pred, getDependencyMap(), getLinks()); + } + + public void reset() { + if (dependencyMap != null) + dependencyMap.clear(); + if (links != null) + links.clear(); + if (internalReferences != null) + internalReferences.clear(); + } + + public void setFirstMismatch(RewriterStatement stmt1, RewriterStatement stmt2) { + firstMismatch = new Tuple2<>(stmt1, stmt2); + } + + public Tuple2 getFirstMismatch() { + return firstMismatch; + } + + public MatcherContext debug(boolean debug) { + this.debug = debug; + return this; + } + + public boolean match() { + return thisExpressionRoot.match(this); + } + + public boolean isDebug() { + return debug; + } + + public static MatcherContext exactMatch(final RuleContext ctx, RewriterStatement stmt, RewriterStatement thisExprRoot) { + return new MatcherContext(ctx, stmt, stmt, thisExprRoot); + } + + public static MatcherContext exactMatchWithDifferentLiteralValues(final RuleContext ctx, RewriterStatement stmt, RewriterStatement thisExprRoot) { + return new MatcherContext(ctx, stmt, stmt, thisExprRoot, false, false, true, false, false, false, false, false, false, Collections.emptyMap()); + } + + public static MatcherContext findMinimalDifference(final RuleContext ctx, RewriterStatement stmt, RewriterStatement thisExpressionRoot) { + return new MatcherContext(ctx, stmt, stmt, thisExpressionRoot, false, false, true, false, false, false, false, true, false, Collections.emptyMap()); + } + } + + public static final class RewriterPredecessor { + private final Object obj; + private final Object meta; + + // Use iff the element is already the root + public RewriterPredecessor() { + obj = null; + meta = null; + } + + public RewriterPredecessor(RewriterStatement parent, Integer idx) { + obj = parent; + meta = idx; + } + + // Use iff the element is a meta object + public RewriterPredecessor(RewriterStatement parent, String meta) { + obj = parent; + this.meta = meta; + } + + public RewriterPredecessor(RewriterAssertions assertions, RewriterAssertions.RewriterAssertion assertion) { + obj = assertions; + meta = assertion; + } + + public boolean isOperand() { + return obj instanceof RewriterStatement && meta instanceof Integer; + } + + public boolean isRoot() { + return obj == null && meta == null; + } + + public boolean isMetaObject() { + return obj instanceof RewriterStatement && meta instanceof String; + } + + public boolean isAssertionObject() { + return obj instanceof RewriterAssertions && meta instanceof RewriterAssertions.RewriterAssertion; + } + + public RewriterStatement getParent() { + return (RewriterStatement) obj; + } + + public RewriterAssertions getAssertions() { + return (RewriterAssertions) obj; + } + + public RewriterAssertions.RewriterAssertion getAssertion() { + return (RewriterAssertions.RewriterAssertion) meta; + } + + public String getMetaKey() { + return (String) meta; + } + + public int getIndex() { + return (Integer) meta; + } + } + + public static enum ReferenceType { + ROOT, OPERAND, NCOL, NROW, BACKREF, ASSERTION + } + + public static class RewriterStatementReference { + public final ReferenceType referenceType; + public final RewriterStatement stmt; + public final Object parentRef; + public final Object ref; + + // TODO: What about root? + public RewriterStatementReference(ReferenceType type, RewriterStatement stmt, RewriterStatement parentRef) { + this.referenceType = type; + this.stmt = stmt; + this.parentRef = parentRef; + this.ref = null; + } + + public RewriterStatementReference(RewriterStatement stmt, RewriterStatement parentRef, int idx) { + this.referenceType = parentRef == null ? ReferenceType.ROOT : ReferenceType.OPERAND; + this.stmt = stmt; + this.parentRef = parentRef; + this.ref = idx; + } + + public RewriterStatementReference(RewriterStatement stmt, RewriterAssertions assertions, RewriterAssertions.RewriterAssertion assertion) { + this.referenceType = ReferenceType.ASSERTION; + this.stmt = stmt; + this.parentRef = assertions; + this.ref = assertion; + } + + public void replace(RewriterStatement newStmt) { + switch (referenceType) { + case ROOT: + throw new NotImplementedException(); + case OPERAND: + ((RewriterStatement) parentRef).getOperands().set((Integer)ref, newStmt); + break; + case NCOL: + ((RewriterStatement) parentRef).unsafePutMeta("ncol", newStmt); + break; + case NROW: + ((RewriterStatement) parentRef).unsafePutMeta("nrow", newStmt); + break; + case BACKREF: + ((RewriterStatement) parentRef).unsafePutMeta("backRef", newStmt); + break; + case ASSERTION: + ((RewriterAssertions) parentRef).replaceAssertionContent(stmt, newStmt, (RewriterAssertions.RewriterAssertion) ref); + break; + } + } + } + + public abstract String getId(); + public abstract String getResultingDataType(final RuleContext ctx); + public abstract boolean isLiteral(); + public abstract Object getLiteral(); + public abstract RewriterStatement getLiteralStatement(); + public long intLiteral() { + return intLiteral(false); + } + public abstract long intLiteral(boolean cast); + public abstract double floatLiteral(); + public abstract boolean boolLiteral(); + + public void setLiteral(Object literal) { + throw new IllegalArgumentException("This class does not support setting literals"); + } + public abstract RewriterStatement consolidate(final RuleContext ctx); + public abstract boolean isConsolidated(); + @Deprecated + public abstract RewriterStatement clone(); + public abstract RewriterStatement copyNode(); + // Performs a nested copy until a condition is met + public abstract RewriterStatement nestedCopyOrInject(Map copiedObjects, TriFunction injector, RewriterStatement parent, int pIdx); + // Returns the new maxRefId + public abstract int toParsableString(StringBuilder builder, Map refs, int maxRefId, Map> vars, Set forceCreateRefs, final RuleContext ctx); + public abstract void refreshReturnType(final RuleContext ctx); + protected abstract void compress(RewriterAssertions assertions); + + public static String parsableDefinitions(Map> defs) { + StringBuilder sb = new StringBuilder(); + defs.forEach((k, v) -> { + sb.append(k); + sb.append(':'); + + int i = 0; + for (String varName : v) { + if (i > 0) + sb.append(','); + + sb.append(varName); + i++; + } + + sb.append('\n'); + }); + + return sb.toString(); + } + + public String toParsableString(final RuleContext ctx, Map> defs) { + return toParsableString(ctx, defs, Collections.emptySet()); + } + + public String toParsableString(final RuleContext ctx, Map> defs, Set forceCreateRefs) { + StringBuilder sb = new StringBuilder(); + toParsableString(sb, new HashMap<>(), 0, defs, forceCreateRefs, ctx); + return sb.toString(); + } + + public String toParsableString(final RuleContext ctx, boolean includeDefinitions) { + return toParsableString(ctx, includeDefinitions, Collections.emptySet()); + } + + public String toParsableString(final RuleContext ctx, boolean includeDefinitions, Set forceCreateRefs) { + StringBuilder sb = new StringBuilder(); + HashMap> defs = new HashMap<>(); + toParsableString(sb, new HashMap<>(), 0, defs, forceCreateRefs, ctx); + + if (includeDefinitions) + return parsableDefinitions(defs) + sb; + + return sb.toString(); + } + + public String toParsableString(final RuleContext ctx) { + return toParsableString(ctx, false); + } + + public RewriterStatement nestedCopyOrInject(Map copiedObjects, TriFunction injector) { + return nestedCopyOrInject(copiedObjects, injector, null, -1); + } + + public RewriterStatement nestedCopyOrInject(Map copiedObjects, Function injector) { + return nestedCopyOrInject(copiedObjects, (el, parent, pIdx) -> injector.apply(el), null, -1); + } + + public RewriterStatement nestedCopy(boolean copyAssertions) { + return nestedCopy(copyAssertions, new HashMap<>()); + } + + public RewriterStatement nestedCopy(boolean copyAssertions, Map createdObjects) { + RewriterStatement cpy = nestedCopyOrInject(createdObjects, el -> null); + + if (copyAssertions) { + RewriterAssertions assertions = (RewriterAssertions) getMeta("_assertions"); + + if (assertions != null) { + cpy.unsafePutMeta("_assertions", RewriterAssertions.copy(assertions, createdObjects, true)); + } + } else { + cpy.unsafeRemoveMeta("_assertions"); + } + + return cpy; + } + + // Returns the root of the matching sub-statement, null if there is no match + public abstract boolean match(MatcherContext matcherContext); + + public abstract int recomputeHashCodes(boolean recursively, final RuleContext ctx); + public abstract RewriterStatement simplify(final RuleContext ctx); + public abstract RewriterStatement as(String id); + public abstract String toString(final RuleContext ctx); + public abstract boolean isArgumentList(); + public abstract List getArgumentList(); + public abstract boolean isInstruction(); + public abstract boolean isEClass(); + public abstract String trueInstruction(); + public abstract String trueTypedInstruction(final RuleContext ctx); + public abstract String trueTypedInstruction(boolean allowImplicitConversions, final RuleContext ctx); + public abstract int structuralHashCode(); + public abstract RewriterStatement rename(String id); + public void prepareDefinitions(final RuleContext ctx, final List strDefs, final Set varDefs) { + if (getMeta(META_VARNAME) != null) + return; + + if (getOperands() != null) + getOperands().forEach(op -> op.prepareDefinitions(ctx, strDefs, varDefs)); + + if (this instanceof RewriterInstruction) { + RewriterInstruction self = ((RewriterInstruction) this); + // Check if it is necessary to define variables + if (refCtr > 1 || self.trueInstruction().equals("_asVar")) { + Pattern pattern = Pattern.compile("[a-zA-Z0-9_]+"); + String instr = pattern.matcher(self.getInstr()).matches() ? self.getInstr() : "tmp"; + instr = instr.replace("_", ""); + String varName = "var_" + instr + "_"; + + int ctr = 1; + while (varDefs.contains(varName + ctr)) + ctr++; + + strDefs.add(varName + ctr + " = " + toString(ctx)); + varDefs.add(varName + ctr); + unsafePutMeta(META_VARNAME, varName + ctr); + } + } + } + + public void eraseDefinitions() { + unsafeRemoveMeta(META_VARNAME); + + if (getOperands() != null) + getOperands().forEach(RewriterStatement::eraseDefinitions); + } + + public List getOperands() { + return Collections.emptyList(); + } + + public int recomputeHashCodes(final RuleContext ctx) { + return recomputeHashCodes(true, ctx); + } + + public void prepareForHashing() { + resetRefCtrs(); + computeRefCtrs(); + resetIds(); + computeIds(1); + } + + protected void resetRefCtrs() { + refCtr = 0; + if (getOperands() != null) + getOperands().forEach(RewriterStatement::resetRefCtrs); + } + + protected void computeRefCtrs() { + refCtr++; + if (refCtr < 2 && getOperands() != null) + getOperands().forEach(RewriterStatement::computeRefCtrs); + } + + protected void resetIds() { + rid = 0; + if (getOperands() != null) + getOperands().forEach(RewriterStatement::resetIds); + } + + protected int computeIds(int id) { + rid = id++; + + if (getOperands() != null) { + for (RewriterStatement stmt : getOperands()) + id = stmt.computeIds(id); + } + + return id; + } + + /** + * Traverses the DAG in-order. If nodes with multiple parents exist, those are visited multiple times. + * If the function returns false, the sub-DAG of the current node will not be traversed. + * @param function test + */ + @Deprecated + public void forEachPreOrderWithDuplicates(Function function) { + if (function.apply(this) && getOperands() != null) + for (int i = 0; i < getOperands().size(); i++) + getOperands().get(i).forEachPreOrderWithDuplicates(function); + } + + public void forEachPreOrder(Function function, boolean includeMeta) { + forEachPreOrder((el, pred) -> function.apply(el), includeMeta); + } + + public void forEachPreOrder(BiFunction function, boolean includeMeta) { + forEachPreOrder(function, new HashSet<>(), new RewriterPredecessor(), includeMeta); + } + + // We will also include metadata + private void forEachPreOrder(BiFunction function, Set visited, RewriterPredecessor pred, boolean includeMeta) { + if (!visited.add(this)) + return; + + if (function.apply(this, pred)) { + for (int i = 0; i < getOperands().size(); i++) + getOperands().get(i).forEachPreOrder(function, visited, new RewriterPredecessor(this, i), includeMeta); + + if (includeMeta) + forEachMetaObject((stmt, mPred) -> stmt.forEachPreOrder(function, visited, mPred, includeMeta)); + } + } + + public void forEachPostOrder(BiConsumer consumer, boolean includeMeta) { + forEachPostOrder(consumer, new HashSet<>(), new RewriterPredecessor(), includeMeta); + } + + private void forEachPostOrder(BiConsumer consumer, Set visited, RewriterPredecessor pred, boolean includeMeta) { + if (!visited.add(this)) + return; + + if (getOperands() != null) + for (int i = 0; i < getOperands().size(); i++) + getOperands().get(i).forEachPostOrder(consumer, visited, new RewriterPredecessor(this, i), includeMeta); + + if (includeMeta) + forEachMetaObject((stmt, mPred) -> stmt.forEachPostOrder(consumer, visited, mPred, includeMeta)); + + consumer.accept(this, pred); + } + + @Deprecated + public void forEachPostOrderWithDuplicates(TriConsumer consumer) { + forEachPostOrderWithDuplicates(consumer, null, -1); + } + + @Deprecated + private void forEachPostOrderWithDuplicates(TriConsumer consumer, RewriterStatement parent, int pIdx) { + for (int i = 0; i < getOperands().size(); i++) + getOperands().get(i).forEachPostOrderWithDuplicates(consumer, this, i); + + consumer.accept(this, parent, pIdx); + } + + public void putMeta(String key, Object value) { + if (isConsolidated()) + throw new IllegalArgumentException("An instruction cannot be modified after consolidation"); + + if (meta == null) + meta = new HashMap<>(); + + meta.put(key, value); + } + + public void unsafePutMeta(String key, Object value) { + if (isLiteral()) + throw new UnsupportedOperationException("Cannot put meta for literals"); + + if (meta == null) + meta = new HashMap<>(); + + meta.put(key, value); + } + + public void unsafeRemoveMeta(String key) { + if (meta == null) + return; + + meta.remove(key); + + if (meta.isEmpty()) + meta = null; + } + + public Object getMeta(String key) { + if (meta == null) + return null; + + return meta.get(key); + } + + public long getCost() { + if (!isInstruction()) + return 0; + + return cost; + } + + public RewriterAssertions getAssertions(final RuleContext ctx) { + RewriterAssertions assertions = (RewriterAssertions) getMeta("_assertions"); + if (assertions == null) { + assertions = new RewriterAssertions(ctx); + if (!isLiteral()) // Otherwise the assertion object will just be temporary + unsafePutMeta("_assertions", assertions); + } + + return assertions; + } + + public RewriterStatement getNCol() { + return (RewriterStatement) getMeta("ncol"); + } + + public RewriterStatement getNRow() { + return (RewriterStatement) getMeta("nrow"); + } + + public RewriterStatement getBackRef() { + return (RewriterStatement) getMeta("_backRef"); + } + + public RewriterStatement getChild(int index) { + return getOperands().get(index); + } + + public RewriterStatement getChild(int... indices) { + RewriterStatement current = this; + + for (int i = 0; i < indices.length; i++) + current = current.getOperands().get(indices[i]); + + return current; + } + + // This can only be called from the root expression to add a new assertion manually + public RewriterStatement givenThatEqualDimensions(RewriterStatement stmt1, RewriterStatement stmt2, final RuleContext ctx) { + getAssertions(ctx).addEqualityAssertion(stmt1.getNRow(), stmt2.getNRow(), this); + getAssertions(ctx).addEqualityAssertion(stmt1.getNCol(), stmt2.getNCol(), this); + return this; + } + + // This can only be called from the root expression to add a new assertion manually + public RewriterStatement givenThatEqual(RewriterStatement stmt1, RewriterStatement stmt2, final RuleContext ctx) { + return givenThatEqual(stmt1, stmt2, this, ctx); + } + + public RewriterStatement givenThatEqual(RewriterStatement stmt1, RewriterStatement stmt2, RewriterStatement exprRoot, final RuleContext ctx) { + getAssertions(ctx).addEqualityAssertion(stmt1, stmt2, exprRoot); + return this; + } + + public RewriterStatement recomputeAssertions() { + RewriterAssertions assertions = (RewriterAssertions) getMeta("_assertions"); + + if (assertions != null) + return assertions.update(this); + + return this; + } + + public static void transferMeta(RewriterRule.ExplicitLink link) { + if (link.oldStmt instanceof RewriterInstruction) { + for (RewriterStatement mNew : link.newStmt) { + if (mNew instanceof RewriterInstruction && + !((RewriterInstruction)mNew).trueInstruction().equals(((RewriterInstruction)link.oldStmt).trueInstruction())) { + ((RewriterInstruction) mNew).unsafeSetInstructionName(((RewriterInstruction)link.oldStmt).trueInstruction()); + } + } + } + + if (link.oldStmt.meta != null) { + link.newStmt.forEach(stmt -> { + HashMap newMap = new HashMap<>(link.oldStmt.meta); + stmt.overwriteImplicitMetaObjects(newMap); + stmt.meta = newMap; + }); + } + else + link.newStmt.forEach(RewriterStatement::cleanupMeta/*stmt.meta = null*/); + } + + public void moveRootTo(RewriterStatement newRoot) { + RewriterAssertions assertions = (RewriterAssertions) getMeta("_assertions"); + + if (assertions != null && !newRoot.isLiteral()) + newRoot.unsafePutMeta("_assertions", assertions); + } + + private void overwriteImplicitMetaObjects(Map map) { + RewriterAssertions assertions = (RewriterAssertions) getMeta("_assertions"); + RewriterStatement ncol = getNCol(); + RewriterStatement nrow = getNRow(); + RewriterStatement backref = getBackRef(); + + if (assertions != null) + map.put("_assertions", assertions); + + if (ncol != null) + map.put("ncol", ncol); + + if (nrow != null) + map.put("nrow", nrow); + + if (backref != null) + map.put("_backRef", backref); + } + + private void cleanupMeta() { + if (meta == null) + return; + + RewriterAssertions assertions = (RewriterAssertions) getMeta("_assertions"); + RewriterStatement ncol = getNCol(); + RewriterStatement nrow = getNRow(); + RewriterStatement backref = getBackRef(); + + if (assertions == null && ncol == null && nrow == null && backref == null) + return; + + meta = new HashMap<>(); + + if (assertions != null) + meta.put("_assertions", assertions); + + if (ncol != null) + meta.put("ncol", ncol); + + if (nrow != null) + meta.put("nrow", nrow); + + if (backref != null) + meta.put("_backRef", ncol); + } + + @Override + public String toString() { + return toString(RuleContext.currentContext); + } + + public boolean isColVector() { + RewriterStatement nrow = getNRow(); + + if (nrow == null) + return false; + + if (nrow.isLiteral() && nrow.getLiteral().equals(1L)) + return true; + + if (nrow.isEClass() && nrow.getChild(0).getOperands().stream().anyMatch(el -> el.isLiteral() && el.getLiteral().equals(1L))) + return true; + + return false; + } + + public boolean isRowVector() { + RewriterStatement ncol = getNCol(); + + if (ncol == null) + return false; + + if (ncol.isLiteral() && ncol.getLiteral().equals(1L)) + return true; + + if (ncol.isEClass() && ncol.getChild(0).getOperands().stream().anyMatch(el -> el.isLiteral() && el.getLiteral().equals(1L))) + return true; + + return false; + } + + public List toExecutableString(final RuleContext ctx) { + ArrayList defList = new ArrayList<>(); + prepareDefinitions(ctx, defList, new HashSet<>()); + defList.add(toString(ctx)); + eraseDefinitions(); + + return defList; + } + + public void compress() { + RewriterAssertions assertions = (RewriterAssertions) getMeta("_assertions"); + this.forEachPostOrder((cur, pred) -> { + cur.compress(assertions); + }, true); + } + + public long getCost(final RuleContext ctx) { + if (!this.isInstruction()) + return 0; + + if (cost != -2) + return cost; + + try { + cost = RewriterCostEstimator.estimateCost(this, ctx); + } catch (Exception e) { + cost = -1L; + } + + return cost; + } + + // This may create cycles if visited objects are not tracked + public void forEachMetaObject(BiConsumer consumer) { + RewriterStatement backref = getBackRef(); + RewriterAssertions assertions = (RewriterAssertions) getMeta("_assertions"); + + if (backref != null) + consumer.accept(backref, new RewriterPredecessor(this, "_backRef")); + if (assertions != null) + assertions.forEachAssertionContents(consumer); + } + + public void updateMetaObjects(Function f) { + RewriterStatement backref = getBackRef(); + + RewriterStatement mNew; + + if (backref != null) { + mNew = f.apply(backref); + + if (backref != mNew) + unsafePutMeta("_backRef", backref); + } + + RewriterAssertions assertions = (RewriterAssertions) getMeta("_assertions"); + + if (assertions != null) + assertions.updateAssertionContents(f); + } + + protected void nestedCopyOrInjectMetaStatements(Map copiedObjects, TriFunction injector) { + if (getNCol() != null) { + unsafePutMeta("ncol", getNCol().nestedCopyOrInject(copiedObjects, injector, this, -1)); + } + + if (getNRow() != null) + unsafePutMeta("nrow", getNRow().nestedCopyOrInject(copiedObjects, injector, this, -1)); + + RewriterStatement backRef = (RewriterStatement) getMeta("_backRef"); + + if (backRef != null) + unsafePutMeta("_backRef", backRef.nestedCopyOrInject(copiedObjects, injector, this, -1)); + + RewriterAssertions assertions = (RewriterAssertions) getMeta("_assertions"); + + if (assertions != null) { + assertions = assertions.nestedCopyOrInject(copiedObjects, injector, this); + unsafePutMeta("_assertions", assertions); + } + } + + // This returns a stream of all children including metadata and assertions if available + // This may contain loops in case of back references + public Stream> allChildren() { + Stream> stream = IntStream.range(0, getOperands().size()).mapToObj(i -> new Tuple2<>(getOperands().get(i), new RewriterPredecessor(this, i))); + RewriterStatement ncol = getNCol(); + RewriterStatement nrow = getNRow(); + RewriterStatement backRef = getBackRef(); + + if (ncol != null) + stream = Stream.concat(stream, Stream.of(new Tuple2<>(ncol, new RewriterPredecessor(this, "ncol")))); + if (nrow != null) + stream = Stream.concat(stream, Stream.of(new Tuple2<>(nrow, new RewriterPredecessor(this, "nrow")))); + if (backRef != null) + stream = Stream.concat(stream, Stream.of(new Tuple2<>(backRef, new RewriterPredecessor(this, "_backRef")))); + + RewriterAssertions assertions = (RewriterAssertions) getMeta("_assertions"); + + if (assertions != null) + stream = Stream.concat(stream, assertions.streamOfContents()); + + return stream; + } + + public boolean isDataOrigin() { + if (!isInstruction()) + return true; + + switch (trueInstruction()) { + case "rowVec": + case "colVec": + case "const": + return true; + } + + return false; + } + + public int countInstructions() { + MutableInt i = new MutableInt(); + forEachPreOrder(cur -> { + if (!cur.isDataOrigin() || cur.isLiteral()) { + i.add(1 + cur.getOperands().size()); + } + return true; + }, false); + return i.getAndIncrement(); + } + + public static RewriterStatement argList(final RuleContext ctx, RewriterStatement... args) { + return new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("argList").withOps(args).consolidate(ctx); + } + + public static RewriterStatement argList(final RuleContext ctx, List args) { + return argList(ctx, args.toArray(RewriterStatement[]::new)); + } + + public static RewriterStatement castFloat(final RuleContext ctx, RewriterStatement stmt) { + return new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("cast.FLOAT").withOps(stmt).consolidate(ctx); + } + + public static RewriterStatement nnz(RewriterStatement of, final RuleContext ctx) { + return nnz(of, ctx, false); + } + + public static RewriterStatement nnz(RewriterStatement of, final RuleContext ctx, boolean treatAsDense) { + if (treatAsDense) + return StatementUtils.length(ctx, of); + return new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("_nnz").withOps(of).consolidate(ctx); + } + + public static RewriterStatement literal(final RuleContext ctx, Object literal) { + if (literal == null) + throw new IllegalArgumentException(); + + if (literal instanceof Double) { // We need to differentiate between -0.0 and 0.0 because otherwise this may leed to bugs + return new RewriterDataType().as(literal.toString()).ofType("FLOAT").asLiteral(((Double) literal).doubleValue() == -0.0 ? 0.0 : literal).consolidate(ctx); + } else if (literal instanceof Long) { + return new RewriterDataType().as(literal.toString()).ofType("INT").asLiteral(literal).consolidate(ctx); + } else if (literal instanceof Boolean) { + return new RewriterDataType().as(literal.toString()).ofType("BOOL").asLiteral(literal).consolidate(ctx); + } + + throw new IllegalArgumentException(); + } + + public static RewriterStatement multiArgInstr(final RuleContext ctx, String instrName, RewriterStatement... ops) { + RewriterStatement argList = new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("argList").withOps(ops).consolidate(ctx); + return new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction(instrName).withOps(argList).consolidate(ctx); + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriterStatementEntry.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriterStatementEntry.java new file mode 100644 index 00000000000..80daebebc32 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriterStatementEntry.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter; + +import java.util.HashMap; + +public class RewriterStatementEntry { + private final RuleContext ctx; + final RewriterStatement instr; + + public RewriterStatementEntry(final RuleContext ctx, RewriterStatement instr) { + this.ctx = ctx; + this.instr = instr; + } + + @Override + public int hashCode() { + return instr.structuralHashCode(); + } + + @Override + public boolean equals(Object o) { + if (o instanceof RewriterStatement) { + if (instr == o) + return true; + if (instr.structuralHashCode() != ((RewriterStatement)o).structuralHashCode()) + return false; + return instr.match(new RewriterStatement.MatcherContext(ctx, (RewriterStatement) o, new RewriterStatement.RewriterPredecessor(), (RewriterStatement) o, instr, false, false, false, false, false, false, true, false, false, false, new HashMap<>())); + } + + if (o.hashCode() != hashCode()) + return false; + + if (o instanceof RewriterStatementEntry) { + if (instr == ((RewriterStatementEntry) o).instr) + return true; + return instr.match(new RewriterStatement.MatcherContext(ctx, ((RewriterStatementEntry) o).instr, new RewriterStatement.RewriterPredecessor(), ((RewriterStatementEntry) o).instr, instr, false, false, false, false, false, false, true, false, false, false, new HashMap<>())); + } + return false; + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RuleContext.java b/src/main/java/org/apache/sysds/hops/rewriter/RuleContext.java new file mode 100644 index 00000000000..978cb62501f --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/RuleContext.java @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.BiFunction; +import java.util.function.Function; + +public class RuleContext { + public static RuleContext currentContext; + + public HashMap, Long>> instrCosts = new HashMap<>(); + + public HashMap instrTypes = new HashMap<>(); + + public HashMap> simplificationRules = new HashMap<>(); + + public HashMap> instrProperties = new HashMap<>(); + + public HashMap> typeHierarchy = new HashMap<>(); + + public HashMap> customStringRepr = new HashMap<>(); + + public Function metaPropagator = null; + + public static RuleContext floatArithmetic = new RuleContext(); + public static RuleContext selectionPushdownContext = new RuleContext(); + + static { + floatArithmetic.instrCosts.put("+(float,float)", d -> 1l); + floatArithmetic.instrCosts.put("*(float,float)", d -> 1l); + + floatArithmetic.instrTypes.put("+(float,float)", "float"); + floatArithmetic.instrTypes.put("*(float,float)", "float"); + + floatArithmetic.simplificationRules.put("+(float,float)", i -> { + RewriterStatement op1 = i.getOperands().get(0); + RewriterStatement op2 = i.getOperands().get(1); + + if (op1.isLiteral() && op2.isLiteral()) { + op1.setLiteral(((Float)op1.getLiteral()) + ((Float)op2.getLiteral())); + return op1; + } + + return null; + }); + floatArithmetic.simplificationRules.put("*(float, float)", i -> { + RewriterStatement op1 = i.getOperands().get(0); + RewriterStatement op2 = i.getOperands().get(1); + + if (op1.isLiteral() && op2.isLiteral()) { + op1.setLiteral(((Float)op1.getLiteral()) * ((Float)op2.getLiteral())); + return op1; + } + + return null; + }); + + selectionPushdownContext.instrCosts.put("RowSelectPushableBinaryInstruction(MATRIX,MATRIX)", d -> 1l); // Just temporary costs + selectionPushdownContext.instrTypes.put("RowSelectPushableBinaryInstruction(MATRIX,MATRIX)", "MATRIX"); + selectionPushdownContext.instrCosts.put("rowSelect(MATRIX,INT,INT)", d -> 1l); + selectionPushdownContext.instrTypes.put("rowSelect(MATRIX,INT,INT)", "MATRIX"); + selectionPushdownContext.instrCosts.put("min(INT,INT)", d -> 1l); + selectionPushdownContext.instrTypes.put("min(INT,INT)", "INT"); + selectionPushdownContext.instrCosts.put("max(INT,INT)", d -> 1l); + selectionPushdownContext.instrTypes.put("max(INT,INT)", "INT"); + + selectionPushdownContext.instrCosts.put("+(MATRIX,MATRIX)", d -> 1l); + selectionPushdownContext.instrTypes.put("+(MATRIX,MATRIX)", "MATRIX"); + } + + public static RuleContext createContext(String contextString) { + RuleContext ctx = new RuleContext(); + HashMap instrTypes = ctx.instrTypes; + HashMap> instrProps = ctx.instrProperties; + String[] lines = contextString.split("\n"); + String fName = null; + String fArgTypes = null; + String fReturnType = null; + for (String line : lines) { + line = line.replaceFirst("^\\s+", ""); + if (line.isEmpty()) + continue; + + if (line.startsWith("impl")) { + if (fArgTypes == null || fReturnType == null) + throw new IllegalArgumentException(); + String newFName = line.substring(4).replace(" ", ""); + if (newFName.isEmpty()) + throw new IllegalArgumentException(); + + instrTypes.put(newFName + fArgTypes, fReturnType); + + final String propertyFunction = fName + fArgTypes; + + if (instrProps.containsKey(newFName + fArgTypes)) { + HashSet props = instrProps.get(newFName + fArgTypes); + props.add(propertyFunction); + props.add(fName); + } else { + HashSet mset = new HashSet<>(); + mset.add(propertyFunction); + mset.add(fName); + instrProps.put(newFName + fArgTypes, mset); + } + + ctx.instrCosts.put(newFName + fArgTypes, d -> 1L); + } else if (line.startsWith("dtype ")) { + String[] dTypeStr = line.substring(6).split("::"); + if (dTypeStr.length > 1) { + Set mSet = ctx.typeHierarchy.compute(dTypeStr[0], (k, v) -> v == null ? new HashSet<>() : v); + for (int i = 1; i < dTypeStr.length; i++) + mSet.add(dTypeStr[i]); + } + + } else { + String[] keyVal = readFunctionDefinition(line); + fName = keyVal[0]; + fArgTypes = keyVal[1]; + fReturnType = keyVal[2]; + instrTypes.put(fName + fArgTypes, fReturnType); + ctx.instrCosts.put(fName + fArgTypes, d -> 1L); + } + } + + // Resolve transitive function properties + boolean changed = true; + while (changed) { + changed = false; + for (Map.Entry> pair : instrProps.entrySet()) { + HashSet toAdd = new HashSet<>(); + for (String propertyFunction : pair.getValue()) { + if (instrProps.containsKey(propertyFunction)) + toAdd.addAll(instrProps.get(propertyFunction)); + } + + changed |= pair.getValue().addAll(toAdd); + } + } + + changed = true; + while (changed) { + changed = false; + for (Map.Entry> pair : ctx.typeHierarchy.entrySet()) { + HashSet toAdd = new HashSet<>(); + for (String superTypes : pair.getValue()) { + if (instrProps.containsKey(superTypes)) + toAdd.addAll(instrProps.get(superTypes)); + } + + changed |= pair.getValue().addAll(toAdd); + } + } + + return ctx; + } + + public static String[] readFunctionDefinition(String line) { + int leftParanthesisIdx = line.indexOf('('); + + if (leftParanthesisIdx == -1) + throw new IllegalArgumentException(); + + String fName = line.substring(0, leftParanthesisIdx).replace(" ", ""); + String rest = line.substring(leftParanthesisIdx+1); + + int parenthesisCloseIdx = rest.indexOf(')'); + + if (parenthesisCloseIdx == -1) + throw new IllegalArgumentException(); + + String argsStr = rest.substring(0, parenthesisCloseIdx); + String[] args = argsStr.split(","); + + args = Arrays.stream(args).map(arg -> arg.replace(" ", "")).toArray(String[]::new); + + if (args.length != 1 && Arrays.stream(args).anyMatch(String::isEmpty)) + throw new IllegalArgumentException(); + + if (!rest.substring(parenthesisCloseIdx+1, parenthesisCloseIdx+3).equals("::")) + throw new IllegalArgumentException(); + + String returnDataType = rest.substring(parenthesisCloseIdx+3); + return new String[] { fName, "(" + String.join(",", args) + ")", returnDataType }; + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/TopologicalSort.java b/src/main/java/org/apache/sysds/hops/rewriter/TopologicalSort.java new file mode 100644 index 00000000000..94bf6f029fb --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/TopologicalSort.java @@ -0,0 +1,543 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter; + +import org.apache.commons.lang3.mutable.MutableInt; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.function.BiFunction; + +// We assume that _argList() will have one unique parent +public class TopologicalSort { + protected static final Log LOG = LogFactory.getLog(TopologicalSort.class.getName()); + + public static boolean DEBUG = false; + + // All of these operators are sortable with argument lists (e.g. +(argList(1, 2, 3)) + private static final Set SORTABLE_ARGLIST_OPS = Set.of("+", "*", "_idxExpr", "_EClass", "rand", "_dummy"); + // All of these operators are sortable but have their operands directly as children (e.g. ==(a,b)) + private static final Set SORTABLE_OPS = Set.of("==", "!="); + + public static void sort(RewriterStatement root, final RuleContext ctx) { + sort(root, (el, parent) -> { + if (!el.isInstruction()) + return false; + + if (el.isArgumentList()) + return parent != null && SORTABLE_ARGLIST_OPS.contains(parent.trueInstruction()); + + return SORTABLE_OPS.contains(el.trueInstruction()); + }, ctx); + } + + public static void sort(RewriterStatement root, BiFunction isArrangable, final RuleContext ctx) { + // First, we setup an artificial root node to be able to sort E-Classes that are only included as meta-info not directly in the operand structure + Set hiddenEClasses = new HashSet<>(); + root.forEachPostOrder((stmt, pred) -> { + if (stmt instanceof RewriterDataType && !stmt.isLiteral() && stmt.getResultingDataType(ctx).equals("MATRIX")) { + if (stmt.getNRow().isInstruction() && stmt.getNRow().trueInstruction().equals("_EClass")) + hiddenEClasses.add(stmt.getNRow()); + + if (stmt.getNCol().isInstruction() && stmt.getNCol().trueInstruction().equals("_EClass")) + hiddenEClasses.add(stmt.getNCol()); + } + }, true); + + RewriterStatement oldRoot = root; + + if (!hiddenEClasses.isEmpty()) { + RewriterStatement argList = new RewriterInstruction().withInstruction("argList").withOps(hiddenEClasses.toArray(RewriterStatement[]::new)); + RewriterStatement dummy = new RewriterInstruction().withInstruction("_dummy").withOps(argList); + root = new RewriterInstruction().withInstruction("_root").withOps(root, dummy); + } + + List uncertainParents = setupOrderFacts(root, isArrangable, ctx); + + buildAddresses(root, ctx); + resolveAmbiguities(root, ctx, uncertainParents); + resetAddresses(uncertainParents); + + int factCtr = 0; + + // Now, we start introducing facts for the lowest level unordered sets + Set lowestUncertainties = findLowestUncertainties(root); + int ctr = 0; + + while (!lowestUncertainties.isEmpty()) { + if (DEBUG) { + LOG.trace("Uncertainties after iteration " + ctr + ": " + lowestUncertainties.size()); + LOG.trace("Lowest uncertainties: " + lowestUncertainties); + } + + factCtr = introduceFacts(lowestUncertainties, factCtr); + buildAddresses(root, ctx); + + if (DEBUG) { + LOG.trace("Built addresses:"); + for (UnorderedSet u : lowestUncertainties) { + for (RewriterStatement s : u.contents) { + LOG.trace("- " + s + " :: " + getAddress(s)); + } + } + } + + resolveAmbiguities(root, ctx, uncertainParents); + resetAddresses(uncertainParents); + + lowestUncertainties = findLowestUncertainties(root); + ctr++; + + if (ctr > 100) + throw new RuntimeException("Could not finish sorting process for expression:\n" + root.toParsableString(ctx)); // Should never get here but just to make sure + } + + // At the end + if (DEBUG) + LOG.trace("Before construction: " + oldRoot.toParsableString(ctx)); + constructNewDAG(oldRoot, ctx); + if (DEBUG) + LOG.trace("After construction: " + oldRoot.toParsableString(ctx)); + } + + // Returns all uncertain parents ordered in post order (elements without uncertain sub-DAGs come first in the list) + private static List setupOrderFacts(RewriterStatement root, BiFunction isArrangable, final RuleContext ctx) { + List uncertainParents = new ArrayList<>(); + + // Create a random global order which will be used for indistinguishable sub-DAGs + MutableInt nameCtr = new MutableInt(0); + root.forEachPostOrder((el, pred) -> { + if (el.isLiteral()) + return; + + el.unsafePutMeta("_tempName", nameCtr.intValue()); + nameCtr.increment(); + boolean arrangable = isArrangable.apply(el, pred.getParent()); + + el.unsafePutMeta("_arrangable", arrangable); + }, false); + + // Try to establish a first order + root.forEachPostOrder((el, pred) -> { + if (el.isLiteral()) + return; + + boolean arrangable = (boolean) el.getMeta("_arrangable"); + + List knownOrder = new ArrayList<>(); + el.unsafePutMeta("_knownOrder", knownOrder); + + if (arrangable) { + el.getOperands().sort((cmp1, cmp2) -> compare(cmp1, cmp2, ctx)); + + boolean containsUnorderedSet = false; + + List currSet = new ArrayList<>(); + currSet.add(el.getOperands().get(0)); + + for (int i = 1; i < el.getOperands().size(); i++) { + if (compare(el.getOperands().get(i-1), el.getOperands().get(i), ctx) != 0) { + if (currSet.size() == 1) { + knownOrder.add(currSet.get(0)); + currSet.clear(); + } else { + final RewriterStatement first = currSet.get(0); + if (currSet.stream().allMatch(mEl -> first == mEl)) { + // Then this is not an unordered set as it only contains one instance and the order doesn't matter + knownOrder.addAll(currSet); + currSet.clear(); + } else { + containsUnorderedSet = true; + currSet.forEach(cur -> { + if (!cur.isLiteral()) + cur.unsafePutMeta("_addresses", new ArrayList()); + }); + knownOrder.add(new UnorderedSet(currSet)); + currSet = new ArrayList<>(); + } + } + } + + currSet.add(el.getOperands().get(i)); + } + + if (currSet.size() == 1) + knownOrder.add(currSet.get(0)); + else { + final RewriterStatement first = currSet.get(0); + if (currSet.stream().allMatch(first::equals)) { + knownOrder.addAll(currSet); + } else { + containsUnorderedSet = true; + currSet.forEach(cur -> { + if (!cur.isLiteral()) + cur.unsafePutMeta("_addresses", new ArrayList()); + }); + knownOrder.add(new UnorderedSet(currSet)); + } + } + + if (containsUnorderedSet) + uncertainParents.add(el); + } else { + knownOrder.addAll(el.getOperands()); + } + + if (DEBUG) + LOG.trace("Initial known order of " + el.toParsableString(ctx) + ": " + knownOrder); + }, false); + + return uncertainParents; + } + + private static int introduceFacts(Collection sets, int factCtr) { + for (RewriterStatement stmt : allChildren(sets)) { + if (stmt.isLiteral()) + continue; + + if (stmt.getMeta("_addresses") == null) + stmt.unsafePutMeta("_addresses", new ArrayList<>()); + + if (stmt.getMeta("_fact") == null) + stmt.unsafePutMeta("_fact", factCtr++); + } + + return factCtr; + } + + // Returns a list of all unordered set that do not contain other unordered set + private static Set findLowestUncertainties(RewriterStatement root) { + Set set = new HashSet<>(); + recursivelyFindLowestUncertainties(root, set); + + List tmpList = new ArrayList<>(set); + Set minSet = new HashSet<>(); + // We have the issue that uncertainties might still depend on each other (e.g. {a,b}, {inv(a),inv(b)}), even if they are the lowest entries + // Theoretically, this comparison might still lead to amgibuities, but never occurred in our examples + int minCumSize = Integer.MAX_VALUE; + for (int i = 0; i < tmpList.size(); i++) { + int cumSize = tmpList.get(i).contents.stream().map(RewriterStatement::countInstructions).reduce(0, Integer::sum); + + if (cumSize < minCumSize) { + minSet.clear(); + minCumSize = cumSize; + } + + if (cumSize <= minCumSize) + minSet.add(tmpList.get(i)); + } + + return minSet; + } + + // All children in post order and unique + private static List allChildren(Collection unorderedSets) { + Set is = new HashSet<>(); + List children = new ArrayList<>(); + for (UnorderedSet set : unorderedSets) + for (RewriterStatement s : set.contents) + traverse(s, is, children); + + return children; + } + + private static void traverse(RewriterStatement stmt, Set visited, List l) { + if (visited.contains(stmt)) + return; + + visited.add(stmt); + stmt.getOperands().forEach(el -> traverse(el, visited, l)); + + l.add(stmt); + } + + private static boolean recursivelyFindLowestUncertainties(RewriterStatement current, Set lowestUncertainties) { + if (current.isLiteral()) + return false; + + List knownOrder = (List) current.getMeta("_knownOrder"); + boolean containsUncertainty = false; + + for (Object o : knownOrder) { + if (o instanceof RewriterStatement) { + containsUncertainty |= recursivelyFindLowestUncertainties((RewriterStatement) o, lowestUncertainties); + } else { + UnorderedSet set = (UnorderedSet) o; + containsUncertainty = true; + boolean foundEmbeddedUncertainty = set.contents.stream().anyMatch(stmt -> recursivelyFindLowestUncertainties(stmt, lowestUncertainties)); + + if (foundEmbeddedUncertainty) + lowestUncertainties.remove(set); + else + lowestUncertainties.add(set); + } + } + + return containsUncertainty; + } + + public static void constructNewDAG(RewriterStatement root, final RuleContext ctx) { + root.forEachPostOrder((cur, pred) -> { + List knownOrder = (List) cur.getMeta("_knownOrder"); + if (DEBUG) + LOG.trace("KnownOrder of " + cur.toParsableString(ctx) + ": " + knownOrder); + + for (int i = 0; i < cur.getOperands().size(); i++) + cur.getOperands().set(i, (RewriterStatement) knownOrder.get(i)); + + cur.unsafeRemoveMeta("_knownOrder"); + cur.unsafeRemoveMeta("_addresses"); + cur.unsafeRemoveMeta("_address"); + cur.unsafeRemoveMeta("_arrangable"); + cur.unsafeRemoveMeta("_tempName"); + }, false); + + root.prepareForHashing(); + root.recomputeHashCodes(ctx); + } + + // Here, we try to infer new information given the address information + // This step also resets all addresses as they will change after one sorting step + private static boolean resolveAmbiguities(RewriterStatement root, final RuleContext ctx, List uncertainParents) { + boolean couldResolve = false; + boolean couldResolveAnyUncertainty = true; + + while (couldResolveAnyUncertainty) { + couldResolveAnyUncertainty = false; + + for (int i = 0; i < uncertainParents.size(); i++) { + List knownOrder = (List) uncertainParents.get(i).getMeta("_knownOrder"); + boolean uncertaintyRemaining = false; + + for (int j = 0; j < knownOrder.size(); j++) { + if (knownOrder.get(j) instanceof UnorderedSet) { + UnorderedSet set = (UnorderedSet) knownOrder.get(j); + + if (tryResolveUncertainties(set, ctx)) { + couldResolveAnyUncertainty = true; + couldResolve = true; + knownOrder.set(j, set.contents.get(0)); + knownOrder.addAll(j+1, set.contents.subList(1, set.contents.size())); + set.contents.forEach(el -> { + el.unsafeRemoveMeta("_addresses"); + el.unsafeRemoveMeta("_address"); + }); + } else { + uncertaintyRemaining = true; + } + } + } + + if (!uncertaintyRemaining) { + uncertainParents.remove(i); + i--; + } + } + } + + return couldResolve; + } + + private static void resetAddresses(List uncertainParents) { + for (RewriterStatement uParent : uncertainParents) { + List knownOrder = (List) uParent.getMeta("_knownOrder"); + + for (Object o : knownOrder) { + if (o instanceof UnorderedSet) { + ((UnorderedSet) o).contents.forEach(el -> { + List addresses = (List) el.getMeta("_addresses"); + + if (addresses == null) { + addresses = new ArrayList<>(); + el.unsafePutMeta("_addresses", addresses); + el.unsafeRemoveMeta("_address"); + } + + addresses.clear(); + }); + } + } + } + } + + private static boolean tryResolveUncertainties(UnorderedSet set, final RuleContext ctx) { + set.contents.sort((el1, el2) -> compare(el1, el2, ctx)); // We assume that every statement has an address, as it is uncertain + + RewriterStatement compareTo = set.contents.get(0); + // Check if ambiguity could be resolved + for (int i = 1; i < set.contents.size(); i++) { + if (compareTo.equals(set.contents.get(i))) + continue; // Ignore same instances + + if (compare(set.contents.get(i), compareTo, ctx) == 0) + return false; // Then there are still some ambiguities + + compareTo = set.contents.get(i); + } + + return true; + } + + private static List buildAddresses(RewriterStatement root, final RuleContext ctx) { + // First, catch all addresses + List elementsWithAddress = new ArrayList<>(); + recursivelyBuildAddresses(root, null, ctx, elementsWithAddress); + + // Now, we sort all addresses + for (RewriterStatement el : elementsWithAddress) { + List addresses = (List) el.getMeta("_addresses"); + Collections.sort(addresses); + String address = String.join(";", addresses); + el.unsafePutMeta("_address", address); + + if (DEBUG) + LOG.trace("Address of " + el + " :: " + address); + } + + return elementsWithAddress; + } + + private static void recursivelyBuildAddresses(RewriterStatement current, String currentAddress, final RuleContext ctx, List elementsWithAddress) { + List knownOrder = (List)current.getMeta("_knownOrder"); + List addresses = (List)current.getMeta("_addresses"); + + if (knownOrder == null) + knownOrder = Collections.emptyList(); + + + + if (DEBUG) { + LOG.trace("CUR: " + current); + LOG.trace("KnownOrder: " + knownOrder); + } + + if (addresses != null) { + if (addresses.isEmpty()) + elementsWithAddress.add(current); + + addresses.add(currentAddress); + } + + for (int i = 0; i < knownOrder.size(); i++) { + Object next = knownOrder.get(i); + String addr = currentAddress == null ? Integer.toString(i) : currentAddress + "." + i; + + if (next instanceof RewriterStatement) { + recursivelyBuildAddresses((RewriterStatement) next, addr, ctx, elementsWithAddress); + } else { + UnorderedSet set = (UnorderedSet) next; + set.contents.forEach(el -> recursivelyBuildAddresses(el, addr, ctx, elementsWithAddress)); + } + } + } + + private static String getAddress(RewriterStatement stmt) { + String addr = (String) stmt.getMeta("_address"); + + if (addr == null) + return null; + + return addr + (stmt.getMeta("_fact") == null ? "_" : "_" + stmt.getMeta("_fact")); + } + + // Expects that the children have already been sorted to the best of the current knowledge + public static int compare(RewriterStatement stmt1, RewriterStatement stmt2, final RuleContext ctx) { + int comp = toOrderString(ctx, stmt1, false).compareTo(toOrderString(ctx, stmt2, false)); + + if (comp != 0 || stmt1.equals(stmt2)) + return comp; + + List knownOrder1 = (List)stmt1.getMeta("_knownOrder"); + List knownOrder2 = (List)stmt2.getMeta("_knownOrder"); + + // Then the two statements are distinguishable by their number of unknowns + if (knownOrder1.size() != knownOrder2.size()) + return Integer.compare(knownOrder1.size(), knownOrder2.size()); + + for (int i = 0; i < knownOrder1.size() && comp == 0; i++) + comp = compare(knownOrder1.get(i), knownOrder2.get(i), ctx); + + if (comp == 0) { + String addr1 = getAddress(stmt1); + String addr2 = getAddress(stmt2); + + if (addr1 != null && addr2 != null) + return addr1.compareTo(addr2); + } + + return comp; + } + + public static int compare(Object obj1, Object obj2, final RuleContext ctx) { + boolean isStmt1 = obj1 instanceof RewriterStatement; + boolean isStmt2 = obj2 instanceof RewriterStatement; + + if (isStmt1 && !isStmt2) + return 1; + if (!isStmt1 && isStmt2) + return -1; + + if (isStmt1 && isStmt2) + return compare((RewriterStatement) obj1, (RewriterStatement) obj2, ctx); + + UnorderedSet set1 = (UnorderedSet) obj1; + UnorderedSet set2 = (UnorderedSet) obj2; + + if (set1.contents.size() < 2 || set2.contents.size() < 2) + throw new IllegalArgumentException(); // This should never happen as this would not be an unknown ordering + + if (set1.contents.size() != set2.contents.size()) + return Integer.compare(set1.contents.size(), set2.contents.size()); + + // Now, we can just choose any representant of the set + return compare(set1.contents.get(0), set2.contents.get(0), ctx); + } + + public static String toOrderString(final RuleContext ctx, RewriterStatement stmt, boolean useGlobalOrder) { + String globalOrderAddition = useGlobalOrder ? ((Integer)stmt.getMeta("_tempName")).toString() : ""; + + if (stmt.isInstruction()) { + return stmt.getResultingDataType(ctx) + ":" + stmt.trueTypedInstruction(ctx) + "[" + stmt.refCtr + "](" + stmt.getOperands().size() + ")" + globalOrderAddition + ";"; + } else { + return stmt.getResultingDataType(ctx) + ":" + (stmt.isLiteral() ? "L:" + stmt.getLiteral() : "V") + "[" + stmt.refCtr + "](0)" + globalOrderAddition + ";"; + } + } + + + + static class UnorderedSet { + List contents; + + public UnorderedSet(List contents) { + this.contents = contents; + } + + public String toString() { + return contents.toString(); + } + } +} \ No newline at end of file diff --git a/src/main/java/org/apache/sysds/hops/rewriter/assertions/RewriterAssertionUtils.java b/src/main/java/org/apache/sysds/hops/rewriter/assertions/RewriterAssertionUtils.java new file mode 100644 index 00000000000..d6b15d25b5b --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/assertions/RewriterAssertionUtils.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter.assertions; + +import org.apache.sysds.hops.rewriter.RewriterInstruction; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.RuleContext; + +public class RewriterAssertionUtils { + public static RewriterAssertions buildImplicitAssertions(RewriterStatement root, final RuleContext ctx) { + RewriterAssertions assertions = new RewriterAssertions(ctx); + buildImplicitAssertions(root, assertions, ctx); + return assertions; + } + + public static void buildImplicitAssertions(RewriterStatement root, RewriterAssertions assertions, final RuleContext ctx) { + root.forEachPreOrder(cur -> { + buildImplicitAssertion(cur, assertions, root, ctx); + return true; + }, false); + } + + public static boolean buildImplicitAssertion(RewriterStatement stmt, RewriterAssertions assertions, RewriterStatement exprRoot, final RuleContext ctx) { + if (!stmt.isInstruction()) + return false; + + switch (stmt.trueInstruction()) { + case "%*%": + assertions.addEqualityAssertion(stmt.getChild(0).getNCol(), stmt.getChild(1).getNRow(), exprRoot); + return true; + case "diag": + assertions.addEqualityAssertion(stmt.getChild(0).getNCol(), stmt.getChild(0).getNRow(), exprRoot); + return true; + case "RBind": + assertions.addEqualityAssertion(stmt.getChild(0).getNCol(), stmt.getChild(1).getNCol(), exprRoot); + return true; + case "CBind": + assertions.addEqualityAssertion(stmt.getChild(0).getNRow(), stmt.getChild(1).getNRow(), exprRoot); + return true; + case "1-*": + assertions.addEqualityAssertion(stmt.getChild(0).getNCol(), stmt.getChild(1).getNCol(), exprRoot); + assertions.addEqualityAssertion(stmt.getChild(0).getNRow(), stmt.getChild(1).getNRow(), exprRoot); + return true; + case "+*": + case "-*": + assertions.addEqualityAssertion(stmt.getChild(0).getNCol(), stmt.getChild(2).getNCol(), exprRoot); + assertions.addEqualityAssertion(stmt.getChild(0).getNRow(), stmt.getChild(2).getNRow(), exprRoot); + return true; + } + + switch (stmt.trueTypedInstruction(ctx)) { + case "trace(MATRIX)": + assertions.addEqualityAssertion(stmt.getChild(0).getNRow(), stmt.getChild(0).getNCol(), exprRoot); + return true; + case "cast.FLOAT(MATRIX)": + assertions.addEqualityAssertion(stmt.getChild(0).getNRow(), stmt.getChild(0).getNCol(), exprRoot); + assertions.addEqualityAssertion(stmt.getChild(0).getNRow(), RewriterStatement.literal(ctx, 1L), exprRoot); + return true; + } + + if (((RewriterInstruction) stmt).hasProperty("ElementWiseInstruction", ctx)) { + if (stmt.getChild(0).getResultingDataType(ctx).equals("MATRIX") + && stmt.getChild(1).getResultingDataType(ctx).equals("MATRIX")) { + assertions.addEqualityAssertion(stmt.getChild(0).getNCol(), stmt.getChild(1).getNCol(), exprRoot); + assertions.addEqualityAssertion(stmt.getChild(0).getNRow(), stmt.getChild(1).getNRow(), exprRoot); + return true; + } + } + + return false; + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/assertions/RewriterAssertions.java b/src/main/java/org/apache/sysds/hops/rewriter/assertions/RewriterAssertions.java new file mode 100644 index 00000000000..7da9da401dd --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/assertions/RewriterAssertions.java @@ -0,0 +1,751 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter.assertions; + +import org.apache.commons.lang3.NotImplementedException; +import org.apache.commons.lang3.function.TriFunction; +import org.apache.sysds.hops.rewriter.RewriterInstruction; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.RuleContext; +import scala.Tuple2; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.UUID; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class RewriterAssertions { + private final RuleContext ctx; + private Map assertionMatcher = new HashMap<>(); + // Tracks which statements are part of which assertions + private Map> partOfAssertion = new HashMap<>(); + private Set allAssertions = new HashSet<>(); + + public RewriterAssertions(final RuleContext ctx) { + this.ctx = ctx; + } + + public RewriterAssertions nestedCopyOrInject(Map createdObjects, TriFunction injector, RewriterStatement parent) { + RewriterAssertions out = new RewriterAssertions(ctx); + out.allAssertions = allAssertions.stream().map(assertion -> { + Set newSet = new HashSet<>(assertion.set.size()); + RewriterAssertion mapped = RewriterAssertion.from(newSet); + + if (assertion.stmt != null) { + mapped.stmt = assertion.stmt.nestedCopyOrInject(createdObjects, injector, parent, -1); + out.assertionMatcher.put(mapped.stmt, mapped); + } + + for (RewriterStatement entry : assertion.set) { + RewriterStatement newStmt = entry.nestedCopyOrInject(createdObjects, injector, parent, -1); + newSet.add(newStmt); + out.assertionMatcher.put(newStmt, mapped); + } + + if (assertion.backRef != null) { + mapped.backRef = assertion.backRef.nestedCopyOrInject(createdObjects, injector, parent, -1); + out.assertionMatcher.put(mapped.backRef, mapped); + } + + return mapped; + }).collect(Collectors.toSet()); + + for (RewriterAssertion assertion : out.allAssertions) { + forEachUniqueElementInAssertion(assertion, el -> { + Set partOfAssertions = out.partOfAssertion.get(el); + + if (partOfAssertions == null) { + partOfAssertions = new HashSet<>(); + out.partOfAssertion.put(el, partOfAssertions); + } + + partOfAssertions.add(assertion); + }); + } + + return out; + } + + public static RewriterAssertions copy(RewriterAssertions old, Map createdObjects, boolean removeOthers) { + RewriterAssertions newAssertions = new RewriterAssertions(old.ctx); + + Map mappedAssertions = new HashMap<>(); + + newAssertions.allAssertions = old.allAssertions.stream().map(assertion -> { + Set newSet = new HashSet<>(); + List backRefsToCheck = new ArrayList<>(); + + for (RewriterStatement oldEl : assertion.set) { + RewriterStatement cpy = createdObjects.get(oldEl); + + if (cpy == null) + cpy = oldEl.nestedCopyOrInject(createdObjects, stmt -> null); + + if (cpy.isInstruction() && cpy.trueInstruction().startsWith("_backRef.")) + backRefsToCheck.add(cpy); + + newSet.add(cpy); + } + + List backRefsToRemove = Collections.emptyList(); + + if (!backRefsToCheck.isEmpty()) { + backRefsToRemove = new ArrayList<>(); + + for (RewriterStatement backRef : backRefsToCheck) { + System.out.println("Candidate: " + backRef); + if (newSet.contains(backRef.getMeta("_backRef"))) { + newSet.remove(backRef); + backRefsToRemove.add(backRef); + } + } + } + + if (newSet.size() < 2) { + System.out.println("Removing E-Class: " + assertion); + return null; + } + + RewriterAssertion mapped = RewriterAssertion.from(newSet); + if (assertion.stmt != null) { + mapped.stmt = createdObjects.get(assertion.stmt); + + if (!backRefsToRemove.isEmpty()) { + mapped.stmt.getChild(0).getOperands().removeAll(backRefsToRemove); + } + } + if (assertion.backRef != null) + mapped.backRef = createdObjects.get(assertion.backRef); + mappedAssertions.put(assertion, mapped); + return mapped; + }).filter(Objects::nonNull).collect(Collectors.toSet()); + + for (Map.Entry> e : old.partOfAssertion.entrySet()) { + RewriterStatement k = createdObjects.get(e.getKey()); + + if (k == null) + continue; + + Set v = e.getValue(); + Set newV = v.stream().map(mappedAssertions::get).filter(Objects::nonNull).collect(Collectors.toSet()); + + newAssertions.partOfAssertion.put(k, newV); + } + + if (removeOthers) { + old.assertionMatcher.forEach((k, v) -> { + RewriterStatement newK = createdObjects.get(k); + + if (newK == null) + return; + + RewriterAssertion newV = mappedAssertions.get(v); + + if (newV == null) + return; + + newAssertions.assertionMatcher.put(newK, newV); + }); + } else { + old.assertionMatcher.forEach((k, v) -> { + RewriterStatement newK = createdObjects.getOrDefault(k, k); + RewriterAssertion newV = mappedAssertions.get(v); + + if (newV == null) + return; + + newAssertions.assertionMatcher.put(newK, newV); + }); + } + + return newAssertions; + } + + public void forEachAssertionContents(BiConsumer consumer) { + allAssertions.forEach(assertion -> assertion.set.forEach(set -> consumer.accept(set, new RewriterStatement.RewriterPredecessor(this, assertion)))); + } + + public void updateAssertionContents(Function f) { + for (RewriterAssertion assertion : allAssertions) { + Set toRemove = new HashSet<>(); + Map toReplace = new HashMap<>(); + + for (RewriterStatement stmt : assertion.set) { + RewriterStatement mNew = f.apply(stmt); + if (mNew != stmt) { + toRemove.add(stmt); + toReplace.put(stmt, mNew); + } + } + + if (toReplace.isEmpty()) + continue; + + toRemove.forEach(assertion.set::remove); + assertion.set.addAll(toReplace.values()); + + if (assertion.stmt != null) { + List argList = assertion.stmt.getChild(0).getOperands(); + for (int i = 0; i < argList.size(); i++) { + RewriterStatement replaced = toReplace.get(argList.get(i)); + + if (replaced != null) + argList.set(i, replaced); + } + } + + // Now, we have to recompute partOfAssertion for removed and newly added elements + for (RewriterStatement removed : toRemove) { + removed.forEachPreOrder((cur, pred) -> { + Set set = partOfAssertion.get(cur); + + if (set != null) + set.remove(assertion); + + return true; + }, false); + } + + forEachUniqueElementInAssertion(assertion, cur -> { + partOfAssertion.compute(cur, (k, v) -> { + if (v == null) + v = new HashSet<>(); + + v.add(assertion); + return v; + }); + }); + } + } + + public Stream> streamOfContents() { + return allAssertions.stream().flatMap(assertion -> { + if (assertion.stmt != null) { + if (assertion.backRef != null) + return Stream.of(new Tuple2<>(assertion.stmt, new RewriterStatement.RewriterPredecessor(this, assertion)), new Tuple2<>(assertion.backRef, new RewriterStatement.RewriterPredecessor(this, assertion))); + return Stream.of(new Tuple2<>(assertion.stmt, new RewriterStatement.RewriterPredecessor(this, assertion))); + } else { + return assertion.set.stream().map(stmt -> new Tuple2<>(stmt, new RewriterStatement.RewriterPredecessor(this, assertion))); + } + }); + } + + public void replaceAssertionContent(RewriterStatement oldStmt, RewriterStatement newStmt, RewriterAssertion assertion) { + if (oldStmt == assertion.stmt) { + // Then we will remove this assertion + allAssertions.remove(assertion); + assertion.set.forEach(s -> this.assertionMatcher.remove(s)); + } + + assertion.set.remove(oldStmt); + assertion.set.add(newStmt); + + if (assertion.stmt != null) { + assertion.stmt.getChild(); + } + + throw new NotImplementedException(); + } + + public void resolveExistingAssertions(RewriterStatement root) { + List backRefs = new ArrayList<>(); + root.forEachPreOrder(stmt -> { + if (stmt.isEClass()) { + if (!assertionMatcher.containsKey(stmt)) { + RewriterAssertion assertion = new RewriterAssertion(); + assertion.stmt = stmt; + assertion.set = new HashSet<>(stmt.getChild(0).getOperands()); + allAssertions.add(assertion); + + for (RewriterStatement eStmt : assertion.set) + assertionMatcher.put(eStmt, assertion); + + forEachUniqueElementInAssertion(assertion, cur -> { + partOfAssertion.compute(cur, (k, v) -> { + if (v == null) + v = new HashSet<>(); + + v.add(assertion); + return v; + }); + }); + } + } else if (stmt.isInstruction() && stmt.trueInstruction().equals("_backRef")) { + backRefs.add(stmt); + } + + return true; + }, false); + + for (RewriterStatement backRef : backRefs) { + RewriterAssertion assertion = getAssertionObj(backRef); + if (assertion != null) { + assertion.backRef = backRef; + } else { + // TODO + } + } + } + + public boolean addEqualityAssertion(RewriterStatement stmt1, RewriterStatement stmt2, RewriterStatement exprRoot) { + if (stmt1 == null || stmt2 == null) + throw new IllegalArgumentException("Cannot add an equality assertion to a null reference!"); + + if (stmt1 == stmt2 || (stmt1.isLiteral() && stmt2.isLiteral() && stmt1.getLiteral().equals(stmt2.getLiteral()))) + return false; + + if (stmt1.isLiteral() && stmt2.isLiteral() && !stmt1.getLiteral().equals(stmt2.getLiteral())) + throw new IllegalArgumentException("Cannot assert equality of two different literals!"); + + if (stmt1.hashCode() == 0) + throw new IllegalArgumentException(); + + RewriterStatement e1 = stmt1; + RewriterStatement e2 = stmt2; + RewriterAssertion stmt1Assertions = assertionMatcher.get(e1); + RewriterAssertion stmt2Assertions = assertionMatcher.get(e2); + + if (stmt1.isLiteral() || stmt2.isLiteral()) { + RewriterStatement literal = stmt1.isLiteral() ? stmt1 : stmt2; + + if (stmt1Assertions != null) { + Optional existingLiteral = stmt1Assertions.getLiteral(); + + if (existingLiteral.isPresent()) { + if (literal.getLiteral().equals(existingLiteral.get().getLiteral())) + return false; + else + throw new IllegalArgumentException("Cannot assert equality of two different literals!"); + } + } + + if (stmt2Assertions != null) { + Optional existingLiteral = stmt2Assertions.getLiteral(); + + if (existingLiteral.isPresent()) { + if (literal.getLiteral().equals(existingLiteral.get().getLiteral())) + return false; + else + throw new IllegalArgumentException("Cannot assert equality of two different literals!"); + } + } + + if (stmt1Assertions != null && stmt2Assertions != null) { + // Here, we need to check if both assertions already contain a literal + // If the literals are identical, we need to deduplicate, otherwise throw an error + Optional existingLiteral1 = stmt1Assertions.getLiteral(); + Optional existingLiteral2 = stmt2Assertions.getLiteral(); + + if (existingLiteral1.isPresent() && existingLiteral2.isPresent()) { + if (!existingLiteral1.get().getLiteral().equals(existingLiteral2.get().getLiteral())) + throw new IllegalArgumentException("Cannot assert equality of two different literal!"); + } + } + } + + if (stmt1Assertions == stmt2Assertions) { + if (stmt1Assertions == null) { + // Then we need to introduce a new equality set + Set newSet = new HashSet<>(); + newSet.add(e1); + newSet.add(e2); + + RewriterAssertion newAssertion = RewriterAssertion.from(newSet); + + assertionMatcher.put(e1, newAssertion); + assertionMatcher.put(e2, newAssertion); + + allAssertions.add(newAssertion); + + resolveCyclicAssertions(newAssertion); + + forEachUniqueElementInAssertion(newAssertion, cur -> { + partOfAssertion.compute(cur, (k, v) -> { + if (v == null) + v = new HashSet<>(); + + v.add(newAssertion); + return v; + }); + }); + + return true; + } + + return false; // The assertion already exists + } + + if (stmt1Assertions == null || stmt2Assertions == null) { + boolean assert1 = stmt1Assertions == null; + RewriterStatement toAssert = assert1 ? stmt1 : stmt2; + RewriterAssertion existingAssertion = assert1 ? stmt2Assertions : stmt1Assertions; + existingAssertion.set.add(toAssert); + assertionMatcher.put(assert1 ? e1 : e2, existingAssertion); + if (existingAssertion.stmt != null) + updateInstance(existingAssertion.stmt.getChild(0), existingAssertion.set); + + resolveCyclicAssertions(existingAssertion); + + toAssert.forEachPreOrder(cur -> { + partOfAssertion.compute(cur, (k, v) -> { + if (v == null) + v = new HashSet<>(); + + v.add(existingAssertion); + return v; + }); + return true; + }, false); + + return true; + } + + // Otherwise we need to merge the assertions + + // For that, we choose the smaller set as we will need fewer operations + if (stmt1Assertions.set.size() > stmt2Assertions.set.size()) { + RewriterAssertion tmp = stmt1Assertions; + stmt1Assertions = stmt2Assertions; + stmt2Assertions = tmp; + } + + stmt2Assertions.set.addAll(stmt1Assertions.set); + allAssertions.remove(stmt1Assertions); + if (stmt2Assertions.stmt != null) + updateInstance(stmt2Assertions.stmt.getChild(0), stmt2Assertions.set); + + for (RewriterStatement stmt : stmt1Assertions.set) + assertionMatcher.put(stmt, stmt2Assertions); + + if (stmt1Assertions.stmt != null) + assertionMatcher.put(stmt1Assertions.stmt, stmt2Assertions); // Only temporary + + resolveCyclicAssertions(stmt2Assertions); + stmt2Assertions.deduplicate(); + + final RewriterAssertion assertionToRemove = stmt1Assertions; + final RewriterAssertion assertionToExtend = stmt2Assertions; + forEachUniqueElementInAssertion(stmt1Assertions, cur -> { + Set v = partOfAssertion.get(cur); + + if (v == null) + throw new IllegalArgumentException(cur.toString()); + + v.remove(assertionToRemove); + v.add(assertionToExtend); + }); + + if (assertionToRemove.stmt != null) { + exprRoot.forEachPreOrder(cur -> { + for (int i = 0; i < cur.getOperands().size(); i++) { + RewriterStatement child = cur.getChild(i); + if (child == assertionToRemove.stmt) + cur.getOperands().set(i, assertionToExtend.getEClassStmt(ctx, this)); + } + return true; + }, false); + } + + return true; + } + + public static RewriterStatement updateMergedEClasses(RewriterStatement exprRoot, Map legacyEClasses) { + exprRoot.forEachPreOrder(cur -> { + for (int i = 0; i < cur.getOperands().size(); i++) { + RewriterStatement child = cur.getChild(i); + if (child.isEClass()) { + RewriterStatement mapped = legacyEClasses.get(child); + if (mapped != null) + cur.getOperands().set(i, mapped); + } + } + return true; + }, false); + + if (exprRoot.isEClass()) { + RewriterStatement mapped = legacyEClasses.get(exprRoot); + if (mapped != null) + return mapped; + } + + return exprRoot; + } + + private void forEachUniqueElementInAssertion(RewriterAssertion assertion, Consumer consumer) { + Set visited = new HashSet<>(); + for (RewriterStatement eq : assertion.set) { + eq.forEachPreOrderWithDuplicates(cur -> { + if (!visited.add(cur)) + return false; + + consumer.accept(cur); + return true; + }); + } + } + + // Replace cycles with _backRef() + private void resolveCyclicAssertions(RewriterAssertion assertion) { + if (assertion.stmt == null) + return; + + RewriterStatement backref = assertion.getBackRef(ctx, this); + + for (RewriterStatement eq : assertion.set) { + eq.forEachPreOrder((cur, pred) -> { + for (int i = 0; i < cur.getOperands().size(); i++) + if (!cur.getChild(i).isLiteral() && getAssertionObj(cur.getChild(i)) == assertion) + cur.getOperands().set(i, backref); + + return true; + }, false); + } + } + + public RewriterAssertion getAssertionObj(RewriterStatement stmt) { + return assertionMatcher.get(stmt); + } + + public Set getAssertions(RewriterStatement stmt) { + RewriterAssertion set = assertionMatcher.get(stmt); + return set == null ? Collections.emptySet() : set.set; + } + + public RewriterStatement getAssertionStatement(RewriterStatement stmt, RewriterStatement parent) { + RewriterAssertion set = assertionMatcher.get(stmt); + + if (set == null || set.getEClassStmt(ctx, this).getChild(0) == parent) { + return stmt; + } + + if (parent != null && parent != set.getEClassStmt(ctx, this).getChild(0) && partOfAssertion.getOrDefault(parent, Collections.emptySet()).contains(set)) + return set.getBackRef(ctx, this); + + return set.getEClassStmt(ctx, this); + } + + public RewriterStatement update(RewriterStatement root) { + RewriterStatement eClass = getAssertionStatement(root, null); + + if (eClass == null) + eClass = root; + else if (root.getMeta("_assertions") != null) + eClass.unsafePutMeta("_assertions", root.getMeta("_assertions")); + + updateRecursively(eClass); + + return eClass; + } + + // This removes E-Classes that are not actually E-Classes like _EClass(argList(nrow(A), nrow(A))), or _EClass(argList(nrow(A), _backRef.INT())) + public RewriterStatement cleanupEClasses(RewriterStatement expressionRoot) { + Set toRemoveList = new HashSet<>(); + Map toRemove = new HashMap<>(); + + for (RewriterAssertion assertion : allAssertions) { + int previousSize = assertion.set.size(); + if (assertion.stmt != null) { + // Eliminate top-level back-refs + assertion.set.removeIf(el -> el.isInstruction() && el.trueInstruction().startsWith("_backRef") && el.getMeta("_backRef").equals(assertion.stmt)); + } + + if (assertion.set.size() < 2) { + toRemoveList.add(assertion); + + if (assertion.stmt != null) + toRemove.put(assertion.stmt, assertion.set.stream().findFirst().get()); + } + + if (previousSize != assertion.set.size() && assertion.stmt != null) { + // Then we need to update the EClass + assertion.stmt.getChild(0).getOperands().removeIf(el -> !assertion.set.contains(el)); + + if (assertion.stmt.getChild(0).getOperands().size() != assertion.set.size()) { + // Then there are still duplicates which we need to rule out + Set visited = new HashSet<>(); + List eItems = assertion.stmt.getChild(0).getOperands(); + for (int i = 0; i < eItems.size(); i++) { + if (!visited.add(eItems.get(i))) + eItems.remove(i--); + } + } + } + } + + if (!toRemoveList.isEmpty()) { + allAssertions.removeAll(toRemoveList); + + if (!toRemove.isEmpty()) { + if (expressionRoot.isEClass()) { + RewriterStatement mNew = toRemove.get(expressionRoot); + + if (mNew != null) + expressionRoot = mNew; + } + + expressionRoot.forEachPostOrder((cur, pred) -> { + cur.allChildren().forEach(t -> { + if (t._1.isEClass()) { + RewriterStatement mNew = toRemove.get(t._1); + if (mNew != null) { + if (t._2.isOperand()) { + cur.getOperands().set(t._2.getIndex(), mNew); + } else if (t._2.isMetaObject()) { + cur.unsafePutMeta(t._2.getMetaKey(), mNew); + } + } + } + }); + }, true); + } + } + + return expressionRoot; + } + + private void updateRecursively(RewriterStatement cur) { + for (int i = 0; i < cur.getOperands().size(); i++) { + RewriterStatement child = cur.getChild(i); + RewriterStatement eClass = getAssertionStatement(child, cur); + + if (eClass != child) + cur.getOperands().set(i, eClass); + + updateRecursively(cur.getChild(i)); + } + } + + @Override + public String toString() { + return allAssertions.toString(); + } + + private void updateInstance(RewriterStatement stmt, Set set) { + if (stmt != null) { + stmt.getOperands().clear(); + stmt.getOperands().addAll(set); + } + } + + public static class RewriterAssertion { + Set set; + RewriterStatement stmt; + RewriterStatement backRef; // The back-reference to this assertion + + public Collection getEClass() { + return set; + } + + public RewriterStatement getEClassStmt(final RuleContext ctx, RewriterAssertions assertions) { + if (stmt != null) + return stmt; + + RewriterStatement argList = new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("argList").withOps(set.toArray(RewriterStatement[]::new)); + stmt = new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("_EClass").withOps(argList); + stmt.consolidate(ctx); + assertions.assertionMatcher.put(stmt, this); + assertions.partOfAssertion.compute(stmt, (k, v) -> { + if (v == null) + v = new HashSet<>(); + + v.add(this); + return v; + }); + assertions.partOfAssertion.compute(argList, (k, v) -> { + if (v == null) + v = new HashSet<>(); + + v.add(this); + return v; + }); + assertions.resolveCyclicAssertions(this); + return stmt; + } + + public RewriterStatement getBackRef(final RuleContext ctx, RewriterAssertions assertions) { + if (backRef != null) + return backRef; + + backRef = new RewriterInstruction() + .as(UUID.randomUUID().toString()) + .withInstruction("_backRef." + getEClassStmt(ctx, assertions).getResultingDataType(ctx)) + .consolidate(ctx); + backRef.unsafePutMeta("_backRef", getEClassStmt(ctx, assertions)); + assertions.partOfAssertion.compute(backRef, (k, v) -> { + if (v == null) + v = new HashSet<>(); + + v.add(this); + return v; + }); + return backRef; + } + + // Returns a literal if available, otherwise null + public Optional getLiteral() { + return set.stream().filter(RewriterStatement::isLiteral).findFirst(); + } + + // Removes duplicate entries (e.g. duplicate literals etc.) + public void deduplicate() { + if (stmt != null && stmt.getChild(0).getOperands().size() != set.size()) { + List operands = stmt.getChild(0).getOperands(); + Set elementTracker = new HashSet<>(); + + for (int i = 0; i < operands.size(); i++) { + RewriterStatement el = operands.get(i); + + if (elementTracker.contains(el)) { + operands.remove(i); + i--; + } else { + elementTracker.add(el); + } + } + } + } + + @Override + public String toString() { + if (stmt != null) + return stmt + " -- " + System.identityHashCode(this); + + return set.toString() + " -- " + System.identityHashCode(this); + } + + static RewriterAssertion from(Set set) { + RewriterAssertion a = new RewriterAssertion(); + a.set = set; + return a; + } + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/codegen/CodeGenCondition.java b/src/main/java/org/apache/sysds/hops/rewriter/codegen/CodeGenCondition.java new file mode 100644 index 00000000000..f97f6360235 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/codegen/CodeGenCondition.java @@ -0,0 +1,627 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter.codegen; + +import org.apache.sysds.hops.rewriter.RewriterDataType; +import org.apache.sysds.hops.rewriter.rule.RewriterRule; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.apache.sysds.hops.rewriter.utils.CodeGenUtils; +import scala.Tuple2; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.Set; +import java.util.stream.Collectors; + +public class CodeGenCondition { + public enum ConditionType { + DATA_TYPE, VALUE_TYPE, UNIQUE_PARENTS, LITERAL, OP_CLASS, OP_CODE, NUM_INPUTS, ELSE + } + + public enum ConditionDataType { + SCALAR, MATRIX + } + + private ConditionType conditionType; + private Object conditionValue; + private List rulesIf; + private List applyAnyway; + private List relativeChildPath; + private RewriterStatement representant; + + private CodeGenCondition(ConditionType cType, Object cValue, List relativeChildPath, RewriterStatement representant, final RuleContext ctx) { + conditionType = cType; + conditionValue = cValue; + rulesIf = new ArrayList<>(); + applyAnyway = new ArrayList<>(); + this.relativeChildPath = relativeChildPath; + this.representant = representant; + + if (conditionType != ConditionType.ELSE) + buildConditionCheck(new StringBuilder(), ctx); + } + + public static List buildCondition(List rules, int maxNumRules, final RuleContext ctx) { + return buildCondition(rules, 3, maxNumRules, ctx); + } + + public static List buildCondition(List rules, int maxDepth, int maxNumRules, final RuleContext ctx) { + if (rules.isEmpty()) + return Collections.emptyList(); + List transformed = rules.stream().map(rule -> new Tuple2(rule, rule.getStmt1())).collect(Collectors.toList()); + List out = populateLayerRecursively(transformed, Collections.emptyList(), new LinkedList<>(), maxDepth, maxNumRules, ctx); + List cond = out.stream().filter(o -> o instanceof CodeGenCondition).map(o -> ((CodeGenCondition)o)).collect(Collectors.toList()); + return cond.isEmpty() ? List.of(conditionalElse(transformed, Collections.emptyList(), ((Tuple2) transformed.get(0))._2, ctx)) : cond; + } + + private static List populateLayerRecursively(List rules, List relativeChildPath, Queue, List>> queue, int maxDepth, int maxNumRules, final RuleContext ctx) { + if (rules.size() <= maxNumRules || maxDepth == 0) + return rules; + + List out = populateDataTypeLayer(rules, relativeChildPath, ctx); + + for (int i = 0; i < out.size(); i++) { + CodeGenCondition c = (CodeGenCondition) out.get(i); + + if (c.rulesIf.size() <= maxNumRules) + continue; + + c.rulesIf = populateOpClassLayer(c.rulesIf, relativeChildPath, ctx); + + for (int j = 0; j < c.rulesIf.size(); j++) { + CodeGenCondition c2 = (CodeGenCondition) c.rulesIf.get(j); + + if (c2.rulesIf.size() <= maxNumRules) + continue; + + c2.rulesIf = populateOpCodeLayer(c2.rulesIf, relativeChildPath, ctx); + + for (int k = 0; k < c2.rulesIf.size(); k++) { + CodeGenCondition c3 = (CodeGenCondition) c2.rulesIf.get(k); + + if (c3.rulesIf.size() <= maxNumRules) + continue; + + c3.rulesIf = populateInputSizeLayer(c3.rulesIf, relativeChildPath, ctx); + + for (int l = 0; l < c3.rulesIf.size(); l++) { + CodeGenCondition c4 = (CodeGenCondition) c3.rulesIf.get(l); + + if (((Tuple2) c4.rulesIf.get(0))._2 == null) + continue; + + final int maxIndex = ((Tuple2) c4.rulesIf.get(0))._2.getOperands().size(); + Set activeRules = c4.rulesIf.stream().map(o -> ((Tuple2) o)._1).collect(Collectors.toSet()); + Queue, List>> mQueue = new LinkedList<>(); + + for (Tuple2, List> t : queue) { + List mObj = new ArrayList<>(); + for (Object o : t._1) { + if (activeRules.contains(((Tuple2) o)._1)) + mObj.add(o); + } + + if (!mObj.isEmpty()) + mQueue.add(new Tuple2<>(mObj, t._2)); + } + + for (int idx = 0; idx < maxIndex; idx++) { + final int mIdx = idx; + final List newRelativeChildPath = new ArrayList<>(relativeChildPath); + newRelativeChildPath.add(mIdx); + List mList = new ArrayList<>(); + mQueue.add(new Tuple2<>(mList, newRelativeChildPath)); + + c4.rulesIf.forEach(o -> { + Tuple2 t = (Tuple2) o; + mList.add(new Tuple2(t._1, (t._2 == null ? null : (t._2.getOperands().isEmpty() ? null : t._2.getChild(mIdx))))); + }); + } + + if (!mQueue.isEmpty()) { + Tuple2, List> next = mQueue.poll(); + c4.rulesIf = populateLayerRecursively(next._1, next._2(), mQueue, maxDepth-1, maxNumRules, ctx); + } + } + } + } + } + + return out; + } + + private static boolean validateSizeMaintenance(List rules, List generatedConditions) { + int origSize = rules.size(); + int newSize = generatedConditions.stream().mapToInt(o -> ((CodeGenCondition)o).rulesIf.size()).sum(); + return origSize <= newSize; + } + + private static List populateDataTypeLayer(List rules, List relativeChildPath, final RuleContext ctx) { + List conds = new ArrayList<>(); + List> defer = new ArrayList<>(); + + //System.out.println("====="); + + for (Object o : rules) { + Tuple2 t = (Tuple2) o; + + if (t._2 == null) { + defer.add(t); + continue; + } + + if (!conds.stream().anyMatch(cond -> ((CodeGenCondition) cond).insertIfMatches(t, ctx))) { + CodeGenCondition cond = CodeGenCondition.conditionalDataType(t._2, relativeChildPath, t._2, ctx); + cond.insertIfMatches(t, ctx); + conds.add(cond); + StringBuilder sb = new StringBuilder(); + cond.buildConditionCheck(sb, ctx); + } else { + CodeGenCondition condse = (CodeGenCondition) conds.stream().filter(cond -> ((CodeGenCondition) cond).matchesCondition(t._1, t._2, ctx)).findFirst().get(); + StringBuilder sb = new StringBuilder(); + condse.buildConditionCheck(sb, ctx); + } + } + + if (!defer.isEmpty()) { + conds.add(CodeGenCondition.conditionalElse(new ArrayList<>(defer), relativeChildPath, null, ctx)); + } + + for (Tuple2 deferred : defer) { + for (Object obj : conds) + ((CodeGenCondition) obj).rulesIf.add(deferred); + } + + if (!validateSizeMaintenance(rules, conds)) + throw new IllegalArgumentException(); + + return conds; + } + + private static List populateOpClassLayer(List l, List relativeChildPath, final RuleContext ctx) { + List conds = new ArrayList<>(); + List remaining = new ArrayList<>(); + List> defer = new ArrayList<>(); + + for (Object o : l) { + try { + Tuple2 t = (Tuple2) o; + + if (t._2 == null || (t._2 instanceof RewriterDataType && !t._2.isLiteral())) { + defer.add(t); + continue; + } + + if (canGenerateOpClassCheck(t._2, ctx)) { + if (!conds.stream().anyMatch(cond -> ((CodeGenCondition) cond).insertIfMatches(t, ctx))) { + CodeGenCondition cond = CodeGenCondition.conditionalOpClass(t._2, relativeChildPath, t._2, ctx); + cond.insertIfMatches(t, ctx); + conds.add(cond); + } + } else { + remaining.add(t); + } + } catch (Exception e) { + e.printStackTrace(); + } + } + + remaining.addAll(defer); + + for (Tuple2 deferred : defer) { + for (Object obj : conds) + ((CodeGenCondition) obj).rulesIf.add(deferred); + } + + if (!remaining.isEmpty()) { + conds.add(CodeGenCondition.conditionalElse(remaining, relativeChildPath, ((Tuple2) remaining.get(0))._2, ctx)); + } + + return conds; + } + + private static List populateOpCodeLayer(List l, List relativeChildPath, final RuleContext ctx) { + List conds = new ArrayList<>(); + List remaining = new ArrayList<>(); + List> defer = new ArrayList<>(); + + for (Object o : l) { + Tuple2 t = (Tuple2) o; + + if (t._2 == null || (t._2 instanceof RewriterDataType && !t._2.isLiteral())) { + defer.add(t); + continue; + } + + if (canGenerateOpCodeCheck(t._2, ctx)) { + if (!conds.stream().anyMatch(cond -> ((CodeGenCondition) cond).insertIfMatches(t, ctx))) { + CodeGenCondition cond = CodeGenCondition.conditionalOpCode(t._2, relativeChildPath, t._2, ctx); + cond.insertIfMatches(t, ctx); + conds.add(cond); + } + } else if (t._2 instanceof RewriterDataType && !t._2.isLiteral()) { + // Then we must add it to all conditions + for (Object obj : conds) + ((CodeGenCondition) obj).rulesIf.add(t); + } else { + remaining.add(t); + } + } + + remaining.addAll(defer); + + for (Tuple2 deferred : defer) { + for (Object obj : conds) + ((CodeGenCondition) obj).rulesIf.add(deferred); + } + + if (!remaining.isEmpty()) { + conds.add(CodeGenCondition.conditionalElse(remaining, relativeChildPath, ((Tuple2) remaining.get(0))._2, ctx)); + } + + if (!validateSizeMaintenance(l, conds)) + throw new IllegalArgumentException(); + + return conds; + } + + private static List populateInputSizeLayer(List l, List relativeChildPath, final RuleContext ctx) { + List conds = new ArrayList<>(); + List remaining = new ArrayList<>(); + List> defer = new ArrayList<>(); + + for (Object o : l) { + Tuple2 t = (Tuple2) o; + + if (t._2 == null || (t._2 instanceof RewriterDataType && !t._2.isLiteral())) { + defer.add(t); + continue; + } + + if (canGenerateInputSizeCheck(t._2, ctx)) { + if (!conds.stream().anyMatch(cond -> ((CodeGenCondition) cond).insertIfMatches(t, ctx))) { + CodeGenCondition cond = CodeGenCondition.conditionalInputSize(t._2.getOperands().size(), relativeChildPath, t._2, ctx); + cond.insertIfMatches(t, ctx); + conds.add(cond); + } + } else if (t._2 instanceof RewriterDataType && !t._2.isLiteral()) { + // Then we must add it to all conditions + for (Object obj : conds) + ((CodeGenCondition) obj).rulesIf.add(t); + } else { + remaining.add(t); + } + } + + remaining.addAll(defer); + + for (Tuple2 deferred : defer) { + for (Object obj : conds) + ((CodeGenCondition) obj).rulesIf.add(deferred); + } + + if (!remaining.isEmpty()) { + conds.add(CodeGenCondition.conditionalElse(remaining, relativeChildPath, ((Tuple2) remaining.get(0))._2, ctx)); + } + + if (!validateSizeMaintenance(l, conds)) + throw new IllegalArgumentException(); + + return conds; + } + + public String getVarName() { + if (relativeChildPath.isEmpty()) + return "hi"; + return "hi_" + relativeChildPath.stream().map(Object::toString).collect(Collectors.joining("_")); + } + + public void buildConditionCheck(StringBuilder sb, final RuleContext ctx) { + switch (conditionType) { + case DATA_TYPE: + sb.append("hi"); + if (!relativeChildPath.isEmpty()) { + sb.append("_"); + sb.append(relativeChildPath.stream().map(Object::toString).collect(Collectors.joining("_"))); + } + sb.append(".getDataType() == "); + sb.append(CodeGenUtils.getReturnType(getDataType() == ConditionDataType.MATRIX ? "MATRIX" : "FLOAT")[0]); + break; + case OP_CLASS: + sb.append("hi"); + if (!relativeChildPath.isEmpty()) { + sb.append("_"); + sb.append(relativeChildPath.stream().map(Object::toString).collect(Collectors.joining("_"))); + } + sb.append(" instanceof " + CodeGenUtils.getOpClass(representant, ctx)); + break; + case OP_CODE: + String hopVar = "hi"; + if (!relativeChildPath.isEmpty()) { + hopVar += "_"; + hopVar += relativeChildPath.stream().map(Object::toString).collect(Collectors.joining("_")); + } + + String specialInstr = CodeGenUtils.getSpecialOpCheck(representant, ctx, hopVar); + if (specialInstr != null) { + sb.append(specialInstr); + } else { + // Some type casting + sb.append("(( "); + sb.append(CodeGenUtils.getOpClass(representant, ctx)); + sb.append(" ) "); + sb.append(hopVar); + sb.append(" )"); + sb.append(".getOp() == "); + sb.append(CodeGenUtils.getOpCode(representant, ctx)); + } + break; + case NUM_INPUTS: + sb.append("hi"); + if (!relativeChildPath.isEmpty()) { + sb.append("_"); + sb.append(relativeChildPath.stream().map(Object::toString).collect(Collectors.joining("_"))); + } + sb.append(".getInput().size() == "); + sb.append(conditionValue.toString()); + break; + default: + throw new IllegalArgumentException(conditionType.name()); + } + } + + public boolean insertIfMatches(Tuple2 t, final RuleContext ctx) { + if (matchesCondition(t._1, t._2, ctx)) { + rulesIf.add(t); + return true; + } + + return false; + } + + public boolean matchesCondition(RewriterRule rule, RewriterStatement stmt, final RuleContext ctx) { + switch (conditionType) { + case DATA_TYPE: + return matchesDataTypeCondition(rule, stmt, ctx); + case OP_CLASS: + return matchesOpClassCondition(rule, stmt, ctx); + case OP_CODE: + return matchesOpCodeCondition(rule, stmt, ctx); + case NUM_INPUTS: + return matchesNumInputs(rule, stmt, ctx); + } + return false; + } + + public ConditionDataType getDataType() { + return (ConditionDataType) conditionValue; + } + + private boolean matchesNumInputs(RewriterRule rule, RewriterStatement stmt, final RuleContext ctx) { + return ((int)conditionValue) == stmt.getOperands().size(); + } + + private boolean matchesDataTypeCondition(RewriterRule rule, RewriterStatement stmt, final RuleContext ctx) { + ConditionDataType cdt = getDataType(); + String dType = stmt.getResultingDataType(ctx); + + if (dType.equals("MATRIX")) + return cdt.equals(ConditionDataType.MATRIX); + else + return cdt.equals(ConditionDataType.SCALAR); + } + + private boolean matchesOpClassCondition(RewriterRule rule, RewriterStatement stmt, final RuleContext ctx) { + try { + String opClass = (String) conditionValue; + String actualClass = CodeGenUtils.getOpClass(stmt, ctx); + + return opClass.equals(actualClass); + } catch (Exception e) { + System.err.println(rule.toParsableString(ctx)); + throw e; + } + } + + private boolean matchesOpCodeCondition(RewriterRule rule, RewriterStatement stmt, final RuleContext ctx) { + String opType = (String) conditionValue; + String actualOpType = CodeGenUtils.getOpCode(stmt, ctx); + + return actualOpType.equals(opType); + } + + + public static CodeGenCondition conditionalDataType(RewriterStatement stmt, List i, RewriterStatement representant, final RuleContext ctx) { + ConditionDataType cdt = stmt.getResultingDataType(ctx).equals("MATRIX") ? ConditionDataType.MATRIX : ConditionDataType.SCALAR; + return new CodeGenCondition(ConditionType.DATA_TYPE, cdt, i, representant, ctx); + } + + public static CodeGenCondition conditionalOpClass(RewriterStatement op, List i, RewriterStatement representant, final RuleContext ctx) { + String opClass = CodeGenUtils.getOpClass(op, ctx); + return new CodeGenCondition(ConditionType.OP_CLASS, opClass, i, representant, ctx); + } + + public static boolean canGenerateOpClassCheck(RewriterStatement op, final RuleContext ctx) { + return !op.isDataOrigin(); + } + + public static CodeGenCondition conditionalOpCode(RewriterStatement op, List i, RewriterStatement representant, final RuleContext ctx) { + String opCode = CodeGenUtils.getOpCode(op, ctx); + return new CodeGenCondition(ConditionType.OP_CODE, opCode, i, representant, ctx); + } + + public static boolean canGenerateOpCodeCheck(RewriterStatement op, final RuleContext ctx) { + return !op.isDataOrigin(); + } + + public static CodeGenCondition conditionalInputSize(int inputSize, List i, RewriterStatement representant, final RuleContext ctx) { + return new CodeGenCondition(ConditionType.NUM_INPUTS, inputSize, i, representant, ctx); + } + + public static boolean canGenerateInputSizeCheck(RewriterStatement op, final RuleContext ctx) { + return !op.isDataOrigin(); + } + + public static CodeGenCondition conditionalElse(List l, List relativeChildPath, RewriterStatement representant, final RuleContext ctx) { + CodeGenCondition cond = new CodeGenCondition(ConditionType.ELSE, null, relativeChildPath, representant, ctx); + cond.rulesIf = l; + return cond; + } + + public static String getSelectionString(List conds, int indentation, Map ruleFunctionMappings, final RuleContext ctx) { + StringBuilder sb = new StringBuilder(); + buildSelection(sb, conds, indentation, ruleFunctionMappings, ctx); + return sb.toString(); + } + + public static void buildSelection(StringBuilder sb, List conds, int indentation, Map ruleFunctionMappings, final RuleContext ctx) { + if (conds.isEmpty()) + return; + + CodeGenCondition firstCond = conds.get(0); + + if (firstCond.conditionType == ConditionType.ELSE) { + List nestedCondition = firstCond.rulesIf.stream().filter(o -> o instanceof CodeGenCondition).map(o -> (CodeGenCondition)o).collect(Collectors.toList()); + buildSelection(sb, nestedCondition, indentation, ruleFunctionMappings, ctx); + if (nestedCondition.isEmpty()) { + List> cur = firstCond.rulesIf.stream().map(o -> (Tuple2)o).collect(Collectors.toList()); + + for (Tuple2 t : cur) { + String fMapping = ruleFunctionMappings.get(t._1); + if (fMapping != null) { + RewriterCodeGen.indent(indentation, sb); + sb.append("hi = "); + sb.append(fMapping); + sb.append("(hi); // "); + sb.append(t._1.toString()); + sb.append("\n"); + } + } + } + return; + } + + RewriterCodeGen.indent(indentation, sb); + sb.append("if ( "); + firstCond.buildConditionCheck(sb, ctx); + sb.append(" ) {\n"); + + if (firstCond.conditionType == ConditionType.NUM_INPUTS) { + int numInputs = (int)firstCond.conditionValue; + + for (int i = 0; i < numInputs; i++) { + RewriterCodeGen.indent(indentation + 1, sb); + sb.append("Hop "); + sb.append(firstCond.getVarName()); + sb.append("_"); + sb.append(i); + sb.append(" = "); + sb.append(firstCond.getVarName()); + sb.append(".getInput("); + sb.append(i); + sb.append(");\n"); + } + } + + List nestedCondition = firstCond.rulesIf.stream().filter(o -> o instanceof CodeGenCondition).map(o -> (CodeGenCondition)o).collect(Collectors.toList()); + buildSelection(sb, nestedCondition, indentation + 1, ruleFunctionMappings, ctx); + + if (nestedCondition.isEmpty()) { + List> cur = firstCond.rulesIf.stream().map(o -> (Tuple2)o).collect(Collectors.toList()); + + if (cur.isEmpty()) + throw new IllegalArgumentException(firstCond.rulesIf.toString()); + + for (Tuple2 t : cur) { + String fMapping = ruleFunctionMappings.get(t._1); + if (fMapping != null) { + RewriterCodeGen.indent(indentation + 1, sb); + sb.append("hi = "); + sb.append(fMapping); + sb.append("(hi); // "); + sb.append(t._1.toString()); + sb.append("\n"); + } + } + } + + RewriterCodeGen.indent(indentation, sb); + sb.append("}"); + + for (CodeGenCondition cond : conds.subList(1, conds.size())) { + if (cond.conditionType == ConditionType.ELSE) { + sb.append(" else {\n"); + } else { + sb.append(" else if ( "); + cond.buildConditionCheck(sb, ctx); + sb.append(" ) {\n"); + } + + if (cond.conditionType == ConditionType.NUM_INPUTS) { + int numInputs = (int)cond.conditionValue; + + for (int i = 0; i < numInputs; i++) { + RewriterCodeGen.indent(indentation + 1, sb); + sb.append("Hop "); + sb.append(cond.getVarName()); + sb.append("_"); + sb.append(i); + sb.append(" = "); + sb.append(cond.getVarName()); + sb.append(".getInput("); + sb.append(i); + sb.append(");"); + } + } + + List mNestedCondition = cond.rulesIf.stream().filter(o -> o instanceof CodeGenCondition).map(o -> (CodeGenCondition)o).collect(Collectors.toList()); + buildSelection(sb, mNestedCondition, indentation + 1, ruleFunctionMappings, ctx); + + if (mNestedCondition.isEmpty()) { + List> cur = cond.rulesIf.stream().map(o -> (Tuple2)o).collect(Collectors.toList()); + + if (cur.isEmpty()) + throw new IllegalArgumentException(cond.rulesIf.toString()); + + for (Tuple2 t : cur) { + String fMapping = ruleFunctionMappings.get(t._1); + if (fMapping != null) { + RewriterCodeGen.indent(indentation + 1, sb); + sb.append("hi = "); + sb.append(fMapping); + sb.append("(hi); // "); + sb.append(t._1.toString()); + sb.append("\n"); + } + } + } + + RewriterCodeGen.indent(indentation, sb); + sb.append("}"); + } + + sb.append("\n"); + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/codegen/RewriterCodeGen.java b/src/main/java/org/apache/sysds/hops/rewriter/codegen/RewriterCodeGen.java new file mode 100644 index 00000000000..2274542d1e6 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/codegen/RewriterCodeGen.java @@ -0,0 +1,841 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter.codegen; + +import org.apache.sysds.hops.Hop; +import org.apache.sysds.hops.rewriter.rule.RewriterRuleSet; +import org.apache.sysds.hops.rewriter.assertions.RewriterAssertions; +import org.apache.sysds.hops.rewriter.estimators.RewriterCostEstimator; +import org.apache.sysds.hops.rewriter.RewriterDataType; +import org.apache.sysds.hops.rewriter.rule.RewriterRule; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.apache.sysds.hops.rewriter.utils.CodeGenUtils; +import org.codehaus.janino.SimpleCompiler; +import scala.Tuple2; + +import java.io.FileWriter; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.AbstractCollection; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; + +public class RewriterCodeGen { + public static boolean DEBUG = true; + + public static String generateRewritesFromFiles(List filePaths, String targetFile, boolean optimize, final RuleContext ctx) throws IOException { + return generateRewritesFromFiles(filePaths, targetFile, optimize, 2, true, true, ctx); + } + + public static String generateRewritesFromFiles(List filePaths, String targetFile, boolean optimize, int maxOptimizationDepth, boolean includePackageInfo, boolean maintainStatistics, final RuleContext ctx) throws IOException { + List lines = new ArrayList<>(); + + for (String path : filePaths) { + lines.addAll(Files.readAllLines(Paths.get(path))); + } + + RewriterRuleSet ruleSet = RewriterRuleSet.deserialize(lines, ctx); + String javaCode = ruleSet.toJavaCode("GeneratedRewriteClass", optimize, maxOptimizationDepth, includePackageInfo, true, maintainStatistics); + + try (FileWriter writer = new FileWriter(targetFile)) { + writer.write(javaCode); + } catch (IOException e) { + throw e; + } + + return javaCode; + } + + public static Function compileRewrites(String className, List> rewrites, final RuleContext ctx, boolean ignoreErrors, boolean printErrors) throws Exception { + String code = generateClass(className, rewrites, false, false, ctx, ignoreErrors, printErrors); + System.out.println("Compiling code:\n" + code); + SimpleCompiler compiler = new SimpleCompiler(); + compiler.cook(code); + Class mClass = compiler.getClassLoader().loadClass(className); + Object instance = mClass.getDeclaredConstructor().newInstance(); + return (Function) instance; + } + + public static Function compile(String javaCode, String className) { + try { + SimpleCompiler compiler = new SimpleCompiler(); + compiler.cook(javaCode); + Class mClass = compiler.getClassLoader().loadClass(className); + Object instance = mClass.getDeclaredConstructor().newInstance(); + return (Function) instance; + } catch (Exception e) { + e.printStackTrace(); + return null; + } + } + + public static String generateClass(String className, List> rewrites, boolean optimize, boolean includePackageInfo, final RuleContext ctx, boolean ignoreErrors, boolean printErrors) { + return generateClass(className, rewrites, optimize, 2, includePackageInfo, ctx, ignoreErrors, printErrors, false); + } + + public static String generateClass(String className, List> rewrites, boolean optimize, int maxOptimizationDepth, boolean includePackageInfo, final RuleContext ctx, boolean ignoreErrors, boolean printErrors, boolean maintainRewriteStats) { + StringBuilder msb = new StringBuilder(); + + if (includePackageInfo) { + // Include license + msb.append("/*\n" + + " * Licensed to the Apache Software Foundation (ASF) under one\n" + + " * or more contributor license agreements. See the NOTICE file\n" + + " * distributed with this work for additional information\n" + + " * regarding copyright ownership. The ASF licenses this file\n" + + " * to you under the Apache License, Version 2.0 (the\n" + + " * \"License\"); you may not use this file except in compliance\n" + + " * with the License. You may obtain a copy of the License at\n" + + " *\n" + + " * http://www.apache.org/licenses/LICENSE-2.0\n" + + " *\n" + + " * Unless required by applicable law or agreed to in writing,\n" + + " * software distributed under the License is distributed on an\n" + + " * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n" + + " * KIND, either express or implied. See the License for the\n" + + " * specific language governing permissions and limitations\n" + + " * under the License.\n" + + " */\n\n"); + msb.append("package org.apache.sysds.hops.rewriter.generated;\n\n"); + } + + msb.append("import java.util.ArrayList;\n"); + msb.append("import java.util.function.Function;\n"); + msb.append("\n"); + msb.append("import org.apache.sysds.utils.Statistics;\n"); + msb.append("import org.apache.sysds.hops.Hop;\n"); + msb.append("import org.apache.sysds.hops.LiteralOp;\n"); + msb.append("import org.apache.sysds.hops.UnaryOp;\n"); + msb.append("import org.apache.sysds.hops.BinaryOp;\n"); + msb.append("import org.apache.sysds.hops.ReorgOp;\n"); + msb.append("import org.apache.sysds.hops.AggUnaryOp;\n"); + msb.append("import org.apache.sysds.hops.AggBinaryOp;\n"); + msb.append("import org.apache.sysds.hops.DataGenOp;\n"); + msb.append("import org.apache.sysds.hops.TernaryOp;\n"); + msb.append("import org.apache.sysds.common.Types;\n"); + msb.append("import org.apache.sysds.hops.rewrite.HopRewriteUtils;\n"); + msb.append("import org.apache.sysds.hops.rewriter.dml.DMLExecutor;\n"); + msb.append("import org.apache.sysds.hops.rewriter.RewriterRuntimeUtils;\n"); + msb.append("\n"); + msb.append("public class " + className + " implements Function {\n\n"); + + StringBuilder implSb = new StringBuilder(); + Set implemented = new HashSet<>(); + int implementedRules = 0; + for (Tuple2 appliedRewrites : rewrites) { + String mRewriteFn; + if (ignoreErrors) { + try { + mRewriteFn = generateRewriteFunction(appliedRewrites._2, appliedRewrites._1, 1, maintainRewriteStats, ctx); + implementedRules++; + } catch (Exception e) { + if (printErrors) + e.printStackTrace(); + + continue; + } + } else { + mRewriteFn = generateRewriteFunction(appliedRewrites._2, appliedRewrites._1, 1, maintainRewriteStats, ctx); + implementedRules++; + } + + implSb.append('\n'); + indent(1, implSb); + implSb.append("// Implementation of the rule " + appliedRewrites._2 + "\n"); + implSb.append(mRewriteFn); + implemented.add(appliedRewrites._1); + } + + indent(1, msb); + msb.append("@Override\n"); + indent(1, msb); + msb.append("public Object apply( Object _hi ) {\n"); + indent(2, msb); + msb.append("if ( _hi == null )\n"); + indent(3, msb); + msb.append("return null;\n\n"); + indent(2, msb); + msb.append("Hop hi = (Hop) _hi;\n\n"); + + if (optimize) { + List> implementedRewrites = rewrites.stream().filter(t -> implemented.contains(t._1)).collect(Collectors.toList()); + + List rules = rewrites.stream().map(t -> t._2).collect(Collectors.toList()); + Map ruleNames = new HashMap<>(); + + for (Tuple2 t : implementedRewrites) + ruleNames.put(t._2, t._1); + + List conditions = CodeGenCondition.buildCondition(rules, maxOptimizationDepth, 5, ctx); + CodeGenCondition.buildSelection(msb, conditions, 2, ruleNames, ctx); + } else { + for (Tuple2 appliedRewrites : rewrites) { + if (implemented.contains(appliedRewrites._1)) { + indent(2, msb); + msb.append("hi = " + appliedRewrites._1 + "((Hop) hi);\t\t// "); + msb.append(appliedRewrites._2.toString()); + msb.append('\n'); + } + } + } + + indent(2, msb); + msb.append("return hi;\n"); + + indent(1, msb); + msb.append("}\n"); + + msb.append(implSb); + + msb.append('\n'); + buildTypeCastFunction(msb, 1); + msb.append('\n'); + buildMinIdxFunction(msb, 1); + msb.append('\n'); + msb.append("}"); + System.out.println("Implemented rules: " + implementedRules); + return msb.toString(); + } + + private static String generateRewriteFunction(RewriterRule rule, String fName, int indentation, boolean maintainRewriteStats, final RuleContext ctx) { + try { + Tuple2, Boolean> t = RewriterCostEstimator.determineSingleReferenceRequirement(rule, ctx); + Set mSet = t._1; + if (mSet instanceof AbstractCollection) + mSet = new HashSet<>(mSet); + mSet.add(rule.getStmt1()); + boolean allowCombinedMultiRefs = t._2; + + StringBuilder sb = new StringBuilder(); + + // Append the function signature + indent(indentation, sb); + sb.append("private static Hop " + fName + "(Hop hi) {\n"); + + if (!allowCombinedMultiRefs) { + indent(indentation + 1, sb); + sb.append("boolean _multiReference = false;\n"); + } + + List tos = rule.isConditionalMultiRule() ? rule.getConditionalMultiRuleTargets() : List.of(rule.getStmt2()); + + // Build the function body + buildMatchingSequence(rule.toString(), rule.getStmt1(), tos, rule.getStmt1Cost(), rule.getStmt2Costs(), rule.getCombinedAssertions(), sb, ctx, indentation + 1, mSet, allowCombinedMultiRefs, maintainRewriteStats); + indent(indentation, sb); + + sb.append("}\n"); + + return sb.toString(); + } catch (Exception e) { + e.addSuppressed(new Exception("Failed to generate rewrite rule: " + rule.toString() + "\nAssertions: " + rule.getCombinedAssertions())); + throw e; + } + } + + private static void buildMatchingSequence(String name, RewriterStatement from, List tos, RewriterStatement fromCost, List toCosts, RewriterAssertions combinedAssertions, StringBuilder sb, final RuleContext ctx, int indentation, Set allowedMultiRefs, boolean allowCombinations, boolean maintainRewriteStats) { + Map vars = new HashMap<>(); + vars.put(from, "hi"); + recursivelyBuildMatchingSequence(from, sb, "hi", ctx, indentation, vars, allowedMultiRefs, allowCombinations); + + from.forEachPreOrder(el -> { + if (el.isInstruction() && el.trueInstruction().equals("const") && vars.get(el.getChild(0)) == null) { + vars.put(el.getChild(0), vars.get(el)); + } + + return true; + }, false); + + if (fromCost != null) { + List msb = new ArrayList<>(); + msb.add(new StringBuilder()); + Set> requirements = new HashSet<>(); + + buildCostFnRecursively(fromCost, vars, ctx, msb.get(0), requirements); + + for (RewriterStatement toCost : toCosts) { + StringBuilder msb2 = new StringBuilder(); + buildCostFnRecursively(toCost, vars, ctx, msb2, requirements); + msb.add(msb2); + } + + // First, we build the necessary checks (e.g. if we have nnz / dim information we need, otherwise this rewrite cannot be applied) + if (!requirements.isEmpty()) { + sb.append('\n'); + indent(indentation, sb); + sb.append("if ( "); + + int ctr = 0; + for (Tuple2 req : requirements) { + if (ctr != 0) + sb.append(" || "); + + sb.append(req._1); + switch (req._2) { + case "_nnz": + sb.append(".getNnz() == -1"); + break; + case "nrow": + sb.append(".getDim1() == -1"); + break; + case "ncol": + sb.append(".getDim2() == -1"); + break; + default: + throw new IllegalArgumentException(req._2); + } + + ctr++; + } + + sb.append(" )\n"); + indent(indentation + 1, sb); + sb.append("return hi;\n\n"); + } + + // Then we build the cost functions + sb.append('\n'); + indent(indentation, sb); + sb.append("double[] costs = new double["); + sb.append(msb.size()); + sb.append("];\n"); + + for (int i = 0; i < msb.size(); i++) { + indent(indentation, sb); + sb.append("costs["); + sb.append(i); + sb.append("] = "); + sb.append(msb.get(i)); + sb.append(";\n"); + } + + indent(indentation, sb); + sb.append("int minIdx = minIdx(costs);\n\n"); + indent(indentation, sb); + sb.append("switch( minIdx ) {\n"); + + for (int i = 1; i < msb.size(); i++) { + indent(indentation+1, sb); + sb.append("case " + i + ": {"); + buildNewHop(name, from, tos.get(i-1), sb, combinedAssertions, new HashMap<>(vars), ctx, indentation+2, maintainRewriteStats); + indent(indentation+1, sb); + sb.append("}\n"); + } + + indent(indentation, sb); + sb.append("}\n"); + + indent(indentation, sb); + sb.append("return hi;\n"); + } else { + buildNewHop(name, from, tos.get(0), sb, combinedAssertions, vars, ctx, indentation, maintainRewriteStats); + } + } + + private static void buildNewHop(String rewriteName, RewriterStatement from, RewriterStatement to, StringBuilder sb, RewriterAssertions combinedAssertions, Map vars, final RuleContext ctx, int indentation, boolean maintainRewriteStats) { + sb.append('\n'); + indent(indentation, sb); + sb.append("// Now, we start building the new HOP-DAG: "); + sb.append(to.toParsableString(ctx)); + sb.append('\n'); + + Set activeStatements = buildRewrite(to, sb, combinedAssertions, vars, ctx, indentation); + + String newRoot = vars.get(to); + + sb.append('\n'); + indent(indentation, sb); + sb.append("Hop newRoot = " + newRoot + ";\n"); + indent(indentation, sb); + sb.append("if ( " + newRoot + ".getValueType() != hi.getValueType() ) {\n"); + indent(indentation + 1, sb); + sb.append("newRoot = castIfNecessary(newRoot, hi);\n"); + indent(indentation + 1, sb); + sb.append("if ( newRoot == null )\n"); + indent(indentation + 2, sb); + sb.append("return hi;\n"); + indent(indentation, sb); + sb.append("}\n"); + + + sb.append('\n'); + indent(indentation, sb); + sb.append("ArrayList parents = new ArrayList<>(hi.getParent());\n\n"); + indent(indentation, sb); + sb.append("for ( Hop p : parents )\n"); + indent(indentation + 1, sb); + sb.append("HopRewriteUtils.replaceChildReference(p, hi, newRoot);\n\n"); + + indent(indentation, sb); + sb.append("// Remove old unreferenced Hops\n"); + removeUnreferencedHops(from, activeStatements, sb, vars, ctx, indentation); + sb.append('\n'); + + if (DEBUG) { + indent(indentation, sb); + sb.append("DMLExecutor.println(\"Applying rewrite: " + rewriteName + "\");\n"); + } + + if (maintainRewriteStats) { + indent(indentation, sb); + sb.append("Statistics.applyGeneratedRewrite(\"" + rewriteName + "\");\n"); + } + + indent(indentation, sb); + sb.append("return newRoot;\n"); + } + + private static void buildTypeCastFunction(StringBuilder sb, int indentation) { + String str = "private static Hop castIfNecessary(Hop newRoot, Hop oldRoot) {\n" + + "\tTypes.OpOp1 cast = null;\n" + + "\tswitch ( oldRoot.getValueType().toExternalString() ) {\n" + + "\t\tcase \"DOUBLE\":\n" + + "\t\t\tcast = Types.OpOp1.CAST_AS_DOUBLE;\n" + + "\t\t\tbreak;\n" + + "\t\tcase \"INT\":\n" + + "\t\t\tcast = Types.OpOp1.CAST_AS_INT;\n" + + "\t\t\tbreak;\n" + + "\t\tcase \"BOOLEAN\":\n" + + "\t\t\tcast = Types.OpOp1.CAST_AS_BOOLEAN;\n" + + "\t\t\tbreak;\n" + + "\t\tdefault:\n" + + "\t\t\treturn null;\n" + + "\t}\n" + + "\n" + + "\treturn new UnaryOp(\"tmp\", oldRoot.getDataType(), oldRoot.getValueType(), cast, newRoot);\n" + + "}\n"; + + sb.append(indentMultilineString(str, indentation)); + } + + private static void buildMinIdxFunction(StringBuilder sb, int indentation) { + String str = "private static int minIdx(double[] l) {\n" + + "\tdouble minValue = Double.MAX_VALUE;\n" + + "\tint minIdx = -1;\n" + + "\n" + + "\tfor (int i = 0; i < l.length; i++) {\n" + + "\t\tif (l[i] < minValue) {\n" + + "\t\t\tminValue = l[i];\n" + + "\t\t\tminIdx = i;\n" + + "\t\t}\n" + + "\t}\n" + + "\n" + + "\treturn minIdx;\n" + + "}\n"; + + sb.append(indentMultilineString(str, indentation)); + } + + private static String indentMultilineString(String str, int indentation) { + String tabs = "\t".repeat(indentation); + return str.lines() // Split the string into lines + .map(line -> tabs + line) // Add tabs to the beginning of each line + .collect(Collectors.joining("\n")); // Join the lines back together + } + + private static void buildCostFnRecursively(RewriterStatement costFn, Map vars, final RuleContext ctx, StringBuilder sb, Set> requirements) { + if (costFn.isLiteral()) { + sb.append(costFn.floatLiteral()); + return; + } + + if (!costFn.isInstruction()) + throw new IllegalArgumentException(); + + List operands; + + if (!costFn.getOperands().isEmpty() && costFn.getChild(0).isArgumentList()) + operands = costFn.getChild(0).getOperands(); + else + operands = costFn.getOperands(); + + String varName; + + // Then, the cost function is an instruction + switch (costFn.trueInstruction()) { + case "_nnz": + varName = vars.get(costFn.getChild(0)); + + if (varName == null) + throw new IllegalArgumentException(costFn.toParsableString(ctx)); + + requirements.add(new Tuple2<>(varName, "_nnz")); + sb.append(varName); + sb.append(".getNnz()"); + break; + + case "nrow": + varName = vars.get(costFn.getChild(0)); + + if (varName == null) + throw new IllegalArgumentException(); + + requirements.add(new Tuple2<>(varName, "nrow")); + sb.append(varName); + sb.append(".getDim1()"); + break; + + case "ncol": + varName = vars.get(costFn.getChild(0)); + + if (varName == null) + throw new IllegalArgumentException(); + + requirements.add(new Tuple2<>(varName, "ncol")); + sb.append(varName); + sb.append(".getDim2()"); + break; + + case "+": + case "*": + sb.append('('); + + for (int i = 0; i < operands.size(); i++) { + if (i != 0) { + sb.append(' '); + sb.append(costFn.trueInstruction()); + sb.append(' '); + } + + buildCostFnRecursively(operands.get(i), vars, ctx, sb, requirements); + } + + sb.append(')'); + break; + case "inv": + sb.append("(1.0 / "); + buildCostFnRecursively(operands.get(0), vars, ctx, sb, requirements); + sb.append(')'); + break; + case "min": + case "max": + sb.append("Math."); + sb.append(costFn.trueInstruction()); + sb.append('('); + for (int i = 0; i < operands.size(); i++) { + if (i != 0) + sb.append(", "); + + buildCostFnRecursively(operands.get(i), vars, ctx, sb, requirements); + } + sb.append(')'); + break; + case "_EClass": + // Here, we can just select a random representant + // Ideally, we would choose one that has dimensions available, but for now, we just take the first + buildCostFnRecursively(operands.get(0), vars, ctx, sb, requirements); + break; + default: + throw new IllegalArgumentException(costFn.trueInstruction()); + } + } + + // Returns the set of all active statements after the rewrite + private static Set buildRewrite(RewriterStatement newRoot, StringBuilder sb, RewriterAssertions assertions, Map vars, final RuleContext ctx, int indentation) { + Set visited = new HashSet<>(); + recursivelyBuildNewHop(sb, newRoot, assertions, vars, ctx, indentation, 1, visited, newRoot.getResultingDataType(ctx).equals("FLOAT"), new ArrayList<>()); + + return visited; + } + + private static void removeUnreferencedHops(RewriterStatement oldRoot, Set activeStatements, StringBuilder sb, Map vars, final RuleContext ctx, int indentation) { + oldRoot.forEachPreOrder(cur -> { + if (activeStatements.contains(cur)) + return true; + + indent(indentation, sb); + sb.append("HopRewriteUtils.cleanupUnreferenced(" + vars.get(cur) + ");\n"); + return true; + }, false); + } + + private static int recursivelyBuildNewHop(StringBuilder sb, RewriterStatement cur, RewriterAssertions assertions, Map vars, final RuleContext ctx, int indentation, int varCtr, Set visited, boolean enforceRootDataType, List createdOps) { + visited.add(cur); + if (vars.containsKey(cur)) + return varCtr; + + for (RewriterStatement child : cur.getOperands()) + varCtr = recursivelyBuildNewHop(sb, child, assertions, vars, ctx, indentation, varCtr, visited, false, createdOps); + + if (cur instanceof RewriterDataType) { + if (cur.isLiteral()) { + indent(indentation, sb); + String name = "l" + (varCtr++); + String literalStr = cur.getLiteral().toString(); + + if (enforceRootDataType) { + sb.append("LiteralOp " + name + ";\n"); + indent(indentation, sb); + sb.append("switch (hi.getValueType()) {\n"); + indent(indentation+1, sb); + sb.append("case FP64:\n"); + indent(indentation+2, sb); + sb.append(name); + sb.append(" = new LiteralOp( " + cur.floatLiteral() + " );\n"); + indent(indentation+2, sb); + sb.append("break;\n"); + indent(indentation+1, sb); + sb.append("case INT64:\n"); + indent(indentation+2, sb); + sb.append(name); + sb.append(" = new LiteralOp( " + cur.intLiteral(true) + " );\n"); + indent(indentation+2, sb); + sb.append("break;\n"); + indent(indentation+1, sb); + sb.append("case BOOLEAN:\n"); + indent(indentation+2, sb); + sb.append(name); + sb.append(" = new LiteralOp( " + cur.boolLiteral() + " );\n"); + indent(indentation+2, sb); + sb.append("break;\n"); + indent(indentation+1, sb); + sb.append("default:\n"); + indent(indentation+2, sb); + sb.append("return hi;\n"); + indent(indentation+1, sb); + sb.append("}\n"); + } else { + sb.append("LiteralOp " + name + " = new LiteralOp( " + literalStr + " );\n"); + } + vars.put(cur, name); + createdOps.add(name); + } + + return varCtr; + } else { + String opClass = CodeGenUtils.getOpClass(cur, ctx); + String[] operandRefs = cur.getOperands().stream().map(vars::get).toArray(String[]::new); + + if (CodeGenUtils.opRequiresBinaryBroadcastingMatch(cur, ctx)) { + // Then we need to validate that broadcasting still works after rearranging + indent(indentation, sb); + sb.append("if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(" + operandRefs[0] + ", " + operandRefs[1] + ") ) {\n"); + for (String createdOp : createdOps) { + // Properly remove the references to the newly constructed ops + indent(indentation+1, sb); + sb.append("HopRewriteUtils.removeAllChildReferences(" + createdOp + ");\n"); + } + indent(indentation+1, sb); + sb.append("return hi;\n"); + indent(indentation, sb); + sb.append("}\n"); + } else { + List matchingDims = CodeGenUtils.matchingDimRequirement(cur, ctx); + + if (!matchingDims.isEmpty()) { + // Then we need to validate that broadcasting still works after rearranging + indent(indentation, sb); + sb.append("if ( !RewriterRuntimeUtils.hasMatchingDims(" + matchingDims.stream().map(idx -> operandRefs[idx]).collect(Collectors.joining(", ")) + ") ) {\n"); + for (String createdOp : createdOps) { + // Properly remove the references to the newly constructed ops + indent(indentation+1, sb); + sb.append("HopRewriteUtils.removeAllChildReferences(" + createdOp + ");\n"); + } + indent(indentation+1, sb); + sb.append("return hi;\n"); + indent(indentation, sb); + sb.append("}\n"); + } + } + + String constructor = CodeGenUtils.getHopConstructor(cur, assertions, vars, ctx, operandRefs); + String name = "v" + (varCtr++); + indent(indentation, sb); + sb.append(opClass + " " + name + " = " + constructor + ";\n"); + + vars.put(cur, name); + createdOps.add(name); + } + + return varCtr; + } + + private static void recursivelyBuildMatchingSequence(RewriterStatement cur, StringBuilder sb, String curVar, final RuleContext ctx, int indentation, Map map, Set allowedMultiRefs, boolean allowCombinations) { + if (cur.isLiteral()) { + String[] types = CodeGenUtils.getReturnType(cur, ctx); + indent(indentation, sb); + sb.append("if ( !(" + curVar + " instanceof LiteralOp) )\n"); + indent(indentation + 1, sb); + sb.append("return hi;\n\n"); + indent(indentation, sb); + String lVar = "l_" + curVar; + sb.append("LiteralOp " + lVar + " = (LiteralOp) " + curVar + ";\n\n"); + indent(indentation, sb); + sb.append("if ( " + lVar + ".getDataType() != " + types[0]); + sb.append("|| !" + lVar + ".getValueType().isNumeric()"); + sb.append(" )\n"); + + indent(indentation + 1, sb); + sb.append("return hi;\n\n"); + + indent(indentation, sb); + sb.append("if ( " + lVar + "." + CodeGenUtils.literalGetterFunction(cur, ctx) + " != " + cur.getLiteral() + " )\n"); + indent(indentation + 1, sb); + sb.append("return hi;\n\n"); + + return; + } + + // Check if we have to ensure a single reference to this object + if (cur.isInstruction() && !allowedMultiRefs.contains(cur)) { + if (allowCombinations && !allowedMultiRefs.contains(cur)) { + indent(indentation, sb); + sb.append("if ("); + sb.append(curVar); + sb.append(".getParent().size() > 1)\n"); + indent(indentation + 1, sb); + sb.append("return hi;\n"); + } else if (!allowedMultiRefs.contains(cur)) { + indent(indentation, sb); + sb.append("if ("); + sb.append(curVar); + sb.append(".getParent().size() > 1) {\n"); + indent(indentation + 1, sb); + sb.append("if (_multiReference)\n"); + indent(indentation + 2, sb); + sb.append("return hi;\n"); + indent(indentation + 1, sb); + sb.append("else\n"); + indent(indentation + 2, sb); + sb.append("_multiReference = true;\n"); + indent(indentation + 1, sb); + sb.append("}\n"); + } + } + + String specialOpCheck = CodeGenUtils.getSpecialOpCheck(cur, ctx, curVar); + + // E.g. A %*% B, which is an AggBinaryOp consisting of multiple OpCodes + if (specialOpCheck != null) { + indent(indentation, sb); + sb.append("if ( !" + specialOpCheck + " )\n"); + indent(indentation + 1, sb); + sb.append("return hi;\n\n"); + } else if (!cur.isDataOrigin()) { + String opClass = CodeGenUtils.getOpClass(cur, ctx); + + // Generate initial class check + indent(indentation, sb); + sb.append("if ( !(" + curVar + " instanceof " + opClass + ") )\n"); + indent(indentation + 1, sb); + sb.append("return hi;\n\n"); + + // Cast the expression to the corresponding op-class + String cCurVar = "c_" + curVar; + indent(indentation, sb); + sb.append(opClass + " " + cCurVar + " = (" + opClass + ") " + curVar + ";\n\n"); + + String opCode = CodeGenUtils.getOpCode(cur, ctx); + + // Check if the instruction matches + indent(indentation, sb); + if (opCode != null) { + sb.append("if ( " + cCurVar + ".getOp() != " + opCode); + sb.append(" || !" + cCurVar + ".getValueType().isNumeric()"); + } else { + sb.append("if ( !" + cCurVar + ".getValueType().isNumeric()"); + } + + sb.append(" )\n"); + indent(indentation + 1, sb); + sb.append("return hi;\n\n"); + + String additionalCheck = CodeGenUtils.getAdditionalCheck(cur, ctx, cCurVar); + + if (additionalCheck != null) { + indent(indentation, sb); + sb.append("if ( !(" + additionalCheck + ") )\n"); + indent(indentation + 1, sb); + sb.append("return hi;\n\n"); + } + } else { + indent(indentation, sb); + String[] types = CodeGenUtils.getReturnType(cur, ctx); + sb.append("if ( " + curVar + ".getDataType() != " + types[0]); + sb.append(" || !" + curVar + ".getValueType().isNumeric()"); + + if (cur.isRowVector()) { + sb.append(" || " + curVar + ".getDim2() != 1L"); + } else if (cur.isColVector()) { + sb.append(" || " + curVar + ".getDim1() != 1L"); + } + + sb.append(" )\n"); + indent(indentation + 1, sb); + sb.append("return hi;\n\n"); + + String additionalCheck = CodeGenUtils.getAdditionalCheck(cur, ctx, curVar); + + if (additionalCheck != null) { + indent(indentation, sb); + sb.append("if ( !(" + additionalCheck + ") )\n"); + indent(indentation + 1, sb); + sb.append("return hi;\n\n"); + } + } + + // Now, we match the children + for (int i = 0; i < cur.getOperands().size(); i++) { + RewriterStatement stmt = cur.getChild(i); + + String existingVar = map.get(stmt); + + if (existingVar != null) { + String name = resolveOperand(cur, i, sb, curVar, ctx, indentation); + sb.append('\n'); + // Just check if they are identical + indent(indentation, sb); + sb.append("if ( " + existingVar + " != " + name + " )\n"); + indent(indentation + 1, sb); + sb.append("return hi;\n\n"); + continue; + } + + // Build the variable definition + String name = resolveOperand(cur, i, sb, curVar, ctx, indentation); + if (name != null) { + map.put(stmt, name); + sb.append('\n'); + recursivelyBuildMatchingSequence(stmt, sb, name, ctx, indentation, map, allowedMultiRefs, allowCombinations); + } + } + } + + private static String resolveOperand(RewriterStatement stmt, int idx, StringBuilder sb, String curVar, final RuleContext ctx, int indentation) { + String accessor = CodeGenUtils.getChildAccessor(curVar, stmt, idx); + if (accessor == null) + return null; // Then we do not need to traverse the sub-dag further + String name = curVar + "_" + idx; + indent(indentation, sb); + sb.append("Hop " + name + " = " + accessor + ";\n"); + return name; + } + + public static void indent(int depth, StringBuilder sb) { + sb.append("\t".repeat(depth)); + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/dml/DMLCodeGenerator.java b/src/main/java/org/apache/sysds/hops/rewriter/dml/DMLCodeGenerator.java new file mode 100644 index 00000000000..9baba0c3fa1 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/dml/DMLCodeGenerator.java @@ -0,0 +1,383 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter.dml; + +import org.apache.commons.lang3.NotImplementedException; +import org.apache.commons.lang3.function.TriFunction; +import org.apache.commons.lang3.mutable.MutableInt; +import org.apache.sysds.hops.rewriter.RewriterInstruction; +import org.apache.sysds.hops.rewriter.rule.RewriterRule; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; +import org.apache.sysds.hops.rewriter.RuleContext; +import scala.Tuple2; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.Set; +import java.util.function.Consumer; +import java.util.stream.Collectors; + +public class DMLCodeGenerator { + public static final long MATRIX_DIMS = 100; + public static final double EPS = 1e-10; + public static Random rd = new Random(42); + + + private static final HashSet printAsBinary = new HashSet<>(); + private static final HashMap, Boolean>> customEncoders = new HashMap<>(); + private static final RuleContext ctx = RewriterUtils.buildDefaultContext(); + + static { + printAsBinary.add("+"); + printAsBinary.add("-"); + printAsBinary.add("*"); + printAsBinary.add("/"); + printAsBinary.add("^"); + printAsBinary.add("&"); + printAsBinary.add("|"); + printAsBinary.add("=="); + printAsBinary.add("!="); + printAsBinary.add(">"); + printAsBinary.add(">="); + printAsBinary.add("<"); + printAsBinary.add("<="); + printAsBinary.add("%*%"); + + customEncoders.put("[]", (stmt, sb, tmpVars) -> { + if (stmt.getOperands().size() == 3) { + appendExpression(stmt.getChild(0), sb, tmpVars); + sb.append('['); + appendExpression(stmt.getChild(1), sb, tmpVars); + sb.append(", "); + appendExpression(stmt.getChild(2), sb, tmpVars); + sb.append(']'); + return true; + } else if (stmt.getOperands().size() == 5) { + appendExpression(stmt.getChild(0), sb, tmpVars); + sb.append('['); + appendExpression(stmt.getChild(1), sb, tmpVars); + sb.append(" : "); + appendExpression(stmt.getChild(2), sb, tmpVars); + sb.append(", "); + appendExpression(stmt.getChild(3), sb, tmpVars); + sb.append(" : "); + appendExpression(stmt.getChild(4), sb, tmpVars); + sb.append(']'); + return true; + } + + return false; + }); + + customEncoders.put("const", (stmt, sb, tmpVars) -> { + sb.append("matrix("); + appendExpression(stmt.getChild(1), sb, tmpVars); + sb.append(", rows="); + sb.append(MATRIX_DIMS); + sb.append(", cols="); + sb.append(MATRIX_DIMS); + sb.append(')'); + + return true; + }); + + customEncoders.put("cast.MATRIX", (stmt, sb, tmpVars) -> { + sb.append("as.matrix("); + appendExpression(stmt.getChild(0), sb, tmpVars); + sb.append(')'); + + return true; + }); + + customEncoders.put("cast.FLOAT", (stmt, sb, tmpVars) -> { + if (stmt.getChild(0).getResultingDataType(ctx).equals("MATRIX")) { + sb.append("as.scalar("); + appendExpression(stmt.getChild(0), sb, tmpVars); + sb.append(')'); + } else { + sb.append("as.double("); + appendExpression(stmt.getChild(0), sb, tmpVars); + sb.append(')'); + } + + return true; + }); + } + + public static Consumer ruleValidationScript(String ruleName, String sessionId, Consumer validator) { + return line -> { + if (!line.startsWith(sessionId)) + return; + + if (line.endsWith("valid: TRUE")) { + validator.accept(true); + } else { + validator.accept(false); + } + }; + } + + public static String generateRuleValidationDML(RewriterRule rule, String sessionId, final RuleContext ctx) { + return generateRuleValidationDML(rule, EPS, sessionId, ctx); + } + + public static String generateRuleValidationDML(RewriterRule rule, double eps, String sessionId, final RuleContext ctx) { + RewriterStatement stmtFrom = RewriterUtils.unfuseOperators(rule.getStmt1(), ctx); + RewriterStatement stmtTo = RewriterUtils.unfuseOperators(rule.getStmt2(), ctx); + + Set vars = new HashSet<>(); + List> orderedTmpVars = new ArrayList<>(); + Map tmpVars = new HashMap<>(); + MutableInt tmpVarCtr = new MutableInt(0); + + stmtFrom.forEachPostOrder((stmt, pred) -> { + if (stmt.isDataOrigin() && !stmt.isLiteral()) + vars.add(stmt); + else + createTmpVars(stmt, orderedTmpVars, tmpVars, tmpVarCtr); + }, false); + + stmtTo.forEachPostOrder((stmt, pred) -> { + if (stmt.isDataOrigin() && !stmt.isLiteral()) + vars.add(stmt); + else + createTmpVars(stmt, orderedTmpVars, tmpVars, tmpVarCtr); + }, false); + + Set toRemove = vars.stream().filter(t -> t.isInstruction() && !t.trueInstruction().equals("const")).map(instr -> instr.getChild(0)).collect(Collectors.toSet()); + vars.removeAll(toRemove); + + StringBuilder sb = new StringBuilder(); + + sb.append(generateDMLVariables(vars)); + + Map incrementingTmpVars = new HashMap<>(); + + for (Tuple2 t : orderedTmpVars) { + sb.append(t._2); + sb.append(" = "); + sb.append(generateDML(t._1, incrementingTmpVars)); + sb.append('\n'); + incrementingTmpVars.put(t._1, t._2); + } + + sb.append('\n'); + sb.append("R1 = "); + sb.append(generateDML(stmtFrom, tmpVars)); + sb.append('\n'); + sb.append("R2 = "); + sb.append(generateDML(stmtTo, tmpVars)); + sb.append('\n'); + sb.append("print(\""); + sb.append(sessionId); + sb.append(" valid: \" + ("); + sb.append(generateEqualityCheck("R1", "R2", stmtFrom.getResultingDataType(ctx), eps)); + sb.append("))"); + + return sb.toString(); + } + + private static boolean createTmpVars(RewriterStatement stmt, List> orderedTmpVars, Map tmpVars, MutableInt tmpVarCtr) { + if (stmt.isInstruction() && stmt.trueInstruction().equals("[]")) { + // Then we need to put the child into a variable + RewriterStatement child = stmt.getChild(0); + if (child.isInstruction() || child.isLiteral()) { + String tmpVar = "tmp" + tmpVarCtr.getAndIncrement(); + tmpVars.put(child, tmpVar); + orderedTmpVars.add(new Tuple2<>(child, tmpVar)); + return true; + } + } + + return false; + } + + public static Set getVariables(RewriterStatement root) { + Set vars = new HashSet<>(); + root.forEachPostOrder((stmt, pred) -> { + if (stmt.isDataOrigin() && !stmt.isLiteral()) + vars.add(stmt); + }, false); + + Set toRemove = vars.stream().filter(stmt -> stmt.isInstruction() && !stmt.trueInstruction().equals("const")).map(instr -> instr.getChild(0)).collect(Collectors.toSet()); + vars.removeAll(toRemove); + + return vars; + } + + public static String generateDMLVariables(RewriterStatement root) { + return generateDMLVariables(getVariables(root)); + } + + public static String generateDMLVariables(Set vars) { + StringBuilder sb = new StringBuilder(); + + for (RewriterStatement var : vars) { + + switch (var.getResultingDataType(ctx)) { + case "MATRIX": + String mId = var.getId(); + long nrow = MATRIX_DIMS; + long ncol = MATRIX_DIMS; + if (var.isInstruction()) { + if (var.trueInstruction().equals("rowVec")) { + mId = var.getChild(0).getId(); + nrow = 1L; + } else if (var.trueInstruction().equals("colVec")) { + mId = var.getChild(0).getId(); + ncol = 1L; + } else if (var.trueInstruction().equals("const")) { + sb.append(var.getId()); + sb.append(" = matrix(" + var.getChild(1).getLiteral() + ", rows=" + nrow + ", cols=" + ncol + ")\n"); + continue; + } + } + sb.append(mId + " = cos((rand(rows=" + nrow + ", cols=" + ncol + ") * rand(rows=" + nrow + ", cols=" + ncol + ", min=(as.scalar(rand())+1.0), max=(as.scalar(rand())+2.0), seed=" + rd.nextInt(1000) + "))^as.scalar(rand()))\n"); + break; + case "FLOAT": + sb.append(var.getId() + " = cos(as.scalar(rand(min=(as.scalar(rand())+1.0), max=(as.scalar(rand())+2.0), seed=" + rd.nextInt(1000) + "))^as.scalar(rand()))\n"); + break; + case "INT": + sb.append(var.getId() + " = as.integer(cos(as.scalar(rand(min=(as.scalar(rand())+1.0), max=(as.scalar(rand()+200000.0)), seed=" + rd.nextInt(1000) + "))^as.scalar(rand())))\n"); + break; + case "BOOL": + sb.append(var.getId() + " = as.scalar(rand()) < 0.5\n"); + break; + default: + throw new NotImplementedException(var.getResultingDataType(ctx)); + } + } + + return sb.toString(); + } + + public static String generateEqualityCheck(String stmt1Var, String stmt2Var, String dataType, double eps) { + switch (dataType) { + case "MATRIX": + return "sum(abs(" + stmt1Var + " - " + stmt2Var + ") < " + eps + ") == length(" + stmt1Var + ")"; + case "INT": + case "BOOL": + return stmt1Var + " == " + stmt2Var; + case "FLOAT": + return "abs(" + stmt1Var + " - " + stmt2Var + ") < " + eps; + } + + throw new NotImplementedException(); + } + + public static String generateDMLDefs(RewriterStatement stmt) { + Map vars = new HashMap<>(); + + stmt.forEachPostOrder((cur, pred) -> { + if (!cur.isInstruction() && !cur.isLiteral()) + vars.put(cur.getId(), cur); + }, false); + + return generateDMLDefs(vars); + } + + public static String generateDMLDefs(Map defs) { + StringBuilder sb = new StringBuilder(); + + defs.forEach((k, v) -> { + sb.append(k); + sb.append(" = "); + sb.append(generateDML(v)); + sb.append('\n'); + }); + + return sb.toString(); + } + + public static String generateDML(RewriterStatement root) { + return generateDML(root, Collections.emptyMap()); + } + + public static String generateDML(RewriterStatement root, Map tmpVars) { + StringBuilder sb = new StringBuilder(); + appendExpression(root, sb, tmpVars); + + return sb.toString(); + } + + private static void appendExpression(RewriterStatement cur, StringBuilder sb, Map tmpVars) { + String tmpVar = tmpVars.get(cur); + + if (tmpVar != null) { + sb.append(tmpVar); + return; + } + + if (cur.isInstruction()) { + if (cur.isDataOrigin()) + sb.append(cur.getId()); + else + resolveExpression((RewriterInstruction) cur, sb, tmpVars); + } else { + if (cur.isLiteral()) + sb.append(cur.getLiteral()); + else + sb.append(cur.getId()); + } + } + + private static void resolveExpression(RewriterInstruction expr, StringBuilder sb, Map tmpVars) { + String typedInstr = expr.trueTypedInstruction(ctx); + String unTypedInstr = expr.trueInstruction(); + + if (expr.getOperands().size() == 2 && (printAsBinary.contains(typedInstr) || printAsBinary.contains(unTypedInstr))) { + sb.append('('); + appendExpression(expr.getChild(0), sb, tmpVars); + sb.append(") "); + sb.append(unTypedInstr); + sb.append(" ("); + appendExpression(expr.getChild(1), sb, tmpVars); + sb.append(')'); + return; + } + + TriFunction, Boolean> customEncoder = customEncoders.get(typedInstr); + + if (customEncoder == null) + customEncoder = customEncoders.get(unTypedInstr); + + if (customEncoder == null) { + sb.append(unTypedInstr); + sb.append('('); + + for (int i = 0; i < expr.getOperands().size(); i++) { + if (i != 0) + sb.append(", "); + + appendExpression(expr.getChild(i), sb, tmpVars); + } + + sb.append(')'); + } else { + customEncoder.apply(expr, sb, tmpVars); + } + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/dml/DMLExecutor.java b/src/main/java/org/apache/sysds/hops/rewriter/dml/DMLExecutor.java new file mode 100644 index 00000000000..0b07a84a7e9 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/dml/DMLExecutor.java @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter.dml; + +import org.apache.commons.lang3.mutable.MutableBoolean; +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.hops.Hop; +import org.apache.sysds.hops.OptimizerUtils; + +import java.io.OutputStream; +import java.io.PrintStream; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Consumer; +import java.util.function.Function; + +public class DMLExecutor { + private static PrintStream origPrintStream = System.out; + private static PrintStream origErrPrintStream = System.out; + + public static boolean APPLY_INJECTED_REWRITES = false; + public static Function REWRITE_FUNCTION = null; + + private static List lastErr; + + public static void executeCode(String code, boolean intercept, String... additionalArgs) { + executeCode(code, intercept ? s -> {} : null, additionalArgs); + } + + // Returns if true if the run was successful without any errors + public static boolean executeCode(String code, Consumer consoleInterceptor, String... additionalArgs) { + return executeCode(code, consoleInterceptor, null, additionalArgs); + } + + // This cannot run in parallel + public static synchronized boolean executeCode(String code, Consumer consoleInterceptor, Function injectedRewriteClass, String... additionalArgs) { + lastErr = new ArrayList<>(); + boolean exceptionOccurred = false; + + try { + if (consoleInterceptor != null) + System.setOut(new PrintStream(new CustomOutputStream(System.out, consoleInterceptor))); + + System.setErr(new PrintStream(new CustomOutputStream(System.err, lastErr::add))); + + String[] args = new String[additionalArgs.length + 2]; + + for (int i = 0; i < additionalArgs.length; i++) + args[i] = additionalArgs[i]; + + args[additionalArgs.length] = "-s"; + args[additionalArgs.length + 1] = code; + + if (injectedRewriteClass != null) { + APPLY_INJECTED_REWRITES = true; + REWRITE_FUNCTION = injectedRewriteClass; + } + + // To allow the discovery of sum((a*A)*B) which would usually be converted to n* + OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = false; + OptimizerUtils.ALLOW_OPERATOR_FUSION = false; + + DMLScript.executeScript(args); + + } catch (Exception e) { + e.printStackTrace(); + exceptionOccurred = true; + } + + APPLY_INJECTED_REWRITES = false; + REWRITE_FUNCTION = null; + + if (consoleInterceptor != null) + System.setOut(origPrintStream); + + System.setErr(origErrPrintStream); + + return !exceptionOccurred && lastErr.isEmpty(); + } + + public static List getLastErr() { + return lastErr; + } + + // Bypasses the interceptor + public static void println(Object o) { + origPrintStream.println(o); + } + + private static class CustomOutputStream extends OutputStream { + private PrintStream ps; + private StringBuilder buffer = new StringBuilder(); + private Consumer lineHandler; + + public CustomOutputStream(PrintStream actualPrintStream, Consumer lineHandler) { + this.ps = actualPrintStream; + this.lineHandler = lineHandler; + } + + @Override + public void write(int b) { + char c = (char) b; + if (c == '\n') { + lineHandler.accept(buffer.toString()); + buffer.setLength(0); // Clear the buffer after handling the line + } else { + buffer.append(c); // Accumulate characters until newline + } + } + + @Override + public void write(byte[] b, int off, int len) { + for (int i = off; i < off + len; i++) { + write(b[i]); + } + } + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/estimators/RewriterCostEstimator.java b/src/main/java/org/apache/sysds/hops/rewriter/estimators/RewriterCostEstimator.java new file mode 100644 index 00000000000..658a1114214 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/estimators/RewriterCostEstimator.java @@ -0,0 +1,947 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter.estimators; + +import org.apache.commons.lang3.mutable.MutableInt; +import org.apache.commons.lang3.mutable.MutableLong; +import org.apache.commons.lang3.mutable.MutableObject; +import org.apache.sysds.hops.rewriter.RewriterInstruction; +import org.apache.sysds.hops.rewriter.rule.RewriterRule; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.apache.sysds.hops.rewriter.assertions.RewriterAssertionUtils; +import org.apache.sysds.hops.rewriter.assertions.RewriterAssertions; +import org.apache.sysds.hops.rewriter.utils.StatementUtils; +import scala.Tuple2; +import scala.Tuple3; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Random; +import java.util.Set; +import java.util.UUID; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.stream.Collectors; + +public class RewriterCostEstimator { + private static final long INSTRUCTION_OVERHEAD = 10; + private static final long MALLOC_COST = 10000; + public static final Function DEFAULT_COST_FN = el -> 2000L; + public static final BiFunction, Long> DEFAULT_NNZ_FN = (el, tpl) -> tpl._1 * tpl._2; + + // This is an important check as many intermediate matrices do not contain any sparsity information + // Thus, we want to use cost functions without sparsity information if possible + public static boolean doesHaveAnImpactOnOptimalExpression(List, Long, Long>> list, boolean sparsity, boolean sort, int costThreshhold) { + if (sort) + sort(list); + + int diff = 0; + Tuple3, Long, Long> last = null; + + for (Tuple3, Long, Long> t : list) { + if (Math.abs(t._2() - t._3()) < costThreshhold) + continue; + + if (last == null || (sparsity && !hasSameDims(last._1(), t._1()))) { + last = t; + diff = Long.signum(t._2() - t._3()); + continue; + } + + int mDiff = Long.signum(t._2() - t._3()); + + if (diff != mDiff && Math.abs(t._2() - t._3() - last._2() + last._3()) > costThreshhold) + return true; + } + + return false; + } + + private static boolean hasSameDims(List l1, List l2) { + int maxN = Math.min(l1.size(), l2.size()); + + for (int i = 0; i < maxN; i++) { + Number el1 = l1.get(i); + Number el2 = l2.get(i); + + if (el1 instanceof Long && el1.longValue() != el2.longValue()) + return false; + } + + return true; + } + + private static void sort(List, Long, Long>> list) { + list.sort((t1, t2) -> { + int size = Math.min(t1._1().size(), t2._1().size()); + for (int i = 0; i < size; i++) { + int cmp = Double.compare(t1._1().get(i).doubleValue(), t2._1().get(i).doubleValue()); + if (cmp != 0) + return cmp; // Return non-zero comparison result if elements differ + } + + return Integer.compare(t1._1().size(), t2._1().size()); + }); + } + + public static Set> findOptima(List, List>> data) { + Set> outSet = new HashSet<>(); + data.stream().forEach(t -> { + int minIdx = -1; + long minValue = Long.MAX_VALUE; + for (int i = 0; i < t._2.size(); i++) { + if (t._2.get(i) < minValue) { + minValue = t._2.get(i); + minIdx = i; + } + } + + for (int i = 0; i < t._2.size(); i++) { + if (t._2.get(i) > minValue) + outSet.add(new Tuple2<>(i, minIdx)); + } + }); + + return outSet; + } + + public static List, List>> compareCosts(List statements, RewriterAssertions jointAssertions, final RuleContext ctx, boolean sample, int sampleSize) { + List> estimates = statements.stream().map(stmt -> RewriterSparsityEstimator.estimateAllNNZ(stmt, ctx)).collect(Collectors.toList()); + + MutableObject assertionRef = new MutableObject<>(jointAssertions); + List costFns = statements.stream().map(stmt -> getRawCostFunction(stmt, ctx, assertionRef, false)).collect(Collectors.toList()); + + for (int i = 0; i < estimates.size(); i++) { + costFns.set(i, RewriterSparsityEstimator.rollupSparsities(costFns.get(i), estimates.get(i), ctx)); + } + + long[] dimVals = new long[] {10, 5000}; + double[] sparsities = new double[] {1.0D, 0.000001D}; + + Map createdObjects = new HashMap<>(); + List costFnCpys = costFns.stream().map(fn -> fn.nestedCopy(false, createdObjects)).collect(Collectors.toList()); + RewriterAssertions jointAssertionsCpy = RewriterAssertions.copy(jointAssertions, createdObjects, false); + + Set dimsToPopulate = new HashSet<>(); + Set nnzsToPopulate = new HashSet<>(); + + List costs = costFnCpys.stream().map(costFnCpy -> { + try { + return computeCostFunction(costFnCpy, el -> { + dimsToPopulate.add(el); + return 2000L; + }, (nnz, tpl) -> { + nnzsToPopulate.add(nnz.getChild(0)); + return tpl._1 * tpl._2; + }, jointAssertionsCpy, ctx); + } catch (Exception e) { + //e.printStackTrace(); + System.err.println("Error while estimating the cost: " + e.getMessage()); + return null; + } + }).collect(Collectors.toList()); + + int nDimsToPopulate = dimsToPopulate.size(); + int nNNZsToPopulate = nnzsToPopulate.size(); + + List firstList = new ArrayList<>(); + for (int i = 0; i < nDimsToPopulate; i++) + firstList.add(2000L); + for (int i = 0; i < nNNZsToPopulate; i++) + firstList.add(1.0D); + + List, List>> out = new ArrayList<>(); + out.add(new Tuple2<>(firstList, costs)); + + if (sampleSize < 2) + return out; + + List> nums = new ArrayList<>(); + List dimList = Arrays.stream(dimVals).mapToObj(dim -> ((Number)dim)).collect(Collectors.toList()); + List sparsityList = Arrays.stream(sparsities).mapToObj(s -> ((Number)s)).collect(Collectors.toList()); + + int numCombinations = 1; + + for (int i = 0; i < nDimsToPopulate; i++) { + nums.add(dimList); + numCombinations *= dimList.size(); + } + + for (int i = 0; i < nNNZsToPopulate; i++) { + nums.add(sparsityList); + numCombinations *= sparsityList.size(); + } + + Set samples = new HashSet<>(); + + if (sample) { + if (sampleSize < numCombinations) { + Random rd = new Random(); + + while (samples.size() < sampleSize) + samples.add(rd.nextInt(numCombinations)); + } else { + sample = false; + } + } + + final boolean doSample = sample; + + MutableInt ctr = new MutableInt(); + + if (nums.size() > 16) { + System.err.println("Could not properly sample: " + statements); + return out; + } + + RewriterUtils.cartesianProduct(nums, new Number[nums.size()], stack -> { + if (doSample && !samples.contains(ctr.getAndIncrement())) + return true; + + int sparsityStart = 0; + + for (Number num : stack) { + if (num instanceof Double) + break; + + sparsityStart++; + } + + final int fSparsityStart = sparsityStart; + + Map replace = new HashMap<>(); + + MutableInt dimCtr = new MutableInt(); + MutableInt sCtr = new MutableInt(); + + Map mCreatedObjects = new HashMap<>(); + List mCostFnCpys = costFns.stream().map(cpy -> cpy.nestedCopy(false, mCreatedObjects)).collect(Collectors.toList()); + RewriterAssertions mAssertionsCpy = RewriterAssertions.copy(jointAssertions, mCreatedObjects, false); + + List mCosts = mCostFnCpys.stream().map(mCpy -> { + try { + return computeCostFunction(mCpy, el -> { + Long literal = replace.get(el); + + if (literal == null) { + literal = (Long) stack[dimCtr.getAndIncrement()]; + //System.out.println("populated size with: " + literal); + replace.put(el, literal); + } + + return literal; + }, (nnz, tpl) -> { + Long literal = replace.get(nnz.getChild(0)); + + if (literal == null) { + double sparsity = (double) stack[fSparsityStart + sCtr.getAndIncrement()]; + literal = (long) Math.ceil(sparsity * tpl._1 * tpl._2); + replace.put(nnz.getChild(0), literal); + } + + return literal; + }, mAssertionsCpy, ctx); + } catch (Exception e) { + e.printStackTrace(); + return null; + } + }).collect(Collectors.toList()); + + out.add(new Tuple2<>(new ArrayList<>(Arrays.asList(stack)), mCosts)); + + return true; + }); + + return out; + } + + // Computes the cost of an expression using different matrix dimensions and sparsities + public static List, Long, Long>> compareCosts(RewriterStatement stmt1, RewriterStatement stmt2, RewriterAssertions jointAssertions, final RuleContext ctx, boolean sample, int sampleSize, boolean returnOnDifference) { + Map estimates1 = RewriterSparsityEstimator.estimateAllNNZ(stmt1, ctx); + Map estimates2 = RewriterSparsityEstimator.estimateAllNNZ(stmt2, ctx); + + MutableObject assertionRef = new MutableObject<>(jointAssertions); + RewriterStatement costFn1 = getRawCostFunction(stmt1, ctx, assertionRef, false); + RewriterStatement costFn2 = getRawCostFunction(stmt2, ctx, assertionRef, false); + + costFn1 = RewriterSparsityEstimator.rollupSparsities(costFn1, estimates1, ctx); + costFn2 = RewriterSparsityEstimator.rollupSparsities(costFn2, estimates2, ctx); + + final RewriterStatement fCostFn1 = costFn1; + final RewriterStatement fCostFn2 = costFn2; + + long[] dimVals = new long[] {10, 5000}; + double[] sparsities = new double[] {1.0D, 0.05D}; + + Map createdObjects = new HashMap<>(); + RewriterStatement costFn1Cpy = costFn1.nestedCopy(true, createdObjects); + RewriterStatement costFn2Cpy = costFn2.nestedCopy(false, createdObjects); + RewriterAssertions jointAssertionsCpy = RewriterAssertions.copy(jointAssertions, createdObjects, false); + + Set dimsToPopulate = new HashSet<>(); + Set nnzsToPopulate = new HashSet<>(); + + long cost1 = computeCostFunction(costFn1Cpy, el -> { + dimsToPopulate.add(el); + return 2000L; + }, (nnz, tpl) -> { + nnzsToPopulate.add(nnz.getChild(0)); + return tpl._1 * tpl._2; + }, jointAssertionsCpy, ctx); + long cost2 = computeCostFunction(costFn2Cpy, el -> { + dimsToPopulate.add(el); + return 2000L; + }, (nnz, tpl) -> { + nnzsToPopulate.add(nnz.getChild(0)); + return tpl._1 * tpl._2; + }, jointAssertionsCpy, ctx); + + int nDimsToPopulate = dimsToPopulate.size(); + int nNNZsToPopulate = nnzsToPopulate.size(); + + List firstList = new ArrayList<>(); + for (int i = 0; i < nDimsToPopulate; i++) + firstList.add(2000L); + for (int i = 0; i < nNNZsToPopulate; i++) + firstList.add(1.0D); + + List, Long, Long>> out = new ArrayList<>(); + out.add(new Tuple3<>(firstList, cost1, cost2)); + + if (returnOnDifference && cost1 != cost2) + return out; + + List> nums = new ArrayList<>(); + List dimList = Arrays.stream(dimVals).mapToObj(dim -> ((Number)dim)).collect(Collectors.toList()); + List sparsityList = Arrays.stream(sparsities).mapToObj(s -> ((Number)s)).collect(Collectors.toList()); + + int numCombinations = 1; + + for (int i = 0; i < nDimsToPopulate; i++) { + nums.add(dimList); + numCombinations *= dimList.size(); + } + + for (int i = 0; i < nNNZsToPopulate; i++) { + nums.add(sparsityList); + numCombinations *= sparsityList.size(); + } + + Set samples = new HashSet<>(); + + if (sample) { + if (sampleSize < numCombinations) { + Random rd = new Random(); + + while (samples.size() < sampleSize) + samples.add(rd.nextInt(numCombinations)); + } else { + sample = false; + } + } + + final boolean doSample = sample; + + MutableInt ctr = new MutableInt(); + + RewriterUtils.cartesianProduct(nums, new Number[nums.size()], stack -> { + if (doSample && !samples.contains(ctr.getAndIncrement())) + return true; + + int sparsityStart = 0; + + for (Number num : stack) { + if (num instanceof Double) + break; + + sparsityStart++; + } + + final int fSparsityStart = sparsityStart; + + Map replace = new HashMap<>(); + + MutableInt dimCtr = new MutableInt(); + MutableInt sCtr = new MutableInt(); + + Map mCreatedObjects = new HashMap<>(); + RewriterStatement mCpy1 = fCostFn1.nestedCopy(false, mCreatedObjects); + RewriterStatement mCpy2 = fCostFn2.nestedCopy(false, mCreatedObjects); + RewriterAssertions mAssertionsCpy = RewriterAssertions.copy(jointAssertions, mCreatedObjects, false); + + long mCost1 = computeCostFunction(mCpy1, el -> { + Long literal = replace.get(el); + + if (literal == null) { + literal = (Long) stack[dimCtr.getAndIncrement()]; + replace.put(el, literal); + } + + return literal; + }, (nnz, tpl) -> { + Long literal = replace.get(nnz.getChild(0)); + + if (literal == null) { + double sparsity = (double) stack[fSparsityStart + sCtr.getAndIncrement()]; + literal = (long)Math.ceil(sparsity * tpl._1 * tpl._2); + replace.put(nnz.getChild(0), literal); + } + + return literal; + }, mAssertionsCpy, ctx); + long mCost2 = computeCostFunction(mCpy2, el -> { + Long literal = replace.get(el); + + if (literal == null) { + literal = (Long) stack[dimCtr.getAndIncrement()]; + replace.put(el, literal); + } + + return literal; + }, (nnz, tpl) -> { + Long literal = replace.get(nnz.getChild(0)); + + if (literal == null) { + double sparsity = (double) stack[fSparsityStart + sCtr.getAndIncrement()]; + literal = (long)Math.ceil(sparsity * tpl._1 * tpl._2); + replace.put(nnz.getChild(0), literal); + } + + return literal; + }, mAssertionsCpy, ctx); + + out.add(new Tuple3<>(new ArrayList<>(Arrays.asList(stack)), mCost1, mCost2)); + + return !returnOnDifference || mCost1 == mCost2; + }); + + return out; + } + + public static Tuple2, Boolean> determineSingleReferenceRequirement(RewriterRule rule, final RuleContext ctx) { + MutableObject assertionRef = new MutableObject<>(); + long fullCost = RewriterCostEstimator.estimateCost(rule.getStmt1(), ctx, assertionRef); + long maxCost = RewriterCostEstimator.estimateCost(rule.getStmt2(), ctx); + return RewriterCostEstimator.determineSingleReferenceRequirement(rule.getStmt2(), RewriterCostEstimator.DEFAULT_COST_FN, RewriterCostEstimator.DEFAULT_NNZ_FN, assertionRef.getValue(), fullCost, maxCost, ctx); + } + + public static Tuple2, Boolean> determineSingleReferenceRequirement(RewriterStatement root, Function costFn, RewriterAssertions assertions, long fullCost, long maxCost, final RuleContext ctx) { + return determineSingleReferenceRequirement(root, costFn, RewriterCostEstimator.DEFAULT_NNZ_FN, assertions, fullCost, maxCost, ctx); + } + + // Returns all (upmost) sub-DAGs that can have multiple references and true as a second arg if all statements can have multiple references at once + public static Tuple2, Boolean> determineSingleReferenceRequirement(RewriterStatement root, Function costFn, BiFunction, Long> nnzFn, RewriterAssertions assertions, long fullCost, long maxCost, final RuleContext ctx) { + if (fullCost >= maxCost) + return new Tuple2<>(Collections.emptySet(), true); + + List> subDAGCosts = new ArrayList<>(); + + root.forEachPreOrder((cur, pred) -> { + if (pred.isRoot() || !cur.isInstruction()) + return true; + + long cost = estimateCost(cur, costFn, nnzFn, ctx, new MutableObject<>(assertions)); + + if (fullCost + cost <= maxCost) { + subDAGCosts.add(new Tuple2<>(cur, cost)); + return false; + } + + return true; + }, true); + + boolean canCombine = true; + long curCost = fullCost; + + for (Tuple2 t : subDAGCosts) { + curCost += t._2; + + if (curCost > maxCost) { + canCombine = false; + break; + } + } + + return new Tuple2<>(subDAGCosts.stream().map(t -> t._1).collect(Collectors.toSet()), canCombine); + } + + public static long estimateCost(RewriterStatement stmt, final RuleContext ctx) { + return estimateCost(stmt, DEFAULT_COST_FN, ctx); + } + + public static long estimateCost(RewriterStatement stmt, final RuleContext ctx, MutableObject assertionRef) { + return estimateCost(stmt, DEFAULT_COST_FN, DEFAULT_NNZ_FN, ctx, assertionRef); + } + + public static long estimateCost(RewriterStatement stmt, Function propertyGenerator, final RuleContext ctx) { + return estimateCost(stmt, propertyGenerator, DEFAULT_NNZ_FN, ctx, null); + } + + public static long estimateCost(RewriterStatement stmt, Function propertyGenerator, BiFunction, Long> nnzGenerator, final RuleContext ctx, MutableObject assertionRef) { + if (assertionRef == null) + assertionRef = new MutableObject<>(); + + RewriterStatement costFn = getRawCostFunction(stmt, ctx, assertionRef, false); + return computeCostFunction(costFn, propertyGenerator, nnzGenerator, assertionRef.getValue(), ctx); + } + + public static RewriterStatement getRawCostFunction(RewriterStatement stmt, final RuleContext ctx, MutableObject assertionRef, boolean treatAsDense) { + RewriterAssertions assertions = assertionRef != null && assertionRef.getValue() != null ? assertionRef.getValue() : new RewriterAssertions(ctx); + + if (assertionRef != null) + assertionRef.setValue(assertions); + + RewriterStatement costFn = propagateCostFunction(stmt, ctx, assertions, treatAsDense); + Map estimations = RewriterSparsityEstimator.estimateAllNNZ(costFn, ctx); + RewriterSparsityEstimator.rollupSparsities(costFn, estimations, ctx); + costFn = assertions.update(costFn); + costFn = RewriterUtils.foldConstants(costFn, ctx); + + return costFn; + } + + public static long computeCostFunction(RewriterStatement costFn, Function propertyGenerator, BiFunction, Long> nnzGenerator, RewriterAssertions assertions, final RuleContext ctx) { + Map map = new HashMap<>(); + + costFn.forEachPreOrder((cur, pred) -> { + for (int i = 0; i < cur.getOperands().size(); i++) { + RewriterStatement op = cur.getChild(i); + + RewriterStatement mNew = map.get(op); + if (mNew != null) { + cur.getOperands().set(i, mNew); + continue; + } + + if (op.isEClass()) { + RewriterAssertions.RewriterAssertion assertion = assertions.getAssertionObj(op); + Optional literal = assertion != null ? assertion.getLiteral() : Optional.empty(); + + mNew = literal.orElseGet(() -> RewriterStatement.literal(ctx, propertyGenerator.apply(op))); + + map.put(op, mNew); + cur.getOperands().set(i, mNew); + } else if (op.isInstruction()) { + if (op.trueInstruction().equals("ncol") || op.trueInstruction().equals("nrow")) { + RewriterStatement eClassStmt = assertions.getAssertionStatement(op, null); + mNew = RewriterStatement.literal(ctx, propertyGenerator.apply(eClassStmt)); + map.put(eClassStmt, mNew); + cur.getOperands().set(i, mNew); + } + } + } + + return true; + }, false); + + costFn.forEachPreOrder((cur, pred) -> { + for (int i = 0; i < cur.getOperands().size(); i++) { + RewriterStatement op = cur.getChild(i); + + RewriterStatement mNew = map.get(op); + if (mNew != null) { + cur.getOperands().set(i, mNew); + continue; + } + + if (op.isInstruction() && op.trueInstruction().equals("_nnz")) { + RewriterStatement ncolLiteral = map.get(op.getChild(0).getNCol()); + + if (ncolLiteral == null) { + RewriterAssertions.RewriterAssertion assertion = assertions.getAssertionObj(op.getChild(0).getNCol()); + + if (assertion != null) { + RewriterStatement assStmt = assertion.getEClassStmt(ctx, assertions); + ncolLiteral = map.get(assStmt); + + if (ncolLiteral == null) { + ncolLiteral = RewriterStatement.literal(ctx, propertyGenerator.apply(assStmt)); + map.put(assStmt, ncolLiteral); + } + } else { + ncolLiteral = RewriterStatement.literal(ctx, propertyGenerator.apply(op.getChild(0).getNCol())); + map.put(op.getChild(0).getNCol(), ncolLiteral); + } + } + + RewriterStatement nrowLiteral = map.get(op.getChild(0).getNRow()); + + if (nrowLiteral == null) { + RewriterAssertions.RewriterAssertion assertion = assertions.getAssertionObj(op.getChild(0).getNRow()); + + if (assertion != null) { + RewriterStatement assStmt = assertion.getEClassStmt(ctx, assertions); + nrowLiteral = map.get(assStmt); + + if (nrowLiteral == null) { + nrowLiteral = RewriterStatement.literal(ctx, propertyGenerator.apply(assStmt)); + map.put(assStmt, nrowLiteral); + } + } else { + nrowLiteral = RewriterStatement.literal(ctx, propertyGenerator.apply(op.getChild(0).getNRow())); + map.put(op.getChild(0).getNRow(), nrowLiteral); + } + } + + mNew = RewriterStatement.literal(ctx, nnzGenerator.apply(op, new Tuple2<>(nrowLiteral.intLiteral(false), ncolLiteral.intLiteral(false)))); + map.put(op, mNew); + cur.getOperands().set(i, mNew); + } + } + + return true; + }, false); + + costFn.forEachPreOrder(cur -> { + if (cur.isInstruction()) + cur.refreshReturnType(ctx); + + return true; + }, false); + + costFn = RewriterUtils.foldConstants(costFn, ctx); + + if (!costFn.isLiteral()) { + throw new IllegalArgumentException("Cost function must be a literal: " + costFn.toParsableString(ctx)); + } + + if (costFn.getLiteral() instanceof Double) + return (long)((double)costFn.getLiteral()); + + return (long)costFn.getLiteral(); + } + + private static RewriterStatement propagateCostFunction(RewriterStatement stmt, final RuleContext ctx, RewriterAssertions assertions, boolean treatAsDense) { + List includedCosts = new ArrayList<>(); + MutableLong instructionOverhead = new MutableLong(0); + + stmt.forEachPostOrder((cur, pred) -> { + if (!(cur instanceof RewriterInstruction)) + return; + + computeCostOf((RewriterInstruction) cur, ctx, includedCosts, assertions, instructionOverhead, treatAsDense, stmt); + instructionOverhead.add(INSTRUCTION_OVERHEAD); + }, false); + + includedCosts.add(RewriterStatement.literal(ctx, instructionOverhead.longValue())); + + RewriterStatement argList = RewriterStatement.argList(ctx, includedCosts); + RewriterStatement add = new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("+").withOps(argList).consolidate(ctx); + add.unsafePutMeta("_assertions", assertions); + return add; + } + + private static RewriterStatement computeCostOf(RewriterInstruction instr, final RuleContext ctx, List uniqueCosts, RewriterAssertions assertions, MutableLong instructionOverhead, boolean treatAsDense, RewriterStatement exprRoot) { + if (instr.getResultingDataType(ctx).equals("MATRIX")) + return computeMatrixOpCost(instr, ctx, uniqueCosts, assertions, instructionOverhead, treatAsDense, exprRoot); + else + return computeScalarOpCost(instr, ctx, uniqueCosts, assertions, instructionOverhead, treatAsDense, exprRoot); + } + + private static RewriterStatement computeMatrixOpCost(RewriterInstruction instr, final RuleContext ctx, List uniqueCosts, RewriterAssertions assertions, MutableLong overhead, boolean treatAsDense, RewriterStatement exprRoot) { + RewriterAssertionUtils.buildImplicitAssertion(instr, assertions, exprRoot, ctx); + + RewriterStatement cost = null; + Map map = new HashMap<>(); + + switch (instr.trueInstruction()) { + case "%*%": + map.put("A", instr.getChild(0)); + map.put("B", instr.getChild(1)); + map.put("nrowA", instr.getChild(0).getNRow()); + map.put("ncolA", instr.getChild(0).getNCol()); + map.put("nrowB", instr.getChild(1).getNRow()); + map.put("ncolB", instr.getChild(1).getNCol()); + map.put("mulCost", atomicOpCostStmt("*", ctx)); + map.put("sumCost", atomicOpCostStmt("+", ctx)); + map.put("nnzA", RewriterStatement.nnz(instr.getChild(0), ctx, treatAsDense)); + map.put("nnzB", RewriterStatement.nnz(instr.getChild(1), ctx, treatAsDense)); + // Rough estimation + cost = RewriterUtils.parse("*(argList(min(nnzA, nnzB), ncolA, +(argList(mulCost, sumCost))))", ctx, map); + overhead.add(MALLOC_COST); + break; + case "t": + case "rev": + cost = RewriterStatement.nnz(instr.getChild(0), ctx, treatAsDense);//RewriterUtils.parse("_nnz(A)", ctx, map); + overhead.add(MALLOC_COST); + break; + case "rowSums": + case "colSums": + map.put("A", instr.getChild(0)); + map.put("nnzA", RewriterStatement.nnz(instr.getChild(0), ctx, treatAsDense)); + RewriterStatement aoc = atomicOpCostStmt("+", ctx); + map.put("opcost", aoc); + // Rough estimation + cost = RewriterUtils.parse("*(argList(nnzA, opcost))", ctx, map); + overhead.add(MALLOC_COST); + break; + case "diag": + map.put("nrowA", instr.getChild(0).getNRow()); + map.put("nnzA", RewriterStatement.nnz(instr.getChild(0), ctx, treatAsDense)); + map.put("A", instr.getChild(0)); + cost = RewriterUtils.parse("min(nnzA, nrowA)", ctx, map); + overhead.add(MALLOC_COST); + break; + case "cast.MATRIX": + cost = RewriterStatement.literal(ctx, 20L); + break; + case "[]": + cost = RewriterStatement.literal(ctx, 0L); + break; // I assume that nothing is materialized + case "RBind": + case "CBind": + map.put("A", instr.getChild(0)); + map.put("B", instr.getChild(1)); + map.put("nnzA", RewriterStatement.nnz(instr.getChild(0), ctx, treatAsDense)); + map.put("nnzB", RewriterStatement.nnz(instr.getChild(1), ctx, treatAsDense)); + cost = RewriterUtils.parse("+(argList(nnzA, nnzB))", ctx, map); + overhead.add(MALLOC_COST); + break; + case "rand": + cost = RewriterStatement.nnz(instr, ctx, treatAsDense); + overhead.add(MALLOC_COST); + break; + case "1-*": + RewriterStatement subtractionCost = atomicOpCostStmt("-", ctx); + RewriterStatement mulCost = atomicOpCostStmt("*", ctx); + RewriterStatement sparsityAwareMul = RewriterStatement.multiArgInstr(ctx, "*", mulCost, StatementUtils.min(ctx, RewriterStatement.nnz(instr.getChild(0), ctx, treatAsDense), RewriterStatement.nnz(instr.getChild(1), ctx, treatAsDense))); + RewriterStatement oneMinus = RewriterStatement.multiArgInstr(ctx, "*", subtractionCost, instr.getNCol(), instr.getNRow()); + cost = RewriterStatement.multiArgInstr(ctx, "+", oneMinus, sparsityAwareMul); + overhead.add(MALLOC_COST); + break; + case "+*": + RewriterStatement additionCost = atomicOpCostStmt("+", ctx); + mulCost = atomicOpCostStmt("*", ctx); + RewriterStatement sum = RewriterStatement.multiArgInstr(ctx, "+", additionCost, mulCost); + cost = RewriterStatement.multiArgInstr(ctx, "*", sum, StatementUtils.min(ctx, RewriterStatement.nnz(instr.getChild(0), ctx, treatAsDense), RewriterStatement.nnz(instr.getChild(2), ctx, treatAsDense))); + overhead.add(MALLOC_COST + 50); // To make it worse than 1-* + break; + case "-*": + subtractionCost = atomicOpCostStmt("-", ctx); + mulCost = atomicOpCostStmt("*", ctx); + sum = RewriterStatement.multiArgInstr(ctx, "+", subtractionCost, mulCost); + cost = RewriterStatement.multiArgInstr(ctx, "*", sum, StatementUtils.min(ctx, RewriterStatement.nnz(instr.getChild(0), ctx, treatAsDense), RewriterStatement.nnz(instr.getChild(2), ctx, treatAsDense))); + overhead.add(MALLOC_COST + 50); // To make it worse than 1-* + break; + case "*2": + cost = RewriterStatement.multiArgInstr(ctx, "*", atomicOpCostStmt("*2", ctx), RewriterStatement.nnz(instr.getChild(0), ctx, treatAsDense)); + overhead.add(MALLOC_COST); + break; + case "sq": + cost = RewriterStatement.multiArgInstr(ctx, "*", atomicOpCostStmt("sq", ctx), RewriterStatement.nnz(instr.getChild(0), ctx, treatAsDense)); + overhead.add(MALLOC_COST); + break; + case "sqrt": + cost = RewriterStatement.multiArgInstr(ctx, "*", atomicOpCostStmt("sqrt", ctx), RewriterStatement.nnz(instr.getChild(0), ctx, treatAsDense)); + overhead.add(MALLOC_COST); + break; + case "exp": + cost = RewriterStatement.multiArgInstr(ctx, "*", atomicOpCostStmt("exp", ctx), RewriterStatement.nnz(instr.getChild(0), ctx, treatAsDense)); + overhead.add(MALLOC_COST); + break; + case "log_nz": { + // Must be a matrix + RewriterStatement logCost = atomicOpCostStmt("log", ctx); + RewriterStatement twoLogCost = RewriterStatement.multiArgInstr(ctx, "*", RewriterStatement.literal(ctx, 2L), logCost); + RewriterStatement neqCost = atomicOpCostStmt("!=", ctx); + sum = RewriterStatement.multiArgInstr(ctx, "+", neqCost, instr.getOperands().size() == 2 ? twoLogCost : logCost); + cost = RewriterStatement.multiArgInstr(ctx, "*", sum, RewriterStatement.nnz(instr.getChild(0), ctx, treatAsDense)); + overhead.add(MALLOC_COST); + break; + } + case "log": + if (instr.getChild(0).getResultingDataType(ctx).equals("MATRIX")) { + RewriterStatement logCost = atomicOpCostStmt("log", ctx); + RewriterStatement twoLogCost = RewriterStatement.multiArgInstr(ctx, "*", RewriterStatement.literal(ctx, 2L), logCost); + cost = RewriterStatement.multiArgInstr(ctx, "*", instr.getOperands().size() == 2 ? twoLogCost : logCost, instr.getNCol(), instr.getNRow()); + overhead.add(MALLOC_COST); + } else { + RewriterStatement logCost = atomicOpCostStmt("log", ctx); + RewriterStatement twoLogCost = RewriterStatement.multiArgInstr(ctx, "*", RewriterStatement.literal(ctx, 2L), logCost); + cost = instr.getOperands().size() == 2 ? twoLogCost : logCost; + } + break; + case "const": + case "rowVec": + case "colVec": + case "cellMat": + cost = RewriterStatement.literal(ctx, 0L); + break; + } + + if (cost == null) { + if (instr.hasProperty("ElementWiseInstruction", ctx)) { + RewriterStatement firstMatrix = null; + RewriterStatement secondMatrix = null; + if (instr.getChild(0).getResultingDataType(ctx).equals("MATRIX")) { + firstMatrix = instr.getChild(0); + } + + if (instr.getChild(1).getResultingDataType(ctx).equals("MATRIX")) { + if (firstMatrix == null) + firstMatrix = instr.getChild(1); + else + secondMatrix = instr.getChild(1); + } + + RewriterStatement opCost = atomicOpCostStmt(instr.trueInstruction(), ctx); + + if (firstMatrix != null) { + switch (instr.trueInstruction()) { + case "*": + cost = new RewriterInstruction().as(UUID.randomUUID().toString()) + .withInstruction("*") + .withOps(RewriterStatement.argList(ctx, opCost, secondMatrix != null ? StatementUtils.min(ctx, RewriterStatement.nnz(firstMatrix, ctx, treatAsDense), RewriterStatement.nnz(secondMatrix, ctx, treatAsDense)) : RewriterStatement.nnz(firstMatrix, ctx, treatAsDense))); + break; + case "/": + if (instr.getChild(0).getResultingDataType(ctx).equals("MATRIX")) + cost = new RewriterInstruction().as(UUID.randomUUID().toString()) + .withInstruction("*") + .withOps(RewriterStatement.argList(ctx, opCost, RewriterStatement.nnz(instr.getChild(0), ctx, treatAsDense))); + else + cost = new RewriterInstruction().as(UUID.randomUUID().toString()) + .withInstruction("*") + .withOps(RewriterStatement.argList(ctx, opCost, StatementUtils.length(ctx, firstMatrix))); + + break; + case "+": + case "-": + cost = new RewriterInstruction().as(UUID.randomUUID().toString()) + .withInstruction("*") + .withOps(RewriterStatement.argList(ctx, opCost, secondMatrix != null ? StatementUtils.add(ctx, RewriterStatement.nnz(firstMatrix, ctx, treatAsDense), RewriterStatement.nnz(secondMatrix, ctx, treatAsDense)) : RewriterStatement.nnz(firstMatrix, ctx, treatAsDense))); + break; + default: + cost = RewriterStatement.multiArgInstr(ctx, "*", opCost, instr.getNRow(), instr.getNCol()); + break; + } + + overhead.add(MALLOC_COST); + } else { + cost = opCost; + } + } else if (instr.hasProperty("UnaryElementWiseOperator", ctx)) { + RewriterStatement opCost = atomicOpCostStmt(instr.trueInstruction(), ctx); + cost = new RewriterInstruction().as(UUID.randomUUID().toString()) + .withInstruction("*") + .withOps(RewriterStatement.argList(ctx, opCost, RewriterStatement.nnz(instr.getChild(0), ctx, treatAsDense))); + overhead.add(MALLOC_COST); + } else { + throw new IllegalArgumentException("Unknown instruction: " + instr.trueTypedInstruction(ctx)); + } + } + + uniqueCosts.add(cost); + return cost; + } + + private static RewriterStatement computeScalarOpCost(RewriterInstruction instr, final RuleContext ctx, List uniqueCosts, RewriterAssertions assertions, MutableLong overhead, boolean treatAsDense, RewriterStatement exprRoot) { + RewriterAssertionUtils.buildImplicitAssertion(instr, assertions, exprRoot, ctx); + Map map = new HashMap<>(); + switch (instr.trueTypedInstruction(ctx)) { + case "sum(MATRIX)": + case "min(MATRIX)": + case "max(MATRIX)": + map.put("nnzA", RewriterStatement.nnz(instr.getChild(0), ctx, treatAsDense)); + uniqueCosts.add(RewriterUtils.parse("nnzA", ctx, map)); + return uniqueCosts.get(uniqueCosts.size()-1); + case "sumSq(MATRIX)": + map.put("nnzA", RewriterStatement.nnz(instr.getChild(0), ctx, treatAsDense)); + uniqueCosts.add(RewriterStatement.multiArgInstr(ctx, "*", RewriterStatement.nnz(instr.getChild(0), ctx, treatAsDense), RewriterStatement.literal(ctx, 2L))); + return uniqueCosts.get(uniqueCosts.size()-1); + case "trace(MATRIX)": + uniqueCosts.add(StatementUtils.min(ctx, RewriterStatement.nnz(instr.getChild(0), ctx, treatAsDense), instr.getChild(0).getNRow())); + return uniqueCosts.get(uniqueCosts.size()-1); + case "[](MATRIX,INT,INT)": + return RewriterStatement.literal(ctx, 0L); + case "cast.FLOAT(MATRIX)": + return RewriterStatement.literal(ctx, INSTRUCTION_OVERHEAD); + case "const(MATRIX,FLOAT)": + case "_nnz(MATRIX)": + return RewriterStatement.literal(ctx, 0L); + } + + double opCost = atomicOpCost(instr.trueInstruction()); + uniqueCosts.add(RewriterUtils.parse(Double.toString(opCost), ctx, "LITERAL_FLOAT:" + opCost)); + return uniqueCosts.get(uniqueCosts.size()-1); + } + + private static RewriterStatement atomicOpCostStmt(String op, final RuleContext ctx) { + double opCost = atomicOpCost(op); + return RewriterUtils.parse(Double.toString(opCost), ctx, "LITERAL_FLOAT:" + opCost); + } + + private static double atomicOpCost(String op) { + switch (op) { + case "+": + case "-": + return 1; + case "*": + return 2; + case "*2": + return 1; // To make *2 cheaper than A+A + case "/": + case "inv": + return 3; + case "length": + case "nrow": + case "ncol": + case "_nnz": + return 0; // These just fetch metadata + case "sqrt": + return 10; + case "sq": + return 1.8; // To make it cheaper than *(A,A) + case "exp": + case "log": + case "^": + return 20; + case "!": + case "|": + case "&": + case ">": + case ">=": + case "<": + case "<=": + case "==": + case "!=": + return 1; + case "round": + return 2; + case "abs": + return 2; + case "cast.FLOAT": + return 1; + case "literal.FLOAT": + case "literal.INT": + case "literal.BOOL": + return 0; + } + + throw new IllegalArgumentException("Unknown instruction: " + op); + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/estimators/RewriterSparsityEstimator.java b/src/main/java/org/apache/sysds/hops/rewriter/estimators/RewriterSparsityEstimator.java new file mode 100644 index 00000000000..22de98abcb2 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/estimators/RewriterSparsityEstimator.java @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter.estimators; + +import org.apache.sysds.hops.rewriter.utils.ConstantFoldingUtils; +import org.apache.sysds.hops.rewriter.RewriterInstruction; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.apache.sysds.hops.rewriter.utils.StatementUtils; + +import java.util.HashMap; +import java.util.Map; + +public class RewriterSparsityEstimator { + public static RewriterStatement rollupSparsities(RewriterStatement sparsityEstimate, Map sparsityMap, final RuleContext ctx) { + sparsityEstimate.forEachPreOrder(cur -> { + for (int i = 0; i < cur.getOperands().size(); i++) { + RewriterStatement child = cur.getChild(i); + + if (child.isInstruction() && child.trueInstruction().equals("_nnz")) { + RewriterStatement subEstimate = sparsityMap.get(child.getChild(0)); + + if (subEstimate != null) { + cur.getOperands().set(i, subEstimate); + } + } + } + return true; + }, false); + + return sparsityEstimate; + } + + public static Map estimateAllNNZ(RewriterStatement stmt, final RuleContext ctx) { + Map map = new HashMap<>(); + stmt.forEachPostOrder((cur, pred) -> { + RewriterStatement estimation = estimateNNZ(cur, ctx); + if (estimation != null) + map.put(cur, estimation); + }, false); + + return map; + } + + public static RewriterStatement estimateNNZ(RewriterStatement stmt, final RuleContext ctx) { + if (!stmt.isInstruction() || !stmt.getResultingDataType(ctx).equals("MATRIX")) + return null; + switch (stmt.trueInstruction()) { + case "%*%": + RewriterStatement min1 = StatementUtils.min(ctx, RewriterStatement.multiArgInstr(ctx, "*", RewriterStatement.nnz(stmt.getChild(0), ctx), new RewriterInstruction("inv", ctx, stmt.getChild(0).getNRow())), RewriterStatement.literal(ctx, 1.0D)); + RewriterStatement min2 = StatementUtils.min(ctx, RewriterStatement.multiArgInstr(ctx, "*", RewriterStatement.nnz(stmt.getChild(1), ctx), new RewriterInstruction("inv", ctx, stmt.getChild(1).getNCol())), RewriterStatement.literal(ctx, 1.0D)); + return RewriterStatement.multiArgInstr(ctx, "*", min1, min2, stmt.getNRow(), stmt.getNCol()); + } + + switch (stmt.trueTypedInstruction(ctx)) { + case "*(MATRIX,MATRIX)": + return StatementUtils.min(ctx, RewriterStatement.nnz(stmt.getChild(0), ctx), RewriterStatement.nnz(stmt.getChild(1), ctx)); + case "*(MATRIX,FLOAT)": + if (stmt.getChild(1).isLiteral() && ConstantFoldingUtils.overwritesLiteral(((Double) stmt.getChild(1).getLiteral()), "*", ctx) != null) + return RewriterStatement.literal(ctx, 0L); + return RewriterStatement.nnz(stmt.getChild(0), ctx); + case "*(FLOAT,MATRIX)": + if (stmt.getChild(0).isLiteral() && ConstantFoldingUtils.overwritesLiteral(((Double) stmt.getChild(0).getLiteral()), "*", ctx) != null) + return RewriterStatement.literal(ctx, 0L); + return RewriterStatement.nnz(stmt.getChild(1), ctx); + case "+(MATRIX,MATRIX)": + case "-(MATRIX,MATRIX)": + return StatementUtils.min(ctx, RewriterStatement.multiArgInstr(ctx, "+", RewriterStatement.nnz(stmt.getChild(0), ctx), RewriterStatement.nnz(stmt.getChild(1), ctx)), StatementUtils.length(ctx, stmt)); + case "+(MATRIX,FLOAT)": + case "-(MATRIX,FLOAT)": + if (stmt.getChild(1).isLiteral() && ConstantFoldingUtils.isNeutralElement(stmt.getChild(1).getLiteral(), "+")) + return RewriterStatement.nnz(stmt.getChild(0), ctx); + return StatementUtils.length(ctx, stmt); + case "+(FLOAT,MATRIX)": + case "-(FLOAT,MATRIX)": + if (stmt.getChild(0).isLiteral() && ConstantFoldingUtils.isNeutralElement(stmt.getChild(0).getLiteral(), "+")) + return RewriterStatement.nnz(stmt.getChild(1), ctx); + return StatementUtils.length(ctx, stmt); + case "!=(MATRIX,MATRIX)": + if (stmt.getChild(0).equals(stmt.getChild(1))) + return RewriterStatement.literal(ctx, 0L); + return StatementUtils.length(ctx, stmt); + + case "sqrt(MATRIX)": + return RewriterStatement.nnz(stmt.getChild(0), ctx); + + case "diag(MATRIX)": + return StatementUtils.min(ctx, stmt.getNRow(), RewriterStatement.nnz(stmt.getChild(0), ctx)); + + case "/(MATRIX,FLOAT)": + case "/(MATRIX,MATRIX)": + return RewriterStatement.nnz(stmt.getChild(0), ctx); + case "/(FLOAT,MATRIX)": + if (stmt.getChild(0).isLiteral() && ConstantFoldingUtils.isNeutralElement(stmt.getChild(0).getLiteral(), "+")) + return RewriterStatement.literal(ctx, 0L); + return StatementUtils.length(ctx, stmt); + + case "RBind(MATRIX,MATRIX)": + case "CBind(MATRIX,MATRIX)": + return StatementUtils.add(ctx, RewriterStatement.nnz(stmt.getChild(0), ctx), RewriterStatement.nnz(stmt.getChild(1), ctx)); + + // Fused operators + case "log_nz(MATRIX)": + case "*2(MATRIX)": + case "sq(MATRIX)": + case "t(MATRIX)": + case "rev(MATRIX)": + return RewriterStatement.nnz(stmt.getChild(0), ctx); + case "1-*(MATRIX,MATRIX)": + return StatementUtils.length(ctx, stmt); + case "+*(MATRIX,FLOAT,MATRIX)": + case "-*(MATRIX,FLOAT,MATRIX)": + if (stmt.getChild(1).isLiteral() && ConstantFoldingUtils.isNeutralElement(stmt.getChild(1).getLiteral(), "+")) + return RewriterStatement.nnz(stmt.getChild(0), ctx); + return StatementUtils.min(ctx, RewriterStatement.multiArgInstr(ctx, "+", RewriterStatement.nnz(stmt.getChild(0), ctx), RewriterStatement.nnz(stmt.getChild(2), ctx)), StatementUtils.length(ctx, stmt)); + case "const(MATRIX,FLOAT)": + if (stmt.getChild(1).isLiteral() && ConstantFoldingUtils.isNeutralElement(stmt.getChild(1).getLiteral(), "+")) + return RewriterStatement.literal(ctx, 0L); + case "rowSums(MATRIX)": + case "colSums(MATRIX)": + StatementUtils.min(ctx, RewriterStatement.nnz(stmt.getChild(0), ctx), StatementUtils.length(ctx, stmt)); + } + + return StatementUtils.length(ctx, stmt); + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/generated/GeneratedRewriteClass.java b/src/main/java/org/apache/sysds/hops/rewriter/generated/GeneratedRewriteClass.java new file mode 100644 index 00000000000..09aea53048d --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/generated/GeneratedRewriteClass.java @@ -0,0 +1,4104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter.generated; + +import java.util.ArrayList; +import java.util.function.Function; + +import org.apache.sysds.utils.Statistics; +import org.apache.sysds.hops.Hop; +import org.apache.sysds.hops.LiteralOp; +import org.apache.sysds.hops.UnaryOp; +import org.apache.sysds.hops.BinaryOp; +import org.apache.sysds.hops.ReorgOp; +import org.apache.sysds.hops.AggUnaryOp; +import org.apache.sysds.hops.AggBinaryOp; +import org.apache.sysds.hops.DataGenOp; +import org.apache.sysds.hops.TernaryOp; +import org.apache.sysds.common.Types; +import org.apache.sysds.hops.rewrite.HopRewriteUtils; +import org.apache.sysds.hops.rewriter.dml.DMLExecutor; +import org.apache.sysds.hops.rewriter.RewriterRuntimeUtils; + +public class GeneratedRewriteClass implements Function { + + @Override + public Object apply( Object _hi ) { + if ( _hi == null ) + return null; + + Hop hi = (Hop) _hi; + + if ( hi.getDataType() == Types.DataType.SCALAR ) { + hi = _applyRewrite0(hi); // *(0.0,a) => 0.0 + hi = _applyRewrite1(hi); // *(a,0.0) => 0.0 + hi = _applyRewrite23(hi); // sum(/(tmp83271,tmp60732)) => /(sum(tmp83271),tmp60732) + hi = _applyRewrite27(hi); // sum(*(*(tmp8790,tmp30390),tmp97178)) => *(tmp30390,sum(*(tmp97178,tmp8790))) + } else if ( hi.getDataType() == Types.DataType.MATRIX ) { + if ( hi instanceof BinaryOp ) { + if ( (( BinaryOp ) hi ).getOp() == Types.OpOp2.PLUS ) { + if ( hi.getInput().size() == 2 ) { + Hop hi_0 = hi.getInput(0); + Hop hi_1 = hi.getInput(1); + if ( hi_0.getDataType() == Types.DataType.MATRIX ) { + if ( hi_0 instanceof BinaryOp ) { + if ( (( BinaryOp ) hi_0 ).getOp() == Types.OpOp2.MINUS ) { + if ( hi_0.getInput().size() == 2 ) { + Hop hi_0_0 = hi_0.getInput(0); + Hop hi_0_1 = hi_0.getInput(1); + hi = _applyRewrite2(hi); // +(A,0.0) => A + hi = _applyRewrite7(hi); // +(-(0.0,A),B) => -(B,A) + hi = _applyRewrite10(hi); // +(-(A,a),b) => +(A,-(b,a)) + hi = _applyRewrite12(hi); // +(-(a,A),b) => -(+(a,b),A) + hi = _applyRewrite20(hi); // +(-(tmp80035,f12880),tmp63699) => -(+(tmp63699,tmp80035),f12880) + hi = _applyRewrite31(hi); // +(-(a,tmp98488),tmp82242) => +(-(tmp82242,tmp98488),a) + hi = _applyRewrite37(hi); // +(-(*(C,b),d),A) => -(+*(A,b,C),d) + hi = _applyRewrite38(hi); // +(-(*(D,c),B),A) => -(A,-*(B,c,D)) + hi = _applyRewrite39(hi); // +(M9347,*(K,*(M40316,f32765))) => +*(M9347,f32765,*(K,M40316)) + hi = _applyRewrite41(hi); // +(-(f45081,A),B) => +(f45081,-(B,A)) + hi = _applyRewrite42(hi); // +(M9347,*(*(f32765,K),M40316)) => +*(M9347,f32765,*(K,M40316)) + hi = _applyRewrite46(hi); // +(-(b,%*%(C,D)),A) => +(b,-(A,%*%(C,D))) + hi = _applyRewrite54(hi); // +(-(C,d),%*%(A,B)) => -(+(C,%*%(A,B)),d) + } else { + hi = _applyRewrite2(hi); // +(A,0.0) => A + hi = _applyRewrite39(hi); // +(M9347,*(K,*(M40316,f32765))) => +*(M9347,f32765,*(K,M40316)) + hi = _applyRewrite42(hi); // +(M9347,*(*(f32765,K),M40316)) => +*(M9347,f32765,*(K,M40316)) + } + } else if ( (( BinaryOp ) hi_0 ).getOp() == Types.OpOp2.MULT ) { + if ( hi_0.getInput().size() == 2 ) { + Hop hi_0_0 = hi_0.getInput(0); + Hop hi_0_1 = hi_0.getInput(1); + hi = _applyRewrite2(hi); // +(A,0.0) => A + hi = _applyRewrite18(hi); // +(*(*(y_corr,-(float599,is_zero_y_corr)),tmp8608),*(tmp20367,+(tmp23071,tmp55180))) => +(*(*(tmp8608,y_corr),-(float599,is_zero_y_corr)),*(tmp20367,+(tmp55180,tmp23071))) + hi = _applyRewrite32(hi); // +(*(tmp99142,missing_mask_Y),*(tmp58606,missing_mask_Y)) => *(missing_mask_Y,+(tmp99142,tmp58606)) + hi = _applyRewrite39(hi); // +(M9347,*(K,*(M40316,f32765))) => +*(M9347,f32765,*(K,M40316)) + hi = _applyRewrite42(hi); // +(M9347,*(*(f32765,K),M40316)) => +*(M9347,f32765,*(K,M40316)) + hi = _applyRewrite43(hi); // +(*(*(K,f32765),M40316),M9347) => +*(M9347,f32765,*(K,M40316)) + } else { + hi = _applyRewrite2(hi); // +(A,0.0) => A + hi = _applyRewrite39(hi); // +(M9347,*(K,*(M40316,f32765))) => +*(M9347,f32765,*(K,M40316)) + hi = _applyRewrite42(hi); // +(M9347,*(*(f32765,K),M40316)) => +*(M9347,f32765,*(K,M40316)) + } + } else { + hi = _applyRewrite2(hi); // +(A,0.0) => A + hi = _applyRewrite39(hi); // +(M9347,*(K,*(M40316,f32765))) => +*(M9347,f32765,*(K,M40316)) + hi = _applyRewrite42(hi); // +(M9347,*(*(f32765,K),M40316)) => +*(M9347,f32765,*(K,M40316)) + } + } else { + hi = _applyRewrite2(hi); // +(A,0.0) => A + hi = _applyRewrite39(hi); // +(M9347,*(K,*(M40316,f32765))) => +*(M9347,f32765,*(K,M40316)) + hi = _applyRewrite42(hi); // +(M9347,*(*(f32765,K),M40316)) => +*(M9347,f32765,*(K,M40316)) + } + } else if ( hi_0.getDataType() == Types.DataType.SCALAR ) { + hi = _applyRewrite3(hi); // +(0.0,A) => A + hi = _applyRewrite11(hi); // +(a,-(A,b)) => +(A,-(a,b)) + hi = _applyRewrite13(hi); // +(a,-(b,A)) => -(+(a,b),A) + } + } + } else if ( (( BinaryOp ) hi ).getOp() == Types.OpOp2.MINUS ) { + if ( hi.getInput().size() == 2 ) { + Hop hi_0 = hi.getInput(0); + Hop hi_1 = hi.getInput(1); + if ( hi_0.getDataType() == Types.DataType.MATRIX ) { + if ( hi_0 instanceof BinaryOp ) { + if ( (( BinaryOp ) hi_0 ).getOp() == Types.OpOp2.MINUS ) { + if ( hi_0.getInput().size() == 2 ) { + Hop hi_0_0 = hi_0.getInput(0); + Hop hi_0_1 = hi_0.getInput(1); + hi = _applyRewrite4(hi); // -(A,0.0) => A + hi = _applyRewrite14(hi); // -(-(A,a),b) => -(A,+(b,a)) + hi = _applyRewrite16(hi); // -(-(a,A),b) => -(-(a,b),A) + hi = _applyRewrite29(hi); // -(obj,+(tmp6500,tmp26035)) => -(-(obj,tmp6500),tmp26035) + hi = _applyRewrite30(hi); // -(-(tmp68530,tmp73960),tmp29113) => -(tmp68530,+(tmp73960,tmp29113)) + hi = _applyRewrite40(hi); // -(y,+(%*%(X,B),intercept)) => -(-(y,%*%(X,B)),intercept) + hi = _applyRewrite47(hi); // -(-(f43240,A),f67634) => -(-(f43240,f67634),A) + hi = _applyRewrite51(hi); // -(M22650,*(*(f97734,M67673),M97683)) => -*(M22650,f97734,*(M97683,M67673)) + hi = _applyRewrite52(hi); // -(-(f75306,M67233),*(A,M350)) => -(f75306,+(*(A,M350),M67233)) + hi = _applyRewrite53(hi); // -(-(f75306,*(A,M350)),M67233) => -(f75306,+(*(A,M350),M67233)) + } else { + hi = _applyRewrite4(hi); // -(A,0.0) => A + hi = _applyRewrite29(hi); // -(obj,+(tmp6500,tmp26035)) => -(-(obj,tmp6500),tmp26035) + hi = _applyRewrite40(hi); // -(y,+(%*%(X,B),intercept)) => -(-(y,%*%(X,B)),intercept) + hi = _applyRewrite51(hi); // -(M22650,*(*(f97734,M67673),M97683)) => -*(M22650,f97734,*(M97683,M67673)) + } + } else if ( (( BinaryOp ) hi_0 ).getOp() == Types.OpOp2.PLUS ) { + hi = _applyRewrite28(hi); // -(+(a,tmp82242),tmp98488) => +(-(tmp82242,tmp98488),a) + hi = _applyRewrite4(hi); // -(A,0.0) => A + hi = _applyRewrite29(hi); // -(obj,+(tmp6500,tmp26035)) => -(-(obj,tmp6500),tmp26035) + hi = _applyRewrite40(hi); // -(y,+(%*%(X,B),intercept)) => -(-(y,%*%(X,B)),intercept) + hi = _applyRewrite51(hi); // -(M22650,*(*(f97734,M67673),M97683)) => -*(M22650,f97734,*(M97683,M67673)) + } else { + hi = _applyRewrite4(hi); // -(A,0.0) => A + hi = _applyRewrite29(hi); // -(obj,+(tmp6500,tmp26035)) => -(-(obj,tmp6500),tmp26035) + hi = _applyRewrite40(hi); // -(y,+(%*%(X,B),intercept)) => -(-(y,%*%(X,B)),intercept) + hi = _applyRewrite51(hi); // -(M22650,*(*(f97734,M67673),M97683)) => -*(M22650,f97734,*(M97683,M67673)) + } + } else { + hi = _applyRewrite4(hi); // -(A,0.0) => A + hi = _applyRewrite29(hi); // -(obj,+(tmp6500,tmp26035)) => -(-(obj,tmp6500),tmp26035) + hi = _applyRewrite40(hi); // -(y,+(%*%(X,B),intercept)) => -(-(y,%*%(X,B)),intercept) + hi = _applyRewrite51(hi); // -(M22650,*(*(f97734,M67673),M97683)) => -*(M22650,f97734,*(M97683,M67673)) + } + } else if ( hi_0.getDataType() == Types.DataType.SCALAR ) { + hi = _applyRewrite8(hi); // -(0.0,-(B,A)) => -(A,B) + hi = _applyRewrite15(hi); // -(a,-(A,b)) => -(+(a,b),A) + hi = _applyRewrite17(hi); // -(a,-(b,A)) => +(-(a,b),A) + hi = _applyRewrite21(hi); // -(tmp66496,cast.MATRIX(tmp91996)) => cast.MATRIX(-(tmp66496,tmp91996)) + } + } + } else if ( (( BinaryOp ) hi ).getOp() == Types.OpOp2.MULT ) { + if ( hi.getInput().size() == 2 ) { + Hop hi_0 = hi.getInput(0); + Hop hi_1 = hi.getInput(1); + if ( hi_0.getDataType() == Types.DataType.MATRIX ) { + if ( hi_0 instanceof BinaryOp ) { + if ( (( BinaryOp ) hi_0 ).getOp() == Types.OpOp2.DIV ) { + if ( hi_0.getInput().size() == 2 ) { + Hop hi_0_0 = hi_0.getInput(0); + Hop hi_0_1 = hi_0.getInput(1); + hi = _applyRewrite5(hi); // *(A,0.0) => const(A,0.0) + hi = _applyRewrite9(hi); // *(A,/(1.0,B)) => /(A,B) + hi = _applyRewrite19(hi); // *(/(1.0,tmp5995),tmp41945) => /(tmp41945,tmp5995) + hi = _applyRewrite34(hi); // *(/(1.0,B),a) => /(a,B) + hi = _applyRewrite44(hi); // *(/(1.0,M13119),A) => /(A,M13119) + hi = _applyRewrite49(hi); // *(A,/(1.0,M13119)) => /(A,M13119) + } else { + hi = _applyRewrite5(hi); // *(A,0.0) => const(A,0.0) + hi = _applyRewrite9(hi); // *(A,/(1.0,B)) => /(A,B) + hi = _applyRewrite49(hi); // *(A,/(1.0,M13119)) => /(A,M13119) + } + } else if ( (( BinaryOp ) hi_0 ).getOp() == Types.OpOp2.MULT ) { + hi = _applyRewrite25(hi); // *(*(y_corr,-(float599,is_zero_y_corr)),tmp8608) => *(*(y_corr,tmp8608),-(float599,is_zero_y_corr)) + hi = _applyRewrite5(hi); // *(A,0.0) => const(A,0.0) + hi = _applyRewrite9(hi); // *(A,/(1.0,B)) => /(A,B) + hi = _applyRewrite49(hi); // *(A,/(1.0,M13119)) => /(A,M13119) + } else { + hi = _applyRewrite5(hi); // *(A,0.0) => const(A,0.0) + hi = _applyRewrite9(hi); // *(A,/(1.0,B)) => /(A,B) + hi = _applyRewrite49(hi); // *(A,/(1.0,M13119)) => /(A,M13119) + } + } else if ( hi_0 instanceof AggBinaryOp ) { + hi = _applyRewrite26(hi); // *(%*%(scale_lambda,parsertemp150455),tmp43267) => {%*%(*(tmp43267,scale_lambda),parsertemp150455)} + hi = _applyRewrite5(hi); // *(A,0.0) => const(A,0.0) + hi = _applyRewrite9(hi); // *(A,/(1.0,B)) => /(A,B) + hi = _applyRewrite49(hi); // *(A,/(1.0,M13119)) => /(A,M13119) + } else { + hi = _applyRewrite5(hi); // *(A,0.0) => const(A,0.0) + hi = _applyRewrite9(hi); // *(A,/(1.0,B)) => /(A,B) + hi = _applyRewrite49(hi); // *(A,/(1.0,M13119)) => /(A,M13119) + } + } else if ( hi_0.getDataType() == Types.DataType.SCALAR ) { + hi = _applyRewrite6(hi); // *(0.0,A) => const(A,0.0) + hi = _applyRewrite33(hi); // *(tmp43267,%*%(scale_lambda,parsertemp150455)) => {%*%(*(tmp43267,scale_lambda),parsertemp150455)} + hi = _applyRewrite36(hi); // *(a,cast.MATRIX(b)) => cast.MATRIX(*(a,b)) + hi = _applyRewrite50(hi); // *(f68833,-(0.0,M48693)) => *(M48693,-(0.0,f68833)) + } + } + } else if ( (( BinaryOp ) hi ).getOp() == Types.OpOp2.DIV ) { + hi = _applyRewrite35(hi); // /(a,cast.MATRIX(b)) => cast.MATRIX(/(a,b)) + hi = _applyRewrite45(hi); // /(M43656,2.0) => *(0.5,M43656) + hi = _applyRewrite48(hi); // /(M62235,2000.0) => *(5.0E-4,M62235) + } + } else if ( hi instanceof ReorgOp ) { + hi = _applyRewrite22(hi); // t(==(key_unique,t(key))) => ==(key,t(key_unique)) + } else if ( hi instanceof AggBinaryOp ) { + hi = _applyRewrite24(hi); // %*%(t(X_batch),tmp92007) => {t(%*%(t(tmp92007),X_batch))} + } + } + return hi; + } + + // Implementation of the rule *(0.0,a) => 0.0 + private static Hop _applyRewrite0(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MULT || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( !(hi_0 instanceof LiteralOp) ) + return hi; + + LiteralOp l_hi_0 = (LiteralOp) hi_0; + + if ( l_hi_0.getDataType() != Types.DataType.SCALAR|| !l_hi_0.getValueType().isNumeric() ) + return hi; + + if ( l_hi_0.getDoubleValue() != 0.0 ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.SCALAR || !hi_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: 0.0 + + Hop newRoot = hi_0; + if ( hi_0.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_1); + + DMLExecutor.println("Applying rewrite: *(0.0,a) => 0.0"); + return newRoot; + } + + // Implementation of the rule *(a,0.0) => 0.0 + private static Hop _applyRewrite1(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MULT || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( hi_0.getDataType() != Types.DataType.SCALAR || !hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( !(hi_1 instanceof LiteralOp) ) + return hi; + + LiteralOp l_hi_1 = (LiteralOp) hi_1; + + if ( l_hi_1.getDataType() != Types.DataType.SCALAR|| !l_hi_1.getValueType().isNumeric() ) + return hi; + + if ( l_hi_1.getDoubleValue() != 0.0 ) + return hi; + + + // Now, we start building the new HOP-DAG: 0.0 + + Hop newRoot = hi_1; + if ( hi_1.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + + DMLExecutor.println("Applying rewrite: *(a,0.0) => 0.0"); + return newRoot; + } + + // Implementation of the rule +(A,0.0) => A + private static Hop _applyRewrite2(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.PLUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( hi_0.getDataType() != Types.DataType.MATRIX || !hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( !(hi_1 instanceof LiteralOp) ) + return hi; + + LiteralOp l_hi_1 = (LiteralOp) hi_1; + + if ( l_hi_1.getDataType() != Types.DataType.SCALAR|| !l_hi_1.getValueType().isNumeric() ) + return hi; + + if ( l_hi_1.getDoubleValue() != 0.0 ) + return hi; + + + // Now, we start building the new HOP-DAG: A + + Hop newRoot = hi_0; + if ( hi_0.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_1); + + DMLExecutor.println("Applying rewrite: +(A,0.0) => A"); + return newRoot; + } + + // Implementation of the rule +(0.0,A) => A + private static Hop _applyRewrite3(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.PLUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( !(hi_0 instanceof LiteralOp) ) + return hi; + + LiteralOp l_hi_0 = (LiteralOp) hi_0; + + if ( l_hi_0.getDataType() != Types.DataType.SCALAR|| !l_hi_0.getValueType().isNumeric() ) + return hi; + + if ( l_hi_0.getDoubleValue() != 0.0 ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.MATRIX || !hi_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: A + + Hop newRoot = hi_1; + if ( hi_1.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + + DMLExecutor.println("Applying rewrite: +(0.0,A) => A"); + return newRoot; + } + + // Implementation of the rule -(A,0.0) => A + private static Hop _applyRewrite4(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MINUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( hi_0.getDataType() != Types.DataType.MATRIX || !hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( !(hi_1 instanceof LiteralOp) ) + return hi; + + LiteralOp l_hi_1 = (LiteralOp) hi_1; + + if ( l_hi_1.getDataType() != Types.DataType.SCALAR|| !l_hi_1.getValueType().isNumeric() ) + return hi; + + if ( l_hi_1.getDoubleValue() != 0.0 ) + return hi; + + + // Now, we start building the new HOP-DAG: A + + Hop newRoot = hi_0; + if ( hi_0.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_1); + + DMLExecutor.println("Applying rewrite: -(A,0.0) => A"); + return newRoot; + } + + // Implementation of the rule *(A,0.0) => const(A,0.0) + private static Hop _applyRewrite5(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MULT || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( hi_0.getDataType() != Types.DataType.MATRIX || !hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( !(hi_1 instanceof LiteralOp) ) + return hi; + + LiteralOp l_hi_1 = (LiteralOp) hi_1; + + if ( l_hi_1.getDataType() != Types.DataType.SCALAR|| !l_hi_1.getValueType().isNumeric() ) + return hi; + + if ( l_hi_1.getDoubleValue() != 0.0 ) + return hi; + + + // Now, we start building the new HOP-DAG: const(A,0.0) + DataGenOp v1 = ((DataGenOp) HopRewriteUtils.createDataGenOpFromDims(HopRewriteUtils.createUnary(hi_0, Types.OpOp1.NROW),HopRewriteUtils.createUnary(hi_0, Types.OpOp1.NCOL),0.0D)); + + Hop newRoot = v1; + if ( v1.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + + DMLExecutor.println("Applying rewrite: *(A,0.0) => const(A,0.0)"); + return newRoot; + } + + // Implementation of the rule *(0.0,A) => const(A,0.0) + private static Hop _applyRewrite6(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MULT || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( !(hi_0 instanceof LiteralOp) ) + return hi; + + LiteralOp l_hi_0 = (LiteralOp) hi_0; + + if ( l_hi_0.getDataType() != Types.DataType.SCALAR|| !l_hi_0.getValueType().isNumeric() ) + return hi; + + if ( l_hi_0.getDoubleValue() != 0.0 ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.MATRIX || !hi_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: const(A,0.0) + DataGenOp v1 = ((DataGenOp) HopRewriteUtils.createDataGenOpFromDims(HopRewriteUtils.createUnary(hi_1, Types.OpOp1.NROW),HopRewriteUtils.createUnary(hi_1, Types.OpOp1.NCOL),0.0D)); + + Hop newRoot = v1; + if ( v1.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + + DMLExecutor.println("Applying rewrite: *(0.0,A) => const(A,0.0)"); + return newRoot; + } + + // Implementation of the rule +(-(0.0,A),B) => -(B,A) + private static Hop _applyRewrite7(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.PLUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.MINUS || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if ( !(hi_0_0 instanceof LiteralOp) ) + return hi; + + LiteralOp l_hi_0_0 = (LiteralOp) hi_0_0; + + if ( l_hi_0_0.getDataType() != Types.DataType.SCALAR|| !l_hi_0_0.getValueType().isNumeric() ) + return hi; + + if ( l_hi_0_0.getDoubleValue() != 0.0 ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if ( hi_0_1.getDataType() != Types.DataType.MATRIX || !hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.MATRIX || !hi_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: -(B,A) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1, hi_0_1) ) { + return hi; + } + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_1, hi_0_1, Types.OpOp2.MINUS); + + Hop newRoot = v1; + if ( v1.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + HopRewriteUtils.cleanupUnreferenced(hi_0_0); + + DMLExecutor.println("Applying rewrite: +(-(0.0,A),B) => -(B,A)"); + return newRoot; + } + + // Implementation of the rule -(0.0,-(B,A)) => -(A,B) + private static Hop _applyRewrite8(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MINUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( !(hi_0 instanceof LiteralOp) ) + return hi; + + LiteralOp l_hi_0 = (LiteralOp) hi_0; + + if ( l_hi_0.getDataType() != Types.DataType.SCALAR|| !l_hi_0.getValueType().isNumeric() ) + return hi; + + if ( l_hi_0.getDoubleValue() != 0.0 ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if (hi_1.getParent().size() > 1) + return hi; + if ( !(hi_1 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_1 = (BinaryOp) hi_1; + + if ( c_hi_1.getOp() != Types.OpOp2.MINUS || !c_hi_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1_0 = hi_1.getInput(0); + + if ( hi_1_0.getDataType() != Types.DataType.MATRIX || !hi_1_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1_1 = hi_1.getInput(1); + + if ( hi_1_1.getDataType() != Types.DataType.MATRIX || !hi_1_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: -(A,B) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1_1, hi_1_0) ) { + return hi; + } + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_1_1, hi_1_0, Types.OpOp2.MINUS); + + Hop newRoot = v1; + if ( v1.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + HopRewriteUtils.cleanupUnreferenced(hi_1); + + DMLExecutor.println("Applying rewrite: -(0.0,-(B,A)) => -(A,B)"); + return newRoot; + } + + // Implementation of the rule *(A,/(1.0,B)) => /(A,B) + private static Hop _applyRewrite9(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MULT || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( hi_0.getDataType() != Types.DataType.MATRIX || !hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if (hi_1.getParent().size() > 1) + return hi; + if ( !(hi_1 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_1 = (BinaryOp) hi_1; + + if ( c_hi_1.getOp() != Types.OpOp2.DIV || !c_hi_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1_0 = hi_1.getInput(0); + + if ( !(hi_1_0 instanceof LiteralOp) ) + return hi; + + LiteralOp l_hi_1_0 = (LiteralOp) hi_1_0; + + if ( l_hi_1_0.getDataType() != Types.DataType.SCALAR|| !l_hi_1_0.getValueType().isNumeric() ) + return hi; + + if ( l_hi_1_0.getDoubleValue() != 1.0 ) + return hi; + + Hop hi_1_1 = hi_1.getInput(1); + + if ( hi_1_1.getDataType() != Types.DataType.MATRIX || !hi_1_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: /(A,B) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0, hi_1_1) ) { + return hi; + } + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0, hi_1_1, Types.OpOp2.DIV); + + Hop newRoot = v1; + if ( v1.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_1); + HopRewriteUtils.cleanupUnreferenced(hi_1_0); + + DMLExecutor.println("Applying rewrite: *(A,/(1.0,B)) => /(A,B)"); + return newRoot; + } + + // Implementation of the rule +(-(A,a),b) => +(A,-(b,a)) + private static Hop _applyRewrite10(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.PLUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.MINUS || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if ( hi_0_0.getDataType() != Types.DataType.MATRIX || !hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if ( hi_0_1.getDataType() != Types.DataType.SCALAR || !hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.SCALAR || !hi_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: +(A,-(b,a)) + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_1, hi_0_1, Types.OpOp2.MINUS); + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_0, v1, Types.OpOp2.PLUS); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + + DMLExecutor.println("Applying rewrite: +(-(A,a),b) => +(A,-(b,a))"); + return newRoot; + } + + // Implementation of the rule +(a,-(A,b)) => +(A,-(a,b)) + private static Hop _applyRewrite11(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.PLUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( hi_0.getDataType() != Types.DataType.SCALAR || !hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if (hi_1.getParent().size() > 1) + return hi; + if ( !(hi_1 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_1 = (BinaryOp) hi_1; + + if ( c_hi_1.getOp() != Types.OpOp2.MINUS || !c_hi_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1_0 = hi_1.getInput(0); + + if ( hi_1_0.getDataType() != Types.DataType.MATRIX || !hi_1_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1_1 = hi_1.getInput(1); + + if ( hi_1_1.getDataType() != Types.DataType.SCALAR || !hi_1_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: +(A,-(a,b)) + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0, hi_1_1, Types.OpOp2.MINUS); + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(hi_1_0, v1, Types.OpOp2.PLUS); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_1); + + DMLExecutor.println("Applying rewrite: +(a,-(A,b)) => +(A,-(a,b))"); + return newRoot; + } + + // Implementation of the rule +(-(a,A),b) => -(+(a,b),A) + private static Hop _applyRewrite12(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.PLUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.MINUS || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if ( hi_0_0.getDataType() != Types.DataType.SCALAR || !hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if ( hi_0_1.getDataType() != Types.DataType.MATRIX || !hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.SCALAR || !hi_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: -(+(a,b),A) + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_0, hi_1, Types.OpOp2.PLUS); + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(v1, hi_0_1, Types.OpOp2.MINUS); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + + DMLExecutor.println("Applying rewrite: +(-(a,A),b) => -(+(a,b),A)"); + return newRoot; + } + + // Implementation of the rule +(a,-(b,A)) => -(+(a,b),A) + private static Hop _applyRewrite13(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.PLUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( hi_0.getDataType() != Types.DataType.SCALAR || !hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if (hi_1.getParent().size() > 1) + return hi; + if ( !(hi_1 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_1 = (BinaryOp) hi_1; + + if ( c_hi_1.getOp() != Types.OpOp2.MINUS || !c_hi_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1_0 = hi_1.getInput(0); + + if ( hi_1_0.getDataType() != Types.DataType.SCALAR || !hi_1_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1_1 = hi_1.getInput(1); + + if ( hi_1_1.getDataType() != Types.DataType.MATRIX || !hi_1_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: -(+(a,b),A) + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0, hi_1_0, Types.OpOp2.PLUS); + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(v1, hi_1_1, Types.OpOp2.MINUS); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_1); + + DMLExecutor.println("Applying rewrite: +(a,-(b,A)) => -(+(a,b),A)"); + return newRoot; + } + + // Implementation of the rule -(-(A,a),b) => -(A,+(b,a)) + private static Hop _applyRewrite14(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MINUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.MINUS || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if ( hi_0_0.getDataType() != Types.DataType.MATRIX || !hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if ( hi_0_1.getDataType() != Types.DataType.SCALAR || !hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.SCALAR || !hi_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: -(A,+(b,a)) + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_1, hi_0_1, Types.OpOp2.PLUS); + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_0, v1, Types.OpOp2.MINUS); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + + DMLExecutor.println("Applying rewrite: -(-(A,a),b) => -(A,+(b,a))"); + return newRoot; + } + + // Implementation of the rule -(a,-(A,b)) => -(+(a,b),A) + private static Hop _applyRewrite15(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MINUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( hi_0.getDataType() != Types.DataType.SCALAR || !hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if (hi_1.getParent().size() > 1) + return hi; + if ( !(hi_1 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_1 = (BinaryOp) hi_1; + + if ( c_hi_1.getOp() != Types.OpOp2.MINUS || !c_hi_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1_0 = hi_1.getInput(0); + + if ( hi_1_0.getDataType() != Types.DataType.MATRIX || !hi_1_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1_1 = hi_1.getInput(1); + + if ( hi_1_1.getDataType() != Types.DataType.SCALAR || !hi_1_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: -(+(a,b),A) + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0, hi_1_1, Types.OpOp2.PLUS); + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(v1, hi_1_0, Types.OpOp2.MINUS); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_1); + + DMLExecutor.println("Applying rewrite: -(a,-(A,b)) => -(+(a,b),A)"); + return newRoot; + } + + // Implementation of the rule -(-(a,A),b) => -(-(a,b),A) + private static Hop _applyRewrite16(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MINUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.MINUS || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if ( hi_0_0.getDataType() != Types.DataType.SCALAR || !hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if ( hi_0_1.getDataType() != Types.DataType.MATRIX || !hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.SCALAR || !hi_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: -(-(a,b),A) + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_0, hi_1, Types.OpOp2.MINUS); + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(v1, hi_0_1, Types.OpOp2.MINUS); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + + DMLExecutor.println("Applying rewrite: -(-(a,A),b) => -(-(a,b),A)"); + return newRoot; + } + + // Implementation of the rule -(a,-(b,A)) => +(-(a,b),A) + private static Hop _applyRewrite17(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MINUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( hi_0.getDataType() != Types.DataType.SCALAR || !hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if (hi_1.getParent().size() > 1) + return hi; + if ( !(hi_1 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_1 = (BinaryOp) hi_1; + + if ( c_hi_1.getOp() != Types.OpOp2.MINUS || !c_hi_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1_0 = hi_1.getInput(0); + + if ( hi_1_0.getDataType() != Types.DataType.SCALAR || !hi_1_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1_1 = hi_1.getInput(1); + + if ( hi_1_1.getDataType() != Types.DataType.MATRIX || !hi_1_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: +(-(a,b),A) + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0, hi_1_0, Types.OpOp2.MINUS); + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(v1, hi_1_1, Types.OpOp2.PLUS); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_1); + + DMLExecutor.println("Applying rewrite: -(a,-(b,A)) => +(-(a,b),A)"); + return newRoot; + } + + // Implementation of the rule +(*(*(y_corr,-(float599,is_zero_y_corr)),tmp8608),*(tmp20367,+(tmp23071,tmp55180))) => +(*(*(tmp8608,y_corr),-(float599,is_zero_y_corr)),*(tmp20367,+(tmp55180,tmp23071))) + private static Hop _applyRewrite18(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.PLUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.MULT || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if (hi_0_0.getParent().size() > 1) + return hi; + if ( !(hi_0_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0_0 = (BinaryOp) hi_0_0; + + if ( c_hi_0_0.getOp() != Types.OpOp2.MULT || !c_hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0_0 = hi_0_0.getInput(0); + + if ( hi_0_0_0.getDataType() != Types.DataType.MATRIX || !hi_0_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0_1 = hi_0_0.getInput(1); + + if (hi_0_0_1.getParent().size() > 1) + return hi; + if ( !(hi_0_0_1 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0_0_1 = (BinaryOp) hi_0_0_1; + + if ( c_hi_0_0_1.getOp() != Types.OpOp2.MINUS || !c_hi_0_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0_1_0 = hi_0_0_1.getInput(0); + + if ( hi_0_0_1_0.getDataType() != Types.DataType.SCALAR || !hi_0_0_1_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0_1_1 = hi_0_0_1.getInput(1); + + if ( hi_0_0_1_1.getDataType() != Types.DataType.MATRIX || !hi_0_0_1_1.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if ( hi_0_1.getDataType() != Types.DataType.MATRIX || !hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if (hi_1.getParent().size() > 1) + return hi; + if ( !(hi_1 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_1 = (BinaryOp) hi_1; + + if ( c_hi_1.getOp() != Types.OpOp2.MULT || !c_hi_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1_0 = hi_1.getInput(0); + + if ( hi_1_0.getDataType() != Types.DataType.SCALAR || !hi_1_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1_1 = hi_1.getInput(1); + + if (hi_1_1.getParent().size() > 1) + return hi; + if ( !(hi_1_1 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_1_1 = (BinaryOp) hi_1_1; + + if ( c_hi_1_1.getOp() != Types.OpOp2.PLUS || !c_hi_1_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1_1_0 = hi_1_1.getInput(0); + + if ( hi_1_1_0.getDataType() != Types.DataType.MATRIX || !hi_1_1_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1_1_1 = hi_1_1.getInput(1); + + if ( hi_1_1_1.getDataType() != Types.DataType.MATRIX || !hi_1_1_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: +(*(*(tmp8608,y_corr),-(float599,is_zero_y_corr)),*(tmp20367,+(tmp55180,tmp23071))) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0_1, hi_0_0_0) ) { + return hi; + } + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_1, hi_0_0_0, Types.OpOp2.MULT); + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_0_1_0, hi_0_0_1_1, Types.OpOp2.MINUS); + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(v1, v2) ) { + HopRewriteUtils.removeAllChildReferences(v1); + HopRewriteUtils.removeAllChildReferences(v2); + return hi; + } + BinaryOp v3 = HopRewriteUtils.createAutoGeneratedBinary(v1, v2, Types.OpOp2.MULT); + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1_1_1, hi_1_1_0) ) { + HopRewriteUtils.removeAllChildReferences(v1); + HopRewriteUtils.removeAllChildReferences(v2); + HopRewriteUtils.removeAllChildReferences(v3); + return hi; + } + BinaryOp v4 = HopRewriteUtils.createAutoGeneratedBinary(hi_1_1_1, hi_1_1_0, Types.OpOp2.PLUS); + BinaryOp v5 = HopRewriteUtils.createAutoGeneratedBinary(hi_1_0, v4, Types.OpOp2.MULT); + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(v3, v5) ) { + HopRewriteUtils.removeAllChildReferences(v1); + HopRewriteUtils.removeAllChildReferences(v2); + HopRewriteUtils.removeAllChildReferences(v3); + HopRewriteUtils.removeAllChildReferences(v4); + HopRewriteUtils.removeAllChildReferences(v5); + return hi; + } + BinaryOp v6 = HopRewriteUtils.createAutoGeneratedBinary(v3, v5, Types.OpOp2.PLUS); + + Hop newRoot = v6; + if ( v6.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + HopRewriteUtils.cleanupUnreferenced(hi_0_0); + HopRewriteUtils.cleanupUnreferenced(hi_0_0_1); + HopRewriteUtils.cleanupUnreferenced(hi_1); + HopRewriteUtils.cleanupUnreferenced(hi_1_1); + + DMLExecutor.println("Applying rewrite: +(*(*(y_corr,-(float599,is_zero_y_corr)),tmp8608),*(tmp20367,+(tmp23071,tmp55180))) => +(*(*(tmp8608,y_corr),-(float599,is_zero_y_corr)),*(tmp20367,+(tmp55180,tmp23071)))"); + return newRoot; + } + + // Implementation of the rule *(/(1.0,tmp5995),tmp41945) => /(tmp41945,tmp5995) + private static Hop _applyRewrite19(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MULT || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.DIV || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if ( !(hi_0_0 instanceof LiteralOp) ) + return hi; + + LiteralOp l_hi_0_0 = (LiteralOp) hi_0_0; + + if ( l_hi_0_0.getDataType() != Types.DataType.SCALAR|| !l_hi_0_0.getValueType().isNumeric() ) + return hi; + + if ( l_hi_0_0.getDoubleValue() != 1.0 ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if ( hi_0_1.getDataType() != Types.DataType.MATRIX || !hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.SCALAR || !hi_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: /(tmp41945,tmp5995) + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_1, hi_0_1, Types.OpOp2.DIV); + + Hop newRoot = v1; + if ( v1.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + HopRewriteUtils.cleanupUnreferenced(hi_0_0); + + DMLExecutor.println("Applying rewrite: *(/(1.0,tmp5995),tmp41945) => /(tmp41945,tmp5995)"); + return newRoot; + } + + // Implementation of the rule +(-(tmp80035,f12880),tmp63699) => -(+(tmp63699,tmp80035),f12880) + private static Hop _applyRewrite20(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.PLUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.MINUS || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if ( hi_0_0.getDataType() != Types.DataType.MATRIX || !hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if ( hi_0_1.getDataType() != Types.DataType.SCALAR || !hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.MATRIX || !hi_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: -(+(tmp63699,tmp80035),f12880) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1, hi_0_0) ) { + return hi; + } + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_1, hi_0_0, Types.OpOp2.PLUS); + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(v1, hi_0_1, Types.OpOp2.MINUS); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + + DMLExecutor.println("Applying rewrite: +(-(tmp80035,f12880),tmp63699) => -(+(tmp63699,tmp80035),f12880)"); + return newRoot; + } + + // Implementation of the rule -(tmp66496,cast.MATRIX(tmp91996)) => cast.MATRIX(-(tmp66496,tmp91996)) + private static Hop _applyRewrite21(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MINUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( hi_0.getDataType() != Types.DataType.SCALAR || !hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if (hi_1.getParent().size() > 1) + return hi; + if ( !(hi_1 instanceof UnaryOp) ) + return hi; + + UnaryOp c_hi_1 = (UnaryOp) hi_1; + + if ( c_hi_1.getOp() != Types.OpOp1.CAST_AS_MATRIX || !c_hi_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1_0 = hi_1.getInput(0); + + if ( hi_1_0.getDataType() != Types.DataType.SCALAR || !hi_1_0.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: cast.MATRIX(-(tmp66496,tmp91996)) + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0, hi_1_0, Types.OpOp2.MINUS); + UnaryOp v2 = HopRewriteUtils.createUnary(v1, Types.OpOp1.CAST_AS_MATRIX); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_1); + + DMLExecutor.println("Applying rewrite: -(tmp66496,cast.MATRIX(tmp91996)) => cast.MATRIX(-(tmp66496,tmp91996))"); + return newRoot; + } + + // Implementation of the rule t(==(key_unique,t(key))) => ==(key,t(key_unique)) + private static Hop _applyRewrite22(Hop hi) { + if ( !(hi instanceof ReorgOp) ) + return hi; + + ReorgOp c_hi = (ReorgOp) hi; + + if ( c_hi.getOp() != Types.ReOrgOp.TRANS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.EQUAL || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if ( hi_0_0.getDataType() != Types.DataType.MATRIX || !hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if (hi_0_1.getParent().size() > 1) + return hi; + if ( !(hi_0_1 instanceof ReorgOp) ) + return hi; + + ReorgOp c_hi_0_1 = (ReorgOp) hi_0_1; + + if ( c_hi_0_1.getOp() != Types.ReOrgOp.TRANS || !c_hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1_0 = hi_0_1.getInput(0); + + if ( hi_0_1_0.getDataType() != Types.DataType.MATRIX || !hi_0_1_0.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: ==(key,t(key_unique)) + ReorgOp v1 = HopRewriteUtils.createTranspose(hi_0_0); + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0_1_0, v1) ) { + HopRewriteUtils.removeAllChildReferences(v1); + return hi; + } + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_1_0, v1, Types.OpOp2.EQUAL); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + HopRewriteUtils.cleanupUnreferenced(hi_0_1); + + DMLExecutor.println("Applying rewrite: t(==(key_unique,t(key))) => ==(key,t(key_unique))"); + return newRoot; + } + + // Implementation of the rule sum(/(tmp83271,tmp60732)) => /(sum(tmp83271),tmp60732) + private static Hop _applyRewrite23(Hop hi) { + if ( !(hi instanceof AggUnaryOp) ) + return hi; + + AggUnaryOp c_hi = (AggUnaryOp) hi; + + if ( c_hi.getOp() != Types.AggOp.SUM || !c_hi.getValueType().isNumeric() ) + return hi; + + if ( !(c_hi.getDirection() == Types.Direction.RowCol) ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.DIV || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if ( hi_0_0.getDataType() != Types.DataType.MATRIX || !hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if ( hi_0_1.getDataType() != Types.DataType.SCALAR || !hi_0_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: /(sum(tmp83271),tmp60732) + AggUnaryOp v1 = HopRewriteUtils.createAggUnaryOp(hi_0_0, Types.AggOp.SUM, Types.Direction.RowCol); + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(v1, hi_0_1, Types.OpOp2.DIV); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + + DMLExecutor.println("Applying rewrite: sum(/(tmp83271,tmp60732)) => /(sum(tmp83271),tmp60732)"); + return newRoot; + } + + // Implementation of the rule %*%(t(X_batch),tmp92007) => {t(%*%(t(tmp92007),X_batch))} + private static Hop _applyRewrite24(Hop hi) { + if ( !HopRewriteUtils.isMatrixMultiply(hi) ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof ReorgOp) ) + return hi; + + ReorgOp c_hi_0 = (ReorgOp) hi_0; + + if ( c_hi_0.getOp() != Types.ReOrgOp.TRANS || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if ( hi_0_0.getDataType() != Types.DataType.MATRIX || !hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.MATRIX || !hi_1.getValueType().isNumeric() ) + return hi; + + + if ( hi_1.getDim2() == -1 || hi_1.getNnz() == -1 || hi_0_0.getNnz() == -1 || hi_0_0.getDim2() == -1 || hi_1.getDim1() == -1 ) + return hi; + + + double[] costs = new double[2]; + costs[0] = (hi_0_0.getNnz() + (Math.min(hi_0_0.getNnz(), hi_1.getNnz()) * hi_1.getDim1() * 3.0) + 20020.0); + costs[1] = (hi_1.getNnz() + (Math.min(hi_1.getNnz(), hi_0_0.getNnz()) * hi_1.getDim1() * 3.0) + (Math.min((hi_1.getNnz() * (1.0 / hi_1.getDim2())), 1.0) * Math.min((hi_0_0.getNnz() * (1.0 / hi_0_0.getDim2())), 1.0) * hi_1.getDim2() * hi_0_0.getDim2()) + 30030.0); + int minIdx = minIdx(costs); + + switch( minIdx ) { + case 1: { + // Now, we start building the new HOP-DAG: t(%*%(t(tmp92007),X_batch)) + ReorgOp v1 = HopRewriteUtils.createTranspose(hi_1); + AggBinaryOp v2 = HopRewriteUtils.createMatrixMultiply(v1, hi_0_0); + ReorgOp v3 = HopRewriteUtils.createTranspose(v2); + + Hop newRoot = v3; + if ( v3.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + + DMLExecutor.println("Applying rewrite: %*%(t(X_batch),tmp92007) => {t(%*%(t(tmp92007),X_batch))}"); + return newRoot; + } + } + return hi; + } + + // Implementation of the rule *(*(y_corr,-(float599,is_zero_y_corr)),tmp8608) => *(*(y_corr,tmp8608),-(float599,is_zero_y_corr)) + private static Hop _applyRewrite25(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MULT || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.MULT || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if ( hi_0_0.getDataType() != Types.DataType.MATRIX || !hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if (hi_0_1.getParent().size() > 1) + return hi; + if ( !(hi_0_1 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0_1 = (BinaryOp) hi_0_1; + + if ( c_hi_0_1.getOp() != Types.OpOp2.MINUS || !c_hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1_0 = hi_0_1.getInput(0); + + if ( hi_0_1_0.getDataType() != Types.DataType.SCALAR || !hi_0_1_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1_1 = hi_0_1.getInput(1); + + if ( hi_0_1_1.getDataType() != Types.DataType.MATRIX || !hi_0_1_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.MATRIX || !hi_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: *(*(y_corr,tmp8608),-(float599,is_zero_y_corr)) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0_0, hi_1) ) { + return hi; + } + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_0, hi_1, Types.OpOp2.MULT); + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_1_0, hi_0_1_1, Types.OpOp2.MINUS); + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(v1, v2) ) { + HopRewriteUtils.removeAllChildReferences(v1); + HopRewriteUtils.removeAllChildReferences(v2); + return hi; + } + BinaryOp v3 = HopRewriteUtils.createAutoGeneratedBinary(v1, v2, Types.OpOp2.MULT); + + Hop newRoot = v3; + if ( v3.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + HopRewriteUtils.cleanupUnreferenced(hi_0_1); + + DMLExecutor.println("Applying rewrite: *(*(y_corr,-(float599,is_zero_y_corr)),tmp8608) => *(*(y_corr,tmp8608),-(float599,is_zero_y_corr))"); + return newRoot; + } + + // Implementation of the rule *(%*%(scale_lambda,parsertemp150455),tmp43267) => {%*%(*(tmp43267,scale_lambda),parsertemp150455)} + private static Hop _applyRewrite26(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MULT || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !HopRewriteUtils.isMatrixMultiply(hi_0) ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if ( hi_0_0.getDataType() != Types.DataType.MATRIX || !hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if ( hi_0_1.getDataType() != Types.DataType.MATRIX || !hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.SCALAR || !hi_1.getValueType().isNumeric() ) + return hi; + + + if ( hi_0_0.getDim1() == -1 || hi_0_1.getDim2() == -1 || hi_0_1.getNnz() == -1 || hi_0_0.getNnz() == -1 || hi_0_1.getDim1() == -1 ) + return hi; + + + double[] costs = new double[2]; + costs[0] = ((Math.min(hi_0_0.getNnz(), hi_0_1.getNnz()) * hi_0_1.getDim1() * 3.0) + (2.0 * (Math.min((hi_0_0.getNnz() * (1.0 / hi_0_0.getDim1())), 1.0) * Math.min((hi_0_1.getNnz() * (1.0 / hi_0_1.getDim2())), 1.0) * hi_0_0.getDim1() * hi_0_1.getDim2())) + 20020.0); + costs[1] = ((2.0 * hi_0_0.getNnz()) + (Math.min(hi_0_0.getNnz(), hi_0_1.getNnz()) * hi_0_1.getDim1() * 3.0) + 20020.0); + int minIdx = minIdx(costs); + + switch( minIdx ) { + case 1: { + // Now, we start building the new HOP-DAG: %*%(*(tmp43267,scale_lambda),parsertemp150455) + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_1, hi_0_0, Types.OpOp2.MULT); + AggBinaryOp v2 = HopRewriteUtils.createMatrixMultiply(v1, hi_0_1); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + + DMLExecutor.println("Applying rewrite: *(%*%(scale_lambda,parsertemp150455),tmp43267) => {%*%(*(tmp43267,scale_lambda),parsertemp150455)}"); + return newRoot; + } + } + return hi; + } + + // Implementation of the rule sum(*(*(tmp8790,tmp30390),tmp97178)) => *(tmp30390,sum(*(tmp97178,tmp8790))) + private static Hop _applyRewrite27(Hop hi) { + if ( !(hi instanceof AggUnaryOp) ) + return hi; + + AggUnaryOp c_hi = (AggUnaryOp) hi; + + if ( c_hi.getOp() != Types.AggOp.SUM || !c_hi.getValueType().isNumeric() ) + return hi; + + if ( !(c_hi.getDirection() == Types.Direction.RowCol) ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.MULT || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if (hi_0_0.getParent().size() > 1) + return hi; + if ( !(hi_0_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0_0 = (BinaryOp) hi_0_0; + + if ( c_hi_0_0.getOp() != Types.OpOp2.MULT || !c_hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0_0 = hi_0_0.getInput(0); + + if ( hi_0_0_0.getDataType() != Types.DataType.MATRIX || !hi_0_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0_1 = hi_0_0.getInput(1); + + if ( hi_0_0_1.getDataType() != Types.DataType.SCALAR || !hi_0_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if ( hi_0_1.getDataType() != Types.DataType.MATRIX || !hi_0_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: *(tmp30390,sum(*(tmp97178,tmp8790))) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0_1, hi_0_0_0) ) { + return hi; + } + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_1, hi_0_0_0, Types.OpOp2.MULT); + AggUnaryOp v2 = HopRewriteUtils.createAggUnaryOp(v1, Types.AggOp.SUM, Types.Direction.RowCol); + BinaryOp v3 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_0_1, v2, Types.OpOp2.MULT); + + Hop newRoot = v3; + if ( v3.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + HopRewriteUtils.cleanupUnreferenced(hi_0_0); + + DMLExecutor.println("Applying rewrite: sum(*(*(tmp8790,tmp30390),tmp97178)) => *(tmp30390,sum(*(tmp97178,tmp8790)))"); + return newRoot; + } + + // Implementation of the rule -(+(a,tmp82242),tmp98488) => +(-(tmp82242,tmp98488),a) + private static Hop _applyRewrite28(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MINUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.PLUS || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if ( hi_0_0.getDataType() != Types.DataType.SCALAR || !hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if ( hi_0_1.getDataType() != Types.DataType.MATRIX || !hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.MATRIX || !hi_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: +(-(tmp82242,tmp98488),a) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0_1, hi_1) ) { + return hi; + } + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_1, hi_1, Types.OpOp2.MINUS); + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(v1, hi_0_0, Types.OpOp2.PLUS); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + + DMLExecutor.println("Applying rewrite: -(+(a,tmp82242),tmp98488) => +(-(tmp82242,tmp98488),a)"); + return newRoot; + } + + // Implementation of the rule -(obj,+(tmp6500,tmp26035)) => -(-(obj,tmp6500),tmp26035) + private static Hop _applyRewrite29(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MINUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( hi_0.getDataType() != Types.DataType.MATRIX || !hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if (hi_1.getParent().size() > 1) + return hi; + if ( !(hi_1 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_1 = (BinaryOp) hi_1; + + if ( c_hi_1.getOp() != Types.OpOp2.PLUS || !c_hi_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1_0 = hi_1.getInput(0); + + if ( hi_1_0.getDataType() != Types.DataType.MATRIX || !hi_1_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1_1 = hi_1.getInput(1); + + if ( hi_1_1.getDataType() != Types.DataType.SCALAR || !hi_1_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: -(-(obj,tmp6500),tmp26035) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0, hi_1_0) ) { + return hi; + } + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0, hi_1_0, Types.OpOp2.MINUS); + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(v1, hi_1_1, Types.OpOp2.MINUS); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_1); + + DMLExecutor.println("Applying rewrite: -(obj,+(tmp6500,tmp26035)) => -(-(obj,tmp6500),tmp26035)"); + return newRoot; + } + + // Implementation of the rule -(-(tmp68530,tmp73960),tmp29113) => -(tmp68530,+(tmp73960,tmp29113)) + private static Hop _applyRewrite30(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MINUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.MINUS || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if ( hi_0_0.getDataType() != Types.DataType.SCALAR || !hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if ( hi_0_1.getDataType() != Types.DataType.MATRIX || !hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.MATRIX || !hi_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: -(tmp68530,+(tmp73960,tmp29113)) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0_1, hi_1) ) { + return hi; + } + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_1, hi_1, Types.OpOp2.PLUS); + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_0, v1, Types.OpOp2.MINUS); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + + DMLExecutor.println("Applying rewrite: -(-(tmp68530,tmp73960),tmp29113) => -(tmp68530,+(tmp73960,tmp29113))"); + return newRoot; + } + + // Implementation of the rule +(-(a,tmp98488),tmp82242) => +(-(tmp82242,tmp98488),a) + private static Hop _applyRewrite31(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.PLUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.MINUS || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if ( hi_0_0.getDataType() != Types.DataType.SCALAR || !hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if ( hi_0_1.getDataType() != Types.DataType.MATRIX || !hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.MATRIX || !hi_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: +(-(tmp82242,tmp98488),a) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1, hi_0_1) ) { + return hi; + } + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_1, hi_0_1, Types.OpOp2.MINUS); + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(v1, hi_0_0, Types.OpOp2.PLUS); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + + DMLExecutor.println("Applying rewrite: +(-(a,tmp98488),tmp82242) => +(-(tmp82242,tmp98488),a)"); + return newRoot; + } + + // Implementation of the rule +(*(tmp99142,missing_mask_Y),*(tmp58606,missing_mask_Y)) => *(missing_mask_Y,+(tmp99142,tmp58606)) + private static Hop _applyRewrite32(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.PLUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.MULT || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if ( hi_0_0.getDataType() != Types.DataType.SCALAR || !hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if ( hi_0_1.getDataType() != Types.DataType.MATRIX || !hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if (hi_1.getParent().size() > 1) + return hi; + if ( !(hi_1 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_1 = (BinaryOp) hi_1; + + if ( c_hi_1.getOp() != Types.OpOp2.MULT || !c_hi_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1_0 = hi_1.getInput(0); + + if ( hi_1_0.getDataType() != Types.DataType.SCALAR || !hi_1_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1_1 = hi_1.getInput(1); + + if ( hi_0_1 != hi_1_1 ) + return hi; + + + // Now, we start building the new HOP-DAG: *(missing_mask_Y,+(tmp99142,tmp58606)) + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_0, hi_1_0, Types.OpOp2.PLUS); + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_1, v1, Types.OpOp2.MULT); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + HopRewriteUtils.cleanupUnreferenced(hi_1); + + DMLExecutor.println("Applying rewrite: +(*(tmp99142,missing_mask_Y),*(tmp58606,missing_mask_Y)) => *(missing_mask_Y,+(tmp99142,tmp58606))"); + return newRoot; + } + + // Implementation of the rule *(tmp43267,%*%(scale_lambda,parsertemp150455)) => {%*%(*(tmp43267,scale_lambda),parsertemp150455)} + private static Hop _applyRewrite33(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MULT || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( hi_0.getDataType() != Types.DataType.SCALAR || !hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if (hi_1.getParent().size() > 1) + return hi; + if ( !HopRewriteUtils.isMatrixMultiply(hi_1) ) + return hi; + + Hop hi_1_0 = hi_1.getInput(0); + + if ( hi_1_0.getDataType() != Types.DataType.MATRIX || !hi_1_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1_1 = hi_1.getInput(1); + + if ( hi_1_1.getDataType() != Types.DataType.MATRIX || !hi_1_1.getValueType().isNumeric() ) + return hi; + + + if ( hi_1_0.getNnz() == -1 || hi_1_1.getDim2() == -1 || hi_1_0.getDim1() == -1 || hi_1_0.getDim2() == -1 || hi_1_1.getNnz() == -1 ) + return hi; + + + double[] costs = new double[2]; + costs[0] = ((Math.min(hi_1_0.getNnz(), hi_1_1.getNnz()) * hi_1_0.getDim2() * 3.0) + (2.0 * (Math.min((hi_1_0.getNnz() * (1.0 / hi_1_0.getDim1())), 1.0) * Math.min((hi_1_1.getNnz() * (1.0 / hi_1_1.getDim2())), 1.0) * hi_1_0.getDim1() * hi_1_1.getDim2())) + 20020.0); + costs[1] = ((2.0 * hi_1_0.getNnz()) + (Math.min(hi_1_0.getNnz(), hi_1_1.getNnz()) * hi_1_0.getDim2() * 3.0) + 20020.0); + int minIdx = minIdx(costs); + + switch( minIdx ) { + case 1: { + // Now, we start building the new HOP-DAG: %*%(*(tmp43267,scale_lambda),parsertemp150455) + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0, hi_1_0, Types.OpOp2.MULT); + AggBinaryOp v2 = HopRewriteUtils.createMatrixMultiply(v1, hi_1_1); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_1); + + DMLExecutor.println("Applying rewrite: *(tmp43267,%*%(scale_lambda,parsertemp150455)) => {%*%(*(tmp43267,scale_lambda),parsertemp150455)}"); + return newRoot; + } + } + return hi; + } + + // Implementation of the rule *(/(1.0,B),a) => /(a,B) + private static Hop _applyRewrite34(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MULT || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.DIV || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if ( !(hi_0_0 instanceof LiteralOp) ) + return hi; + + LiteralOp l_hi_0_0 = (LiteralOp) hi_0_0; + + if ( l_hi_0_0.getDataType() != Types.DataType.SCALAR|| !l_hi_0_0.getValueType().isNumeric() ) + return hi; + + if ( l_hi_0_0.getDoubleValue() != 1.0 ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if ( hi_0_1.getDataType() != Types.DataType.MATRIX || !hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.SCALAR || !hi_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: /(a,B) + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_1, hi_0_1, Types.OpOp2.DIV); + + Hop newRoot = v1; + if ( v1.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + HopRewriteUtils.cleanupUnreferenced(hi_0_0); + + DMLExecutor.println("Applying rewrite: *(/(1.0,B),a) => /(a,B)"); + return newRoot; + } + + // Implementation of the rule /(a,cast.MATRIX(b)) => cast.MATRIX(/(a,b)) + private static Hop _applyRewrite35(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.DIV || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( hi_0.getDataType() != Types.DataType.SCALAR || !hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if (hi_1.getParent().size() > 1) + return hi; + if ( !(hi_1 instanceof UnaryOp) ) + return hi; + + UnaryOp c_hi_1 = (UnaryOp) hi_1; + + if ( c_hi_1.getOp() != Types.OpOp1.CAST_AS_MATRIX || !c_hi_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1_0 = hi_1.getInput(0); + + if ( hi_1_0.getDataType() != Types.DataType.SCALAR || !hi_1_0.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: cast.MATRIX(/(a,b)) + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0, hi_1_0, Types.OpOp2.DIV); + UnaryOp v2 = HopRewriteUtils.createUnary(v1, Types.OpOp1.CAST_AS_MATRIX); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_1); + + DMLExecutor.println("Applying rewrite: /(a,cast.MATRIX(b)) => cast.MATRIX(/(a,b))"); + return newRoot; + } + + // Implementation of the rule *(a,cast.MATRIX(b)) => cast.MATRIX(*(a,b)) + private static Hop _applyRewrite36(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MULT || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( hi_0.getDataType() != Types.DataType.SCALAR || !hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if (hi_1.getParent().size() > 1) + return hi; + if ( !(hi_1 instanceof UnaryOp) ) + return hi; + + UnaryOp c_hi_1 = (UnaryOp) hi_1; + + if ( c_hi_1.getOp() != Types.OpOp1.CAST_AS_MATRIX || !c_hi_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1_0 = hi_1.getInput(0); + + if ( hi_1_0.getDataType() != Types.DataType.SCALAR || !hi_1_0.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: cast.MATRIX(*(a,b)) + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0, hi_1_0, Types.OpOp2.MULT); + UnaryOp v2 = HopRewriteUtils.createUnary(v1, Types.OpOp1.CAST_AS_MATRIX); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_1); + + DMLExecutor.println("Applying rewrite: *(a,cast.MATRIX(b)) => cast.MATRIX(*(a,b))"); + return newRoot; + } + + // Implementation of the rule +(-(*(C,b),d),A) => -(+*(A,b,C),d) + private static Hop _applyRewrite37(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.PLUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.MINUS || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if (hi_0_0.getParent().size() > 1) + return hi; + if ( !(hi_0_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0_0 = (BinaryOp) hi_0_0; + + if ( c_hi_0_0.getOp() != Types.OpOp2.MULT || !c_hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0_0 = hi_0_0.getInput(0); + + if ( hi_0_0_0.getDataType() != Types.DataType.MATRIX || !hi_0_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0_1 = hi_0_0.getInput(1); + + if ( hi_0_0_1.getDataType() != Types.DataType.SCALAR || !hi_0_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if ( hi_0_1.getDataType() != Types.DataType.SCALAR || !hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.MATRIX || !hi_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: -(+*(A,b,C),d) + if ( !RewriterRuntimeUtils.hasMatchingDims(hi_1, hi_0_0_0) ) { + return hi; + } + TernaryOp v1 = HopRewriteUtils.createTernary(hi_1, hi_0_0_1, hi_0_0_0,Types.OpOp3.PLUS_MULT); + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(v1, hi_0_1, Types.OpOp2.MINUS); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + HopRewriteUtils.cleanupUnreferenced(hi_0_0); + + DMLExecutor.println("Applying rewrite: +(-(*(C,b),d),A) => -(+*(A,b,C),d)"); + return newRoot; + } + + // Implementation of the rule +(-(*(D,c),B),A) => -(A,-*(B,c,D)) + private static Hop _applyRewrite38(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.PLUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.MINUS || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if (hi_0_0.getParent().size() > 1) + return hi; + if ( !(hi_0_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0_0 = (BinaryOp) hi_0_0; + + if ( c_hi_0_0.getOp() != Types.OpOp2.MULT || !c_hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0_0 = hi_0_0.getInput(0); + + if ( hi_0_0_0.getDataType() != Types.DataType.MATRIX || !hi_0_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0_1 = hi_0_0.getInput(1); + + if ( hi_0_0_1.getDataType() != Types.DataType.SCALAR || !hi_0_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if ( hi_0_1.getDataType() != Types.DataType.MATRIX || !hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.MATRIX || !hi_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: -(A,-*(B,c,D)) + if ( !RewriterRuntimeUtils.hasMatchingDims(hi_0_1, hi_0_0_0) ) { + return hi; + } + TernaryOp v1 = HopRewriteUtils.createTernary(hi_0_1, hi_0_0_1, hi_0_0_0,Types.OpOp3.MINUS_MULT); + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1, v1) ) { + HopRewriteUtils.removeAllChildReferences(v1); + return hi; + } + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(hi_1, v1, Types.OpOp2.MINUS); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + HopRewriteUtils.cleanupUnreferenced(hi_0_0); + + DMLExecutor.println("Applying rewrite: +(-(*(D,c),B),A) => -(A,-*(B,c,D))"); + return newRoot; + } + + // Implementation of the rule +(M9347,*(K,*(M40316,f32765))) => +*(M9347,f32765,*(K,M40316)) + private static Hop _applyRewrite39(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.PLUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( hi_0.getDataType() != Types.DataType.MATRIX || !hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if (hi_1.getParent().size() > 1) + return hi; + if ( !(hi_1 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_1 = (BinaryOp) hi_1; + + if ( c_hi_1.getOp() != Types.OpOp2.MULT || !c_hi_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1_0 = hi_1.getInput(0); + + if ( hi_1_0.getDataType() != Types.DataType.MATRIX || !hi_1_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1_1 = hi_1.getInput(1); + + if (hi_1_1.getParent().size() > 1) + return hi; + if ( !(hi_1_1 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_1_1 = (BinaryOp) hi_1_1; + + if ( c_hi_1_1.getOp() != Types.OpOp2.MULT || !c_hi_1_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1_1_0 = hi_1_1.getInput(0); + + if ( hi_1_1_0.getDataType() != Types.DataType.MATRIX || !hi_1_1_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1_1_1 = hi_1_1.getInput(1); + + if ( hi_1_1_1.getDataType() != Types.DataType.SCALAR || !hi_1_1_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: +*(M9347,f32765,*(K,M40316)) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1_0, hi_1_1_0) ) { + return hi; + } + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_1_0, hi_1_1_0, Types.OpOp2.MULT); + if ( !RewriterRuntimeUtils.hasMatchingDims(hi_0, v1) ) { + HopRewriteUtils.removeAllChildReferences(v1); + return hi; + } + TernaryOp v2 = HopRewriteUtils.createTernary(hi_0, hi_1_1_1, v1,Types.OpOp3.PLUS_MULT); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_1); + HopRewriteUtils.cleanupUnreferenced(hi_1_1); + + DMLExecutor.println("Applying rewrite: +(M9347,*(K,*(M40316,f32765))) => +*(M9347,f32765,*(K,M40316))"); + return newRoot; + } + + // Implementation of the rule -(y,+(%*%(X,B),intercept)) => -(-(y,%*%(X,B)),intercept) + private static Hop _applyRewrite40(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MINUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( hi_0.getDataType() != Types.DataType.MATRIX || !hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if (hi_1.getParent().size() > 1) + return hi; + if ( !(hi_1 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_1 = (BinaryOp) hi_1; + + if ( c_hi_1.getOp() != Types.OpOp2.PLUS || !c_hi_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1_0 = hi_1.getInput(0); + + if (hi_1_0.getParent().size() > 1) + return hi; + if ( !HopRewriteUtils.isMatrixMultiply(hi_1_0) ) + return hi; + + Hop hi_1_0_0 = hi_1_0.getInput(0); + + if ( hi_1_0_0.getDataType() != Types.DataType.MATRIX || !hi_1_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1_0_1 = hi_1_0.getInput(1); + + if ( hi_1_0_1.getDataType() != Types.DataType.MATRIX || !hi_1_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1_1 = hi_1.getInput(1); + + if ( hi_1_1.getDataType() != Types.DataType.SCALAR || !hi_1_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: -(-(y,%*%(X,B)),intercept) + AggBinaryOp v1 = HopRewriteUtils.createMatrixMultiply(hi_1_0_0, hi_1_0_1); + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0, v1) ) { + HopRewriteUtils.removeAllChildReferences(v1); + return hi; + } + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(hi_0, v1, Types.OpOp2.MINUS); + BinaryOp v3 = HopRewriteUtils.createAutoGeneratedBinary(v2, hi_1_1, Types.OpOp2.MINUS); + + Hop newRoot = v3; + if ( v3.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_1); + HopRewriteUtils.cleanupUnreferenced(hi_1_0); + + DMLExecutor.println("Applying rewrite: -(y,+(%*%(X,B),intercept)) => -(-(y,%*%(X,B)),intercept)"); + return newRoot; + } + + // Implementation of the rule +(-(f45081,A),B) => +(f45081,-(B,A)) + private static Hop _applyRewrite41(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.PLUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.MINUS || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if ( hi_0_0.getDataType() != Types.DataType.SCALAR || !hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if ( hi_0_1.getDataType() != Types.DataType.MATRIX || !hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.MATRIX || !hi_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: +(f45081,-(B,A)) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1, hi_0_1) ) { + return hi; + } + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_1, hi_0_1, Types.OpOp2.MINUS); + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_0, v1, Types.OpOp2.PLUS); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + + DMLExecutor.println("Applying rewrite: +(-(f45081,A),B) => +(f45081,-(B,A))"); + return newRoot; + } + + // Implementation of the rule +(M9347,*(*(f32765,K),M40316)) => +*(M9347,f32765,*(K,M40316)) + private static Hop _applyRewrite42(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.PLUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( hi_0.getDataType() != Types.DataType.MATRIX || !hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if (hi_1.getParent().size() > 1) + return hi; + if ( !(hi_1 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_1 = (BinaryOp) hi_1; + + if ( c_hi_1.getOp() != Types.OpOp2.MULT || !c_hi_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1_0 = hi_1.getInput(0); + + if (hi_1_0.getParent().size() > 1) + return hi; + if ( !(hi_1_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_1_0 = (BinaryOp) hi_1_0; + + if ( c_hi_1_0.getOp() != Types.OpOp2.MULT || !c_hi_1_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1_0_0 = hi_1_0.getInput(0); + + if ( hi_1_0_0.getDataType() != Types.DataType.SCALAR || !hi_1_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1_0_1 = hi_1_0.getInput(1); + + if ( hi_1_0_1.getDataType() != Types.DataType.MATRIX || !hi_1_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1_1 = hi_1.getInput(1); + + if ( hi_1_1.getDataType() != Types.DataType.MATRIX || !hi_1_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: +*(M9347,f32765,*(K,M40316)) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1_0_1, hi_1_1) ) { + return hi; + } + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_1_0_1, hi_1_1, Types.OpOp2.MULT); + if ( !RewriterRuntimeUtils.hasMatchingDims(hi_0, v1) ) { + HopRewriteUtils.removeAllChildReferences(v1); + return hi; + } + TernaryOp v2 = HopRewriteUtils.createTernary(hi_0, hi_1_0_0, v1,Types.OpOp3.PLUS_MULT); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_1); + HopRewriteUtils.cleanupUnreferenced(hi_1_0); + + DMLExecutor.println("Applying rewrite: +(M9347,*(*(f32765,K),M40316)) => +*(M9347,f32765,*(K,M40316))"); + return newRoot; + } + + // Implementation of the rule +(*(*(K,f32765),M40316),M9347) => +*(M9347,f32765,*(K,M40316)) + private static Hop _applyRewrite43(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.PLUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.MULT || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if (hi_0_0.getParent().size() > 1) + return hi; + if ( !(hi_0_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0_0 = (BinaryOp) hi_0_0; + + if ( c_hi_0_0.getOp() != Types.OpOp2.MULT || !c_hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0_0 = hi_0_0.getInput(0); + + if ( hi_0_0_0.getDataType() != Types.DataType.MATRIX || !hi_0_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0_1 = hi_0_0.getInput(1); + + if ( hi_0_0_1.getDataType() != Types.DataType.SCALAR || !hi_0_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if ( hi_0_1.getDataType() != Types.DataType.MATRIX || !hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.MATRIX || !hi_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: +*(M9347,f32765,*(K,M40316)) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0_0_0, hi_0_1) ) { + return hi; + } + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_0_0, hi_0_1, Types.OpOp2.MULT); + if ( !RewriterRuntimeUtils.hasMatchingDims(hi_1, v1) ) { + HopRewriteUtils.removeAllChildReferences(v1); + return hi; + } + TernaryOp v2 = HopRewriteUtils.createTernary(hi_1, hi_0_0_1, v1,Types.OpOp3.PLUS_MULT); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + HopRewriteUtils.cleanupUnreferenced(hi_0_0); + + DMLExecutor.println("Applying rewrite: +(*(*(K,f32765),M40316),M9347) => +*(M9347,f32765,*(K,M40316))"); + return newRoot; + } + + // Implementation of the rule *(/(1.0,M13119),A) => /(A,M13119) + private static Hop _applyRewrite44(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MULT || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.DIV || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if ( !(hi_0_0 instanceof LiteralOp) ) + return hi; + + LiteralOp l_hi_0_0 = (LiteralOp) hi_0_0; + + if ( l_hi_0_0.getDataType() != Types.DataType.SCALAR|| !l_hi_0_0.getValueType().isNumeric() ) + return hi; + + if ( l_hi_0_0.getDoubleValue() != 1.0 ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if ( hi_0_1.getDataType() != Types.DataType.MATRIX || !hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.MATRIX || !hi_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: /(A,M13119) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1, hi_0_1) ) { + return hi; + } + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_1, hi_0_1, Types.OpOp2.DIV); + + Hop newRoot = v1; + if ( v1.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + HopRewriteUtils.cleanupUnreferenced(hi_0_0); + + DMLExecutor.println("Applying rewrite: *(/(1.0,M13119),A) => /(A,M13119)"); + return newRoot; + } + + // Implementation of the rule /(M43656,2.0) => *(0.5,M43656) + private static Hop _applyRewrite45(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.DIV || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( hi_0.getDataType() != Types.DataType.MATRIX || !hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( !(hi_1 instanceof LiteralOp) ) + return hi; + + LiteralOp l_hi_1 = (LiteralOp) hi_1; + + if ( l_hi_1.getDataType() != Types.DataType.SCALAR|| !l_hi_1.getValueType().isNumeric() ) + return hi; + + if ( l_hi_1.getDoubleValue() != 2.0 ) + return hi; + + + // Now, we start building the new HOP-DAG: *(0.5,M43656) + LiteralOp l1 = new LiteralOp( 0.5 ); + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(l1, hi_0, Types.OpOp2.MULT); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_1); + + DMLExecutor.println("Applying rewrite: /(M43656,2.0) => *(0.5,M43656)"); + return newRoot; + } + + // Implementation of the rule +(-(b,%*%(C,D)),A) => +(b,-(A,%*%(C,D))) + private static Hop _applyRewrite46(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.PLUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.MINUS || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if ( hi_0_0.getDataType() != Types.DataType.SCALAR || !hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if (hi_0_1.getParent().size() > 1) + return hi; + if ( !HopRewriteUtils.isMatrixMultiply(hi_0_1) ) + return hi; + + Hop hi_0_1_0 = hi_0_1.getInput(0); + + if ( hi_0_1_0.getDataType() != Types.DataType.MATRIX || !hi_0_1_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1_1 = hi_0_1.getInput(1); + + if ( hi_0_1_1.getDataType() != Types.DataType.MATRIX || !hi_0_1_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.MATRIX || !hi_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: +(b,-(A,%*%(C,D))) + AggBinaryOp v1 = HopRewriteUtils.createMatrixMultiply(hi_0_1_0, hi_0_1_1); + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1, v1) ) { + HopRewriteUtils.removeAllChildReferences(v1); + return hi; + } + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(hi_1, v1, Types.OpOp2.MINUS); + BinaryOp v3 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_0, v2, Types.OpOp2.PLUS); + + Hop newRoot = v3; + if ( v3.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + HopRewriteUtils.cleanupUnreferenced(hi_0_1); + + DMLExecutor.println("Applying rewrite: +(-(b,%*%(C,D)),A) => +(b,-(A,%*%(C,D)))"); + return newRoot; + } + + // Implementation of the rule -(-(f43240,A),f67634) => -(-(f43240,f67634),A) + private static Hop _applyRewrite47(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MINUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.MINUS || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if ( hi_0_0.getDataType() != Types.DataType.SCALAR || !hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if ( hi_0_1.getDataType() != Types.DataType.MATRIX || !hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.SCALAR || !hi_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: -(-(f43240,f67634),A) + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_0, hi_1, Types.OpOp2.MINUS); + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(v1, hi_0_1, Types.OpOp2.MINUS); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + + DMLExecutor.println("Applying rewrite: -(-(f43240,A),f67634) => -(-(f43240,f67634),A)"); + return newRoot; + } + + // Implementation of the rule /(M62235,2000.0) => *(5.0E-4,M62235) + private static Hop _applyRewrite48(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.DIV || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( hi_0.getDataType() != Types.DataType.MATRIX || !hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( !(hi_1 instanceof LiteralOp) ) + return hi; + + LiteralOp l_hi_1 = (LiteralOp) hi_1; + + if ( l_hi_1.getDataType() != Types.DataType.SCALAR|| !l_hi_1.getValueType().isNumeric() ) + return hi; + + if ( l_hi_1.getDoubleValue() != 2000.0 ) + return hi; + + + // Now, we start building the new HOP-DAG: *(5.0E-4,M62235) + LiteralOp l1 = new LiteralOp( 5.0E-4 ); + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(l1, hi_0, Types.OpOp2.MULT); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_1); + + DMLExecutor.println("Applying rewrite: /(M62235,2000.0) => *(5.0E-4,M62235)"); + return newRoot; + } + + // Implementation of the rule *(A,/(1.0,M13119)) => /(A,M13119) + private static Hop _applyRewrite49(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MULT || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( hi_0.getDataType() != Types.DataType.MATRIX || !hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if (hi_1.getParent().size() > 1) + return hi; + if ( !(hi_1 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_1 = (BinaryOp) hi_1; + + if ( c_hi_1.getOp() != Types.OpOp2.DIV || !c_hi_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1_0 = hi_1.getInput(0); + + if ( !(hi_1_0 instanceof LiteralOp) ) + return hi; + + LiteralOp l_hi_1_0 = (LiteralOp) hi_1_0; + + if ( l_hi_1_0.getDataType() != Types.DataType.SCALAR|| !l_hi_1_0.getValueType().isNumeric() ) + return hi; + + if ( l_hi_1_0.getDoubleValue() != 1.0 ) + return hi; + + Hop hi_1_1 = hi_1.getInput(1); + + if ( hi_1_1.getDataType() != Types.DataType.MATRIX || !hi_1_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: /(A,M13119) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0, hi_1_1) ) { + return hi; + } + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0, hi_1_1, Types.OpOp2.DIV); + + Hop newRoot = v1; + if ( v1.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_1); + HopRewriteUtils.cleanupUnreferenced(hi_1_0); + + DMLExecutor.println("Applying rewrite: *(A,/(1.0,M13119)) => /(A,M13119)"); + return newRoot; + } + + // Implementation of the rule *(f68833,-(0.0,M48693)) => *(M48693,-(0.0,f68833)) + private static Hop _applyRewrite50(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MULT || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( hi_0.getDataType() != Types.DataType.SCALAR || !hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if (hi_1.getParent().size() > 1) + return hi; + if ( !(hi_1 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_1 = (BinaryOp) hi_1; + + if ( c_hi_1.getOp() != Types.OpOp2.MINUS || !c_hi_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1_0 = hi_1.getInput(0); + + if ( !(hi_1_0 instanceof LiteralOp) ) + return hi; + + LiteralOp l_hi_1_0 = (LiteralOp) hi_1_0; + + if ( l_hi_1_0.getDataType() != Types.DataType.SCALAR|| !l_hi_1_0.getValueType().isNumeric() ) + return hi; + + if ( l_hi_1_0.getDoubleValue() != 0.0 ) + return hi; + + Hop hi_1_1 = hi_1.getInput(1); + + if ( hi_1_1.getDataType() != Types.DataType.MATRIX || !hi_1_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: *(M48693,-(0.0,f68833)) + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_1_0, hi_0, Types.OpOp2.MINUS); + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(hi_1_1, v1, Types.OpOp2.MULT); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_1); + + DMLExecutor.println("Applying rewrite: *(f68833,-(0.0,M48693)) => *(M48693,-(0.0,f68833))"); + return newRoot; + } + + // Implementation of the rule -(M22650,*(*(f97734,M67673),M97683)) => -*(M22650,f97734,*(M97683,M67673)) + private static Hop _applyRewrite51(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MINUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( hi_0.getDataType() != Types.DataType.MATRIX || !hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if (hi_1.getParent().size() > 1) + return hi; + if ( !(hi_1 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_1 = (BinaryOp) hi_1; + + if ( c_hi_1.getOp() != Types.OpOp2.MULT || !c_hi_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1_0 = hi_1.getInput(0); + + if (hi_1_0.getParent().size() > 1) + return hi; + if ( !(hi_1_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_1_0 = (BinaryOp) hi_1_0; + + if ( c_hi_1_0.getOp() != Types.OpOp2.MULT || !c_hi_1_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1_0_0 = hi_1_0.getInput(0); + + if ( hi_1_0_0.getDataType() != Types.DataType.SCALAR || !hi_1_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1_0_1 = hi_1_0.getInput(1); + + if ( hi_1_0_1.getDataType() != Types.DataType.MATRIX || !hi_1_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1_1 = hi_1.getInput(1); + + if ( hi_1_1.getDataType() != Types.DataType.MATRIX || !hi_1_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: -*(M22650,f97734,*(M97683,M67673)) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1_1, hi_1_0_1) ) { + return hi; + } + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_1_1, hi_1_0_1, Types.OpOp2.MULT); + if ( !RewriterRuntimeUtils.hasMatchingDims(hi_0, v1) ) { + HopRewriteUtils.removeAllChildReferences(v1); + return hi; + } + TernaryOp v2 = HopRewriteUtils.createTernary(hi_0, hi_1_0_0, v1,Types.OpOp3.MINUS_MULT); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_1); + HopRewriteUtils.cleanupUnreferenced(hi_1_0); + + DMLExecutor.println("Applying rewrite: -(M22650,*(*(f97734,M67673),M97683)) => -*(M22650,f97734,*(M97683,M67673))"); + return newRoot; + } + + // Implementation of the rule -(-(f75306,M67233),*(A,M350)) => -(f75306,+(*(A,M350),M67233)) + private static Hop _applyRewrite52(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MINUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.MINUS || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if ( hi_0_0.getDataType() != Types.DataType.SCALAR || !hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if ( hi_0_1.getDataType() != Types.DataType.MATRIX || !hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if (hi_1.getParent().size() > 1) + return hi; + if ( !(hi_1 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_1 = (BinaryOp) hi_1; + + if ( c_hi_1.getOp() != Types.OpOp2.MULT || !c_hi_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1_0 = hi_1.getInput(0); + + if ( hi_1_0.getDataType() != Types.DataType.MATRIX || !hi_1_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1_1 = hi_1.getInput(1); + + if ( hi_1_1.getDataType() != Types.DataType.MATRIX || !hi_1_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: -(f75306,+(*(A,M350),M67233)) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1_0, hi_1_1) ) { + return hi; + } + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_1_0, hi_1_1, Types.OpOp2.MULT); + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(v1, hi_0_1) ) { + HopRewriteUtils.removeAllChildReferences(v1); + return hi; + } + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(v1, hi_0_1, Types.OpOp2.PLUS); + BinaryOp v3 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_0, v2, Types.OpOp2.MINUS); + + Hop newRoot = v3; + if ( v3.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + HopRewriteUtils.cleanupUnreferenced(hi_1); + + DMLExecutor.println("Applying rewrite: -(-(f75306,M67233),*(A,M350)) => -(f75306,+(*(A,M350),M67233))"); + return newRoot; + } + + // Implementation of the rule -(-(f75306,*(A,M350)),M67233) => -(f75306,+(*(A,M350),M67233)) + private static Hop _applyRewrite53(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MINUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.MINUS || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if ( hi_0_0.getDataType() != Types.DataType.SCALAR || !hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if (hi_0_1.getParent().size() > 1) + return hi; + if ( !(hi_0_1 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0_1 = (BinaryOp) hi_0_1; + + if ( c_hi_0_1.getOp() != Types.OpOp2.MULT || !c_hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1_0 = hi_0_1.getInput(0); + + if ( hi_0_1_0.getDataType() != Types.DataType.MATRIX || !hi_0_1_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1_1 = hi_0_1.getInput(1); + + if ( hi_0_1_1.getDataType() != Types.DataType.MATRIX || !hi_0_1_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.MATRIX || !hi_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: -(f75306,+(*(A,M350),M67233)) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0_1_0, hi_0_1_1) ) { + return hi; + } + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_1_0, hi_0_1_1, Types.OpOp2.MULT); + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(v1, hi_1) ) { + HopRewriteUtils.removeAllChildReferences(v1); + return hi; + } + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(v1, hi_1, Types.OpOp2.PLUS); + BinaryOp v3 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_0, v2, Types.OpOp2.MINUS); + + Hop newRoot = v3; + if ( v3.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + HopRewriteUtils.cleanupUnreferenced(hi_0_1); + + DMLExecutor.println("Applying rewrite: -(-(f75306,*(A,M350)),M67233) => -(f75306,+(*(A,M350),M67233))"); + return newRoot; + } + + // Implementation of the rule +(-(C,d),%*%(A,B)) => -(+(C,%*%(A,B)),d) + private static Hop _applyRewrite54(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.PLUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.MINUS || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if ( hi_0_0.getDataType() != Types.DataType.MATRIX || !hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if ( hi_0_1.getDataType() != Types.DataType.SCALAR || !hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if (hi_1.getParent().size() > 1) + return hi; + if ( !HopRewriteUtils.isMatrixMultiply(hi_1) ) + return hi; + + Hop hi_1_0 = hi_1.getInput(0); + + if ( hi_1_0.getDataType() != Types.DataType.MATRIX || !hi_1_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1_1 = hi_1.getInput(1); + + if ( hi_1_1.getDataType() != Types.DataType.MATRIX || !hi_1_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: -(+(C,%*%(A,B)),d) + AggBinaryOp v1 = HopRewriteUtils.createMatrixMultiply(hi_1_0, hi_1_1); + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0_0, v1) ) { + HopRewriteUtils.removeAllChildReferences(v1); + return hi; + } + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_0, v1, Types.OpOp2.PLUS); + BinaryOp v3 = HopRewriteUtils.createAutoGeneratedBinary(v2, hi_0_1, Types.OpOp2.MINUS); + + Hop newRoot = v3; + if ( v3.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + HopRewriteUtils.cleanupUnreferenced(hi_1); + + DMLExecutor.println("Applying rewrite: +(-(C,d),%*%(A,B)) => -(+(C,%*%(A,B)),d)"); + return newRoot; + } + + private static Hop castIfNecessary(Hop newRoot, Hop oldRoot) { + Types.OpOp1 cast = null; + switch ( oldRoot.getValueType().toExternalString() ) { + case "DOUBLE": + cast = Types.OpOp1.CAST_AS_DOUBLE; + break; + case "INT": + cast = Types.OpOp1.CAST_AS_INT; + break; + case "BOOLEAN": + cast = Types.OpOp1.CAST_AS_BOOLEAN; + break; + default: + return null; + } + + return new UnaryOp("tmp", oldRoot.getDataType(), oldRoot.getValueType(), cast, newRoot); + } + private static int minIdx(double[] l) { + double minValue = Double.MAX_VALUE; + int minIdx = -1; + + for (int i = 0; i < l.length; i++) { + if (l[i] < minValue) { + minValue = l[i]; + minIdx = i; + } + } + + return minIdx; + } +} \ No newline at end of file diff --git a/src/main/java/org/apache/sysds/hops/rewriter/generated/RewriteAutomaticallyGenerated.java b/src/main/java/org/apache/sysds/hops/rewriter/generated/RewriteAutomaticallyGenerated.java new file mode 100644 index 00000000000..d8a05d85ebd --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/generated/RewriteAutomaticallyGenerated.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter.generated; + +import org.apache.sysds.hops.Hop; +import org.apache.sysds.hops.rewrite.HopRewriteRule; +import org.apache.sysds.hops.rewrite.ProgramRewriteStatus; +import org.apache.sysds.hops.rewriter.rule.RewriterRuleSet; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Function; + +public class RewriteAutomaticallyGenerated extends HopRewriteRule { + public static final String FILE_PATH = null; + public static RewriteAutomaticallyGenerated existingRewrites; + + private Function rewriteFn; + public static long totalTimeNanos = 0; + public static long callCount = 0; + public static long maxTimeNanos = -1; + + // This constructor could be used to dynamically compile generated rewrite rules from a file + @Deprecated + public RewriteAutomaticallyGenerated() { + // Try to read the file + try { + final RuleContext ctx = RewriterUtils.buildDefaultContext(); + List lines = Files.readAllLines(Paths.get(FILE_PATH)); + RewriterRuleSet ruleSet = RewriterRuleSet.deserialize(lines, ctx); + + rewriteFn = ruleSet.compile("AutomaticallyGeneratedRewriteFunction", false); + existingRewrites = this; + } catch (IOException e) { + } + } + + public RewriteAutomaticallyGenerated(Function rewriteFn) { + this.rewriteFn = rewriteFn; + } + + @Override + public ArrayList rewriteHopDAGs(ArrayList roots, ProgramRewriteStatus state) { + if( roots == null || rewriteFn == null ) + return roots; + + long startNanos = System.nanoTime(); + + //one pass rewrite-descend (rewrite created pattern) + for( Hop h : roots ) + rule_apply( h, false ); + Hop.resetVisitStatus(roots, true); + + //one pass descend-rewrite (for rollup) + for( Hop h : roots ) + rule_apply( h, true ); + + long diff = System.nanoTime() - startNanos; + totalTimeNanos += diff; + callCount++; + if (maxTimeNanos == -1 || maxTimeNanos < diff) + maxTimeNanos = diff; + + return roots; + } + + @Override + public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) { + if( root == null || rewriteFn == null ) + return root; + + long startNanos = System.nanoTime(); + + //one pass rewrite-descend (rewrite created pattern) + rule_apply( root, false ); + + root.resetVisitStatus(); + + //one pass descend-rewrite (for rollup) + rule_apply( root, true ); + + long diff = System.nanoTime() - startNanos; + totalTimeNanos += diff; + callCount++; + if (maxTimeNanos == -1 || maxTimeNanos < diff) + maxTimeNanos = diff; + + return root; + } + + private void rule_apply(Hop hop, boolean descendFirst) + { + if(hop.isVisited()) + return; + + //recursively process children + for( int i=0; i f; + private final boolean accelerated; + + public RewriterHeuristic(RewriterRuleSet ruleSet) { + this(ruleSet, true); + } + + public RewriterHeuristic(RewriterRuleSet ruleSet, boolean accelerated) { + this.ruleSet = ruleSet; + this.accelerated = accelerated; + this.f = null; + } + + public RewriterHeuristic(Function f) { + this.ruleSet = null; + this.accelerated = false; + this.f = f; + } + + public void forEachRuleSet(Consumer consumer, boolean printNames) { + consumer.accept(ruleSet); + } + + public RewriterStatement apply(RewriterStatement current) { + return apply(current, null); + } + + public RewriterStatement apply(RewriterStatement current, @Nullable BiFunction handler) { + return apply(current, handler, new MutableBoolean(false), true); + } + + public RewriterStatement apply(RewriterStatement currentStmt, @Nullable BiFunction handler, MutableBoolean foundRewrite, boolean print) { + if (f != null) + return f.apply(currentStmt); + + RuleContext.currentContext = ruleSet.getContext(); + + if (handler != null && !handler.apply(currentStmt, null)) + return currentStmt; + + RewriterRuleSet.ApplicableRule rule; + if (accelerated) + rule = ruleSet.acceleratedFindFirst(currentStmt); + else + throw new NotImplementedException("Must use accelerated mode"); + + if (rule != null) + foundRewrite.setValue(true); + + for (int i = 0; i < 500 && rule != null; i++) { + currentStmt = rule.rule.apply(rule.matches.get(0), currentStmt, rule.forward, false); + + if (handler != null && !handler.apply(currentStmt, rule.rule)) { + rule = null; + break; + } + + if (!(currentStmt instanceof RewriterInstruction)) { + rule = null; + break; + } + + if (accelerated) + rule = ruleSet.acceleratedFindFirst(currentStmt); + else + throw new IllegalArgumentException("Must use accelerated mode!"); + } + + if (rule != null) + throw new IllegalArgumentException("Expression did not converge:\n" + currentStmt.toParsableString(ruleSet.getContext(), true) + "\nRule: " + rule); + + return currentStmt; + } + + @Override + public String toString() { + return ruleSet.toString(); + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterHeuristicTransformation.java b/src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterHeuristicTransformation.java new file mode 100644 index 00000000000..4a62323c77b --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterHeuristicTransformation.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter.rule; + +import org.apache.commons.lang3.mutable.MutableBoolean; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.rule.RewriterRule; +import org.apache.sysds.hops.rewriter.rule.RewriterRuleSet; + +import javax.annotation.Nullable; +import java.util.function.BiFunction; +import java.util.function.Consumer; + +public interface RewriterHeuristicTransformation { + RewriterStatement apply(RewriterStatement stmt, @Nullable BiFunction func, MutableBoolean bool, boolean print); + + void forEachRuleSet(Consumer consumer, boolean printNames); + + default RewriterStatement apply(RewriterStatement stmt, @Nullable BiFunction func) { + return apply(stmt, func, new MutableBoolean(false), true); + } + + default RewriterStatement apply(RewriterStatement stmt, @Nullable BiFunction func, boolean print) { + return apply(stmt, func, new MutableBoolean(false), print); + } + + default RewriterStatement apply(RewriterStatement stmt) { + return apply(stmt, null, new MutableBoolean(false), true); + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterHeuristics.java b/src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterHeuristics.java new file mode 100644 index 00000000000..681ac34e1a9 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterHeuristics.java @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter.rule; + +import org.apache.commons.lang3.mutable.MutableBoolean; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.jetbrains.annotations.Nullable; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.BiFunction; +import java.util.function.Consumer; + +public class RewriterHeuristics implements RewriterHeuristicTransformation { + protected static final Log LOG = LogFactory.getLog(RewriterHeuristic.class.getName()); + List heuristics = new ArrayList<>(); + + public void forEachRuleSet(Consumer consumer, boolean printNames) { + heuristics.forEach(entry -> { + if (printNames) { + LOG.info("\n"); + LOG.info("> " + entry.name + " <"); + LOG.info("\n"); + } + entry.heuristics.forEachRuleSet(consumer, printNames); + }); + } + + public void add(String name, RewriterHeuristicTransformation heur) { + heuristics.add(new HeuristicEntry(name, heur)); + } + + public void addRepeated(String name, RewriterHeuristicTransformation heur) { + heuristics.add(new HeuristicEntry(name, new RepeatedHeuristics(heur))); + } + + @Override + public RewriterStatement apply(RewriterStatement stmt, @Nullable BiFunction func, MutableBoolean bool, boolean print) { + for (HeuristicEntry entry : heuristics) { + if (print) { + System.out.println("\n"); + System.out.println("> " + entry.name + " <"); + System.out.println("\n"); + } + + stmt = entry.heuristics.apply(stmt, func, bool, print); + } + + return stmt; + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + + for (HeuristicEntry entry : heuristics) { + sb.append("\n> "); + sb.append(entry.name); + sb.append(" <\n"); + + sb.append(entry.heuristics.toString()); + } + + return sb.toString(); + } + + class RepeatedHeuristics implements RewriterHeuristicTransformation { + RewriterHeuristicTransformation heuristic; + + public RepeatedHeuristics(RewriterHeuristicTransformation heuristic) { + this.heuristic = heuristic; + } + + @Override + public RewriterStatement apply(RewriterStatement stmt, @Nullable BiFunction func, MutableBoolean bool, boolean print) { + bool.setValue(true); + + while (bool.getValue()) { + bool.setValue(false); + stmt = heuristic.apply(stmt, func, bool, print); + } + + return stmt; + } + + @Override + public void forEachRuleSet(Consumer consumer, boolean printNames) { + heuristic.forEachRuleSet(consumer, printNames); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + + sb.append("\n===== REPEAT =====\n"); + + for (HeuristicEntry entry : heuristics) { + sb.append("\n> "); + sb.append(entry.name); + sb.append(" <\n"); + + sb.append(entry.heuristics.toString()); + } + + sb.append("\n===== END REPEAT ====="); + + return sb.toString(); + } + } + + + class HeuristicEntry { + String name; + RewriterHeuristicTransformation heuristics; + + public HeuristicEntry(String name, RewriterHeuristicTransformation heuristics) { + this.name = name; + this.heuristics = heuristics; + } + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterRule.java b/src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterRule.java new file mode 100644 index 00000000000..408abe71290 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterRule.java @@ -0,0 +1,489 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter.rule; + +import org.apache.commons.lang3.NotImplementedException; +import org.apache.commons.lang3.mutable.MutableObject; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.apache.sysds.hops.rewriter.assertions.RewriterAssertionUtils; +import org.apache.sysds.hops.rewriter.assertions.RewriterAssertions; +import org.apache.sysds.hops.rewriter.estimators.RewriterCostEstimator; +import org.apache.sysds.hops.rewriter.estimators.RewriterSparsityEstimator; +import scala.Tuple2; +import scala.Tuple3; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +public class RewriterRule { + + private final RuleContext ctx; + private final String name; + private final RewriterStatement fromRoot; + private final RewriterStatement toRoot; + private List toRoots; + private final HashMap linksStmt1ToStmt2; // Contains the explicit links a transformation has (like instructions, (a+b)-c = a+(b-c), but '+' and '-' are the same instruction still [important if instructions have metadata]) + private final HashMap linksStmt2ToStmt1; + private final List>> applyStmt1ToStmt2; + private final List>> applyStmt2ToStmt1; + private final Function iff1to2; + private final Function iff2to1; + private final boolean unidirectional; + private final Consumer postProcessor; + private Set allowedMultiReferences = Collections.emptySet(); + private RewriterAssertions combinedAssertions; + private boolean allowCombinations = false; + private boolean requireCostCheck = false; + private RewriterStatement fromCost = null; + private List toCosts = null; + + public RewriterRule(final RuleContext ctx, String name, RewriterStatement fromRoot, RewriterStatement toRoot, boolean unidirectional, HashMap linksStmt1ToStmt2, HashMap linksStmt2ToStmt1) { + this(ctx, name, fromRoot, toRoot, unidirectional, linksStmt1ToStmt2, linksStmt2ToStmt1, null, null, null, null, null); + } + + public RewriterRule(final RuleContext ctx, String name, RewriterStatement fromRoot, RewriterStatement toRoot, boolean unidirectional, HashMap linksStmt1ToStmt2, HashMap linksStmt2ToStmt1, Function iff1to2, Function iff2to1, List>> apply1To2, List>> apply2To1) { + this(ctx, name, fromRoot, toRoot, unidirectional, linksStmt1ToStmt2, linksStmt2ToStmt1, iff1to2, iff2to1, apply1To2, apply2To1, null); + } + + public RewriterRule(final RuleContext ctx, String name, RewriterStatement fromRoot, RewriterStatement toRoot, boolean unidirectional, HashMap linksStmt1ToStmt2, HashMap linksStmt2ToStmt1, Function iff1to2, Function iff2to1, List>> apply1To2, List>> apply2To1, Consumer postProcessor) { + this.ctx = ctx; + this.name = name; + this.fromRoot = fromRoot; + this.toRoot = toRoot; + this.unidirectional = unidirectional; + this.linksStmt1ToStmt2 = linksStmt1ToStmt2; + this.linksStmt2ToStmt1 = linksStmt2ToStmt1; + this.iff1to2 = iff1to2; + this.iff2to1 = iff2to1; + this.applyStmt1ToStmt2 = apply1To2; + this.applyStmt2ToStmt1 = apply2To1; + this.postProcessor = postProcessor; + } + + // Determine if this rule can universally be applied or only in some conditions (e.g. certain dimensions / sparsity) + public boolean determineConditionalApplicability() { + RewriterAssertions assertions = new RewriterAssertions(ctx); + RewriterAssertionUtils.buildImplicitAssertion(fromRoot, assertions, fromRoot, ctx); + for (RewriterStatement root : getStmt2AsList()) + RewriterAssertionUtils.buildImplicitAssertion(root, assertions, root, ctx); + + List, Long, Long>> costs = RewriterCostEstimator.compareCosts(fromRoot, getStmt2(), assertions, ctx, false, -1, false); + + requireCostCheck = isConditionalMultiRule() || RewriterCostEstimator.doesHaveAnImpactOnOptimalExpression(costs, false, true, 20); + + if (!requireCostCheck) + return false; + + List roots = toRoots == null ? List.of(toRoot) : toRoots; + + boolean integrateSparsityInCosts = isConditionalMultiRule() || RewriterCostEstimator.doesHaveAnImpactOnOptimalExpression(costs, true, false, 20); + + MutableObject assertionRef = new MutableObject<>(assertions); + fromCost = RewriterCostEstimator.getRawCostFunction(fromRoot, ctx, assertionRef, !integrateSparsityInCosts); + toCosts = getStmt2AsList().stream().map(root -> RewriterCostEstimator.getRawCostFunction(root, ctx, assertionRef, !integrateSparsityInCosts)).collect(Collectors.toList()); + + fromCost = RewriterSparsityEstimator.rollupSparsities(fromCost, RewriterSparsityEstimator.estimateAllNNZ(fromRoot, ctx), ctx); + toCosts = IntStream.range(0, toCosts.size()).mapToObj(i -> RewriterSparsityEstimator.rollupSparsities(toCosts.get(i), RewriterSparsityEstimator.estimateAllNNZ(roots.get(i), ctx), ctx)).collect(Collectors.toList()); + + return requireCostCheck; + } + + public boolean requiresCostCheck() { + return requireCostCheck; + } + + public RewriterStatement getStmt1Cost() { + return fromCost; + } + + public RewriterStatement getStmt2Cost() { + return toCosts.get(0); + } + + public List getStmt2Costs() { + return toCosts; + } + + public void buildCombinedAssertions() { + combinedAssertions = RewriterAssertionUtils.buildImplicitAssertions(fromRoot, ctx); + if (toRoot != null) + RewriterAssertionUtils.buildImplicitAssertions(toRoot, combinedAssertions, ctx); + else { + for (RewriterStatement root : toRoots) + RewriterAssertionUtils.buildImplicitAssertions(root, combinedAssertions, ctx); + } + } + + public RewriterAssertions getCombinedAssertions() { + if (combinedAssertions == null) + buildCombinedAssertions(); + + return combinedAssertions; + } + + public void setAllowedMultiReferences(Set allowed, boolean allowCombinations) { + this.allowedMultiReferences = allowed; + this.allowCombinations = allowCombinations; + } + + /** + * Overwrites the rule as a conditional rule + * @param targets all possible target statements + */ + public void setConditional(List targets) { + toRoots = targets; + } + + public boolean isConditionalMultiRule() { + return toRoots != null; + } + + public List getConditionalMultiRuleTargets() { + return toRoots; + } + + public String getName() { + return name; + } + + public RewriterStatement getStmt1() { + return fromRoot; + } + + /** + * Returns the target statement. + * @return the target statement; in case of a multi-rule, this will return the first option + */ + public RewriterStatement getStmt2() { + return toRoot != null ? toRoot : toRoots.get(0); + } + + public List getStmt2AsList() { + return toRoot != null ? List.of(toRoot) : toRoots; + } + + public boolean isUnidirectional() { + return unidirectional; + } + + public HashMap getForwardLinks() { + return linksStmt1ToStmt2; + } + + public HashMap getBackwardLinks() { + return linksStmt2ToStmt1; + } + + public RewriterStatement apply(RewriterStatement.MatchingSubexpression match, RewriterStatement rootNode, boolean forward, boolean inplace) { + return apply(match, rootNode, forward, inplace, false); + } + + public RewriterStatement apply(RewriterStatement.MatchingSubexpression match, RewriterStatement rootNode, boolean forward, boolean inplace, boolean updateTypes) { + return forward ? applyForward(match, rootNode, inplace, updateTypes) : applyBackward(match, rootNode, inplace, updateTypes); + } + + public RewriterStatement applyForward(RewriterStatement.MatchingSubexpression match, RewriterStatement rootNode, boolean inplace, boolean updateTypes) { + return applyForward(match, rootNode, inplace, updateTypes, new MutableObject<>(null)); + } + + public RewriterStatement applyForward(RewriterStatement.MatchingSubexpression match, RewriterStatement rootNode, boolean inplace, boolean updateTypes, MutableObject> modificationHandle) { + if (inplace) + throw new NotImplementedException("Inplace operations have been removed"); + RewriterStatement out = apply(match, rootNode, toRoot, modificationHandle, applyStmt1ToStmt2 == null ? Collections.emptyList() : applyStmt1ToStmt2); + if (updateTypes) + updateTypes(out, ctx); + return out; + } + + public RewriterStatement applyBackward(RewriterStatement.MatchingSubexpression match, RewriterStatement rootNode, boolean inplace, boolean updateTypes) { + return applyBackward(match, rootNode, inplace, updateTypes, new MutableObject<>(null)); + } + + public RewriterStatement applyBackward(RewriterStatement.MatchingSubexpression match, RewriterStatement rootNode, boolean inplace, boolean updateTypes, MutableObject> modificationHandle) { + if (inplace) + throw new NotImplementedException("Inplace operations have been removed"); + RewriterStatement out = apply(match, rootNode, fromRoot, modificationHandle, applyStmt2ToStmt1 == null ? Collections.emptyList() : applyStmt2ToStmt1); + if (updateTypes) + updateTypes(out, ctx); + return out; + } + + public RewriterStatement.MatchingSubexpression matchSingleStmt1(RewriterStatement exprRoot, RewriterStatement.RewriterPredecessor pred, RewriterStatement stmt, boolean allowImplicitTypeConversions) { + RewriterStatement.MatcherContext mCtx = new RewriterStatement.MatcherContext(ctx, stmt, pred, exprRoot, getStmt1(), true, true, false, true, true, false, true, false, false, allowImplicitTypeConversions, linksStmt1ToStmt2); + mCtx.currentStatement = stmt; + boolean match = getStmt1().match(mCtx); + + if (match) { + RewriterStatement.MatchingSubexpression matchExpr = mCtx.toMatch(); + + if (iff1to2 == null || iff1to2.apply(matchExpr)) + return matchExpr; + } + + return null; + } + + public RewriterStatement.MatchingSubexpression matchSingleStmt2(RewriterStatement exprRoot, RewriterStatement.RewriterPredecessor pred, RewriterStatement stmt, boolean allowImplicitTypeConversions) { + RewriterStatement.MatcherContext mCtx = new RewriterStatement.MatcherContext(ctx, stmt, pred, exprRoot, getStmt2(), true, true, false, true, true, false, true, false, false, allowImplicitTypeConversions, linksStmt2ToStmt1); + mCtx.currentStatement = stmt; + boolean match = getStmt2().match(mCtx); + + if (match) { + RewriterStatement.MatchingSubexpression matchExpr = mCtx.toMatch(); + + if (iff2to1 == null || iff2to1.apply(matchExpr)) + return matchExpr; + } + + return null; + } + + public void updateTypes(RewriterStatement root, final RuleContext ctx) { + root.forEachPostOrder((cur, pred) -> { + cur.refreshReturnType(ctx); + }, true); + } + + private RewriterStatement apply(RewriterStatement.MatchingSubexpression match, RewriterStatement rootInstruction, RewriterStatement dest, MutableObject> modificationHandle, List>> applyFunction) { + if (match.getPredecessor().isRoot()) { + final Map createdObjects = new HashMap<>(); + RewriterStatement cpy = dest.nestedCopyOrInject(createdObjects, obj -> { + RewriterStatement assoc = match.getAssocs().get(obj); + if (assoc != null) { + RewriterStatement assocCpy = createdObjects.get(assoc); + if (assocCpy == null) { + assocCpy = assoc.nestedCopyOrInject(createdObjects, obj2 -> null); + createdObjects.put(assoc, assocCpy); + } + + return assocCpy; + } + + return null; + }); + + RewriterStatement tmp = cpy.simplify(ctx); + if (tmp != null) + cpy = tmp; + + match.setNewExprRoot(cpy); + + RewriterStatement oldRootCpy = createdObjects.get(match.getExpressionRoot()); + RewriterAssertions assertions = null; + + if (oldRootCpy != null) { + assertions = (RewriterAssertions) oldRootCpy.getMeta("_assertions"); + oldRootCpy.unsafeRemoveMeta("_assertions"); + } else if (match.getExpressionRoot().getMeta("_assertions") != null) { + assertions = ((RewriterAssertions) match.getExpressionRoot().getMeta("_assertions")).nestedCopyOrInject(createdObjects, (obj, p, pIdx) -> { + RewriterStatement assoc = match.getAssocs().get(obj); + if (assoc != null) { + RewriterStatement assocCpy = createdObjects.get(assoc); + if (assocCpy == null) { + assocCpy = assoc.nestedCopyOrInject(createdObjects, obj2 -> null); + createdObjects.put(assoc, assocCpy); + } + + return assocCpy; + } + + return null; + }, match.getNewExprRoot()); + match.getExpressionRoot().unsafeRemoveMeta("_assertions"); + } + + if (assertions != null) { + if (!cpy.isLiteral()) + cpy.unsafePutMeta("_assertions", assertions); + } + + match.getLinks().forEach(lnk -> lnk.newStmt.replaceAll(createdObjects::get)); + match.getLinks().forEach(lnk -> lnk.transferFunction.accept(lnk)); + applyFunction.forEach(t -> t._2.accept(createdObjects.get(t._1), match)); + + if (postProcessor != null) + postProcessor.accept(cpy); + + if (ctx.metaPropagator != null) { + RewriterStatement mNew = ctx.metaPropagator.apply(cpy); + + if (mNew != cpy) { + mNew.unsafePutMeta("_assertions", cpy.getMeta("_assertions")); + cpy.unsafeRemoveMeta("_assertions"); + cpy = mNew; + } + } + + cpy.prepareForHashing(); + cpy.recomputeHashCodes(ctx); + + modificationHandle.setValue(new Tuple3<>(cpy, null, -1)); + + return cpy; + } + + final Map createdObjects = new HashMap<>(); + RewriterStatement cpy2 = rootInstruction.nestedCopyOrInject(createdObjects, (obj2, parent, pIdx) -> { + if (obj2.equals(match.getMatchRoot())) { + RewriterStatement cpy = dest.nestedCopyOrInject(createdObjects, obj -> { + RewriterStatement assoc = match.getAssocs().get(obj); + if (assoc != null) { + RewriterStatement assocCpy = createdObjects.get(assoc); + if (assocCpy == null) { + assocCpy = assoc.nestedCopyOrInject(createdObjects, obj3 -> null); + createdObjects.put(assoc, assocCpy); + } + return assocCpy; + } + return null; + }); + createdObjects.put(obj2, cpy); + modificationHandle.setValue(new Tuple3<>(cpy, parent, pIdx)); + return cpy; + } + return null; + }); + RewriterStatement tmp = cpy2.simplify(ctx); + if (tmp != null) + cpy2 = tmp; + + match.setNewExprRoot(cpy2); + + match.getLinks().forEach(lnk -> lnk.newStmt.replaceAll(createdObjects::get)); + cpy2.prepareForHashing(); + match.getLinks().forEach(lnk -> lnk.transferFunction.accept(lnk)); + applyFunction.forEach(t -> t._2.accept(createdObjects.get(t._1), match)); + + if (postProcessor != null) + postProcessor.accept(cpy2); + + if (ctx.metaPropagator != null) { + RewriterStatement mNew = ctx.metaPropagator.apply(cpy2); + + if (mNew != cpy2) { + mNew.unsafePutMeta("_assertions", cpy2.getMeta("_assertions")); + cpy2.unsafeRemoveMeta("_assertions"); + cpy2 = mNew; + } + } + + cpy2.prepareForHashing(); + cpy2.recomputeHashCodes(ctx); + + return cpy2; + } + + public String toString() { + if (isUnidirectional()) + if (isConditionalMultiRule()) + return fromRoot.toParsableString(ctx) + " => {" + toRoots.stream().map(stmt -> stmt.toParsableString(ctx)).collect(Collectors.joining("; ")) + "}"; + else + return fromRoot.toParsableString(ctx) + " => " + toRoot.toParsableString(ctx); + else + return fromRoot.toParsableString(ctx) + " <=> " + toRoot.toParsableString(ctx); + } + + public String toParsableString(final RuleContext ctx) { + Map> varDefs = new HashMap<>(); + StringBuilder sb = new StringBuilder(); + Map refs = new HashMap<>(); + int refIdx = fromRoot.toParsableString(sb, refs, 0, varDefs, allowedMultiReferences, ctx); + String stmt1 = sb.toString(); + sb = new StringBuilder(); + if (toRoot != null) { + toRoot.toParsableString(sb, refs, refIdx, varDefs, allowedMultiReferences, ctx); + } else { + for (RewriterStatement mToRoot : toRoots) { + mToRoot.toParsableString(sb, refs, refIdx, varDefs, allowedMultiReferences, ctx); + sb.append('\n'); + } + } + String stmt2 = sb.toString(); + String multiRefDefs = ""; + + if (!allowedMultiReferences.isEmpty()) { + multiRefDefs = "AllowedMultiRefs:" + allowedMultiReferences.stream().map(stmt -> "$" + refs.get(stmt)).collect(Collectors.joining(",")) + "\nAllowCombinations:" + allowCombinations + "\n"; + } + + String defs = RewriterStatement.parsableDefinitions(varDefs); + + if (toRoot != null) + return multiRefDefs + defs + "\n" + stmt1 + "\n=>\n" + stmt2; + else + return multiRefDefs + defs + "\n" + stmt1 + "\n=>\n{\n" + stmt2 + "}"; + } + + public static class LinkObject { + public List stmt; + public Consumer transferFunction; + + public LinkObject() { + stmt = new ArrayList<>(2); + } + + public LinkObject(List stmt, Consumer transferFunction) { + this.stmt = stmt; + this.transferFunction = transferFunction; + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < stmt.size(); i++) { + if (i != 0) + sb.append(", "); + sb.append(stmt.get(i)); + } + return sb.toString(); + } + + @Override + public boolean equals(Object o) { + return o instanceof LinkObject && ((LinkObject)o).stmt == stmt; + } + + @Override + public int hashCode() { + return stmt.hashCode(); + } + } + + public static class ExplicitLink { + public final RewriterStatement oldStmt; + public List newStmt; + public final Consumer transferFunction; + + public ExplicitLink(RewriterStatement oldStmt, List newStmt, Consumer transferFunction) { + this.oldStmt = oldStmt; + this.newStmt = new ArrayList<>(newStmt); + this.transferFunction = transferFunction; + } + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterRuleBuilder.java b/src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterRuleBuilder.java new file mode 100644 index 00000000000..078d79e2c95 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterRuleBuilder.java @@ -0,0 +1,543 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter.rule; + +import org.apache.sysds.hops.rewriter.RewriterDataType; +import org.apache.sysds.hops.rewriter.RewriterInstruction; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; +import scala.Tuple2; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Set; +import java.util.UUID; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import java.util.function.Function; + +public class RewriterRuleBuilder { + private final RuleContext ctx; + private String ruleName = "?"; + private ArrayList instrSeq = new ArrayList<>(); + private ArrayList mappingSeq = new ArrayList<>(); + private HashMap globalIds = new HashMap<>(); + private HashMap instrSeqIds = new HashMap<>(); + private HashMap mappingSeqIds = new HashMap<>(); + private HashMap linksStmt1ToStmt2 = new HashMap<>(); + private ArrayList>> applyStmt1ToStmt2 = new ArrayList<>(); + private HashMap linksStmt2ToStmt1 = new HashMap<>(); + private ArrayList>> applyStmt2ToStmt1 = new ArrayList<>(); + private RewriterStatement fromRoot = null; + private RewriterStatement toRoot = null; + private List multiRuleRoots = null; + private Function iff1to2 = null; + private Function iff2to1 = null; + private boolean isUnidirectional = false; + private boolean buildSingleDAG = false; + + private RewriterStatement currentStatement = null; + private boolean mappingState = false; + + private boolean canBeModified = true; + + private Set allowedMultiReferences = Collections.emptySet(); + private boolean allowCombinations = false; + + public RewriterRuleBuilder(final RuleContext ctx) { + this.ctx = ctx; + } + + public RewriterRuleBuilder(final RuleContext ctx, String ruleName) { + this.ctx = ctx; + this.ruleName = ruleName; + } + + public RewriterRuleBuilder iff(Function iff, boolean forward) { + if (buildSingleDAG) + throw new IllegalArgumentException(); + + if (forward) + iff1to2 = iff; + else + iff2to1 = iff; + + return this; + } + + public RewriterRuleBuilder parseGlobalVars(String globalVarDefinition) { + if (!canBeModified) + throw new IllegalArgumentException(); + RewriterUtils.parseDataTypes(globalVarDefinition, globalIds, ctx); + return this; + } + + public RewriterRuleBuilder intLiteral(String id, int value) { + return intLiteral(id, value, "global"); + } + + public RewriterRuleBuilder intLiteral(String id, int value, String scope) { + switch (scope) { + case "global": + globalIds.put(id, new RewriterDataType().as(id).ofType("INT").asLiteral(value)); + break; + case "from": + instrSeqIds.put(id, new RewriterDataType().as(id).ofType("INT").asLiteral(value)); + break; + case "to": + mappingSeqIds.put(id, new RewriterDataType().as(id).ofType("INT").asLiteral(value)); + break; + } + + return this; + } + + public RewriterRuleBuilder parseGlobalStatementAsVariable(String varName, String expr) { + return parseGlobalStatementAsVariable(varName, expr, new HashMap<>()); + } + + public RewriterRuleBuilder parseGlobalStatementAsVariable(String varName, String expr, HashMap refMap) { + if (!canBeModified) + throw new IllegalArgumentException(); + + RewriterStatement parsed = RewriterUtils.parseExpression(expr, refMap, globalIds, ctx); + parsed.consolidate(ctx); + globalIds.put(varName, parsed); + return this; + } + + public RewriterRuleBuilder withParsedStatement(String stmt) { + return withParsedStatement(stmt, new HashMap<>()); + } + + public RewriterRuleBuilder withParsedStatement(String stmt, HashMap refMap) { + if (!canBeModified) + throw new IllegalArgumentException(); + fromRoot = RewriterUtils.parseExpression(stmt, refMap, globalIds, ctx); + fromRoot.forEachPreOrderWithDuplicates(el -> { + instrSeqIds.put(el.getId(), el); + return true; + }); + return this; + } + + public RewriterRuleBuilder toParsedStatement(String stmt) { + return toParsedStatement(stmt, new HashMap<>()); + } + + public RewriterRuleBuilder toParsedStatement(String stmt, HashMap refMap) { + if (!canBeModified) + throw new IllegalArgumentException(); + mappingState = true; + toRoot = RewriterUtils.parseExpression(stmt, refMap, globalIds, ctx); + toRoot.forEachPreOrderWithDuplicates(el -> { + mappingSeqIds.put(el.getId(), el); + return true; + }); + return this; + } + + public RewriterRuleBuilder prepare() { + if (!canBeModified) + return this; + if (buildSingleDAG) { + getCurrentInstruction().consolidate(ctx); + fromRoot.prepareForHashing(); + fromRoot.recomputeHashCodes(ctx); + canBeModified = false; + } else { + if (getCurrentInstruction() != null) + getCurrentInstruction().consolidate(ctx); + fromRoot.prepareForHashing(); + if (toRoot != null) + toRoot.prepareForHashing(); + else + multiRuleRoots.forEach(RewriterStatement::prepareForHashing); + fromRoot.recomputeHashCodes(ctx); + if (toRoot != null) + toRoot.recomputeHashCodes(ctx); + else + multiRuleRoots.forEach(rt -> rt.recomputeHashCodes(ctx)); + canBeModified = false; + } + + return this; + } + + public RewriterRule build() { + if (buildSingleDAG) + throw new IllegalArgumentException("Cannot build a rule if DAG was specified"); + if (!mappingState) + throw new IllegalArgumentException("No mapping expression"); + if (fromRoot == null) + throw new IllegalArgumentException("From-root statement cannot be null"); + if (toRoot == null && multiRuleRoots == null) + throw new IllegalArgumentException("To-root statement cannot be null"); + if (getCurrentInstruction() != null) + getCurrentInstruction().consolidate(ctx); + prepare(); + RewriterRule rule = new RewriterRule(ctx, ruleName, fromRoot, toRoot, isUnidirectional, linksStmt1ToStmt2, linksStmt2ToStmt1, iff1to2, iff2to1, applyStmt1ToStmt2, applyStmt2ToStmt1); + rule.setAllowedMultiReferences(allowedMultiReferences, allowCombinations); + if (multiRuleRoots != null) + rule.setConditional(multiRuleRoots); + return rule; + } + + public RewriterStatement buildDAG() { + if (!buildSingleDAG) + throw new IllegalArgumentException("Cannot build a DAG if rule was specified"); + prepare(); + return fromRoot; + } + + public RewriterRuleBuilder asDAGBuilder() { + buildSingleDAG = true; + return this; + } + + public RewriterRuleBuilder setUnidirectional(boolean unidirectional) { + this.isUnidirectional = unidirectional; + return this; + } + + public RewriterStatement getCurrentInstruction() { + if (mappingState) + if (mappingSeq.size() > 0) + return mappingSeq.get(mappingSeq.size()-1); + else if (toRoot != null) + return toRoot; + else if (multiRuleRoots != null) + return multiRuleRoots.get(0); // Just as a dummy + else + throw new IllegalArgumentException("There is no current instruction in the mapping sequence"); + else + if (instrSeq.size() > 0) + return instrSeq.get(instrSeq.size()-1); + else if (fromRoot != null) + return fromRoot; + else + throw new IllegalArgumentException("There is no current instruction in the instruction sequence"); + } + + public RewriterDataType getCurrentOperand() { + if (currentStatement instanceof RewriterDataType) + return (RewriterDataType)currentStatement; + else + throw new IllegalArgumentException("The current operand is not a data type"); + } + + public RewriterRuleBuilder withDataType(String id, String type) { + withDataType(id, type, null); + return this; + } + + public RewriterRuleBuilder withDataType(String id, String type, Object literal) { + if (!canBeModified) + throw new IllegalArgumentException("The DAG is final and cannot be modified"); + if (!instrSeq.isEmpty()) + throw new IllegalArgumentException("To define a single data type, the instruction sequence must be empty"); + fromRoot = new RewriterDataType().ofType(type).asLiteral(literal).as(id); + storeVar(fromRoot); + return this; + } + + public RewriterRuleBuilder withInstruction(String instr) { + if (!canBeModified) + throw new IllegalArgumentException("The DAG is final and cannot be modified"); + if (mappingState) + throw new IllegalArgumentException("Cannot add an instruction when a mapping instruction was already defined"); + if (instrSeq.size() > 0) + getCurrentInstruction().consolidate(ctx); + instrSeq.add(new RewriterInstruction().withInstruction(instr)); + return this; + } + + public RewriterRuleBuilder completeRule(RewriterStatement from, RewriterStatement to) { + if (!canBeModified) + throw new IllegalArgumentException("The DAG is final and cannot be modified"); + if (mappingState) + throw new IllegalArgumentException("Cannot add an instruction when a mapping instruction was already defined"); + this.fromRoot = from; + this.toRoot = to; + this.mappingState = true; + return this; + } + + public RewriterRuleBuilder completeConditionalRule(RewriterStatement from, List to) { + if (!canBeModified) + throw new IllegalArgumentException("The DAG is final and cannot be modified"); + if (mappingState) + throw new IllegalArgumentException("Cannot add an instruction when a mapping instruction was already defined"); + this.fromRoot = from; + this.multiRuleRoots = to; + this.mappingState = true; + return this; + } + + public RewriterRuleBuilder withAllowedMultiRefs(Set allowedMultiRefs, boolean allowCombinations) { + if (!canBeModified) + throw new IllegalArgumentException("The DAG is final and cannot be modified"); + if (!mappingState) + throw new IllegalArgumentException("Cannot add an instruction when a mapping instruction was already defined"); + + this.allowedMultiReferences = allowedMultiRefs; + this.allowCombinations = allowCombinations; + return this; + } + + public RewriterRuleBuilder withOps(RewriterDataType... operands) { + if (!canBeModified) + throw new IllegalArgumentException("The DAG is final and cannot be modified"); + ((RewriterInstruction)getCurrentInstruction()).withOps(operands); + currentStatement = null; + return this; + } + + public RewriterRuleBuilder addOp(String id) { + if (!canBeModified) + throw new IllegalArgumentException("The DAG is final and cannot be modified"); + RewriterDataType dt = new RewriterDataType().as(id); + storeVar(dt); + ((RewriterInstruction)getCurrentInstruction()).addOp(dt); + if (currentStatement != null) + currentStatement.consolidate(ctx); + currentStatement = dt; + return this; + } + + public RewriterRuleBuilder addDynamicOpListInstr(String id, String type, boolean fromInstr, String... ops) { + if (!canBeModified) + throw new IllegalArgumentException("The DAG is final and cannot be modified"); + + if (fromInstr) + withInstruction("argList"); + else + toInstruction("argList"); + + if (ops.length == 0 && type.endsWith("...")) { + // Add one placeholder operand to implicitly determine the data type + addOp(UUID.randomUUID().toString()).ofType(type.substring(0, type.length()-3)); + } else { + for (String op : ops) + addExistingOp(op); + } + + as(id); + return this; + } + + public RewriterRuleBuilder asLiteral(Object literal) { + if (!canBeModified) + throw new IllegalArgumentException("The DAG is final and cannot be modified"); + getCurrentOperand().asLiteral(literal); + return this; + } + + public RewriterRuleBuilder as(String id) { + if (!canBeModified) + throw new IllegalArgumentException("The DAG is final and cannot be modified"); + getCurrentInstruction().as(id); + currentVars().put(id, getCurrentInstruction()); + storeVar(getCurrentInstruction()); + return this; + } + + public RewriterRuleBuilder asRootInstruction() { + if (!canBeModified) + throw new IllegalArgumentException("The DAG is final and cannot be modified"); + if (mappingState) { + if (toRoot != null) + throw new IllegalArgumentException("Cannot have more than one root instruction"); + toRoot = getCurrentInstruction().as("result"); + mappingSeqIds.put("result", toRoot); + } else { + if (fromRoot != null) + throw new IllegalArgumentException("Cannot have more than one root instruction"); + fromRoot = getCurrentInstruction().as("result"); + instrSeqIds.put("result", fromRoot); + } + return this; + } + + public RewriterRuleBuilder addExistingOp(String id) { + if (!canBeModified) + throw new IllegalArgumentException("The DAG is final and cannot be modified"); + RewriterStatement operand = findVar(id); + + if (operand == null) + throw new IllegalArgumentException("Operand with id '" + id + "' does not exist"); + + if (currentStatement != null) + currentStatement.consolidate(ctx); + + currentStatement = operand; + ((RewriterInstruction)getCurrentInstruction()).addOp(operand); + + return this; + } + + public RewriterRuleBuilder ofType(String type) { + if (!canBeModified) + throw new IllegalArgumentException("The DAG is final and cannot be modified"); + getCurrentOperand().ofType(type); + return this; + } + + public RewriterRuleBuilder instrMeta(String key, Object value) { + if (!canBeModified) + throw new IllegalArgumentException("The DAG is final and cannot be modified"); + getCurrentInstruction().putMeta(key, value); + return this; + } + + public RewriterRuleBuilder operandMeta(String key, Object value) { + if (!canBeModified) + throw new IllegalArgumentException("The DAG is final and cannot be modified"); + getCurrentOperand().putMeta(key, value); + return this; + } + + public RewriterRuleBuilder toInstruction(String instr) { + if (!canBeModified) + throw new IllegalArgumentException("The DAG is final and cannot be modified"); + if (buildSingleDAG) + throw new IllegalArgumentException("Cannot create a mapping instruction when building a single DAG"); + getCurrentInstruction().consolidate(ctx); + mappingSeq.add(new RewriterInstruction().withInstruction(instr)); + mappingState = true; + return this; + } + + public RewriterRuleBuilder linkUnidirectional(String idFrom, String idTo, Consumer transferFunction, boolean forward) { + return linkManyUnidirectional(idFrom, List.of(idTo), transferFunction, forward); + } + + public RewriterRuleBuilder linkManyUnidirectional(String idFrom, List idsTo, Consumer transferFunction, boolean forward) { + prepare(); + RewriterStatement stmt1 = forward ? instrSeqIds.get(idFrom) : mappingSeqIds.get(idFrom); + if (stmt1 == null) + stmt1 = globalIds.get(idFrom); + if (stmt1 == null) + throw new IllegalArgumentException("Could not find instruction id: " + idFrom); + if (!stmt1.isConsolidated()) + stmt1.consolidate(ctx); + + List stmts2 = new ArrayList<>(); + + for (String idTo : idsTo) { + RewriterStatement stmt2 = forward ? mappingSeqIds.get(idTo) : instrSeqIds.get(idTo); + if (stmt2 == null) + stmt2 = globalIds.get(idTo); + if (stmt2 == null) + throw new IllegalArgumentException("Could not find instruction id: " + idTo); + if (!stmt2.isConsolidated()) + stmt2.consolidate(ctx); + + stmts2.add(stmt2); + } + + HashMap links = forward ? linksStmt1ToStmt2 : linksStmt2ToStmt1; + + RewriterRule.LinkObject lnk = new RewriterRule.LinkObject(stmts2, transferFunction); + + if (links.containsKey(stmt1) || links.containsValue(lnk)) + throw new IllegalArgumentException("Key or value already exists in explicit link map."); + + links.put(stmt1, lnk); + return this; + } + + public RewriterRuleBuilder link(String id, String id2, Consumer transferFunction) { + linkUnidirectional(id, id2, transferFunction, true); + linkUnidirectional(id2, id, transferFunction, false); + return this; + } + + public RewriterRuleBuilder apply(String id, Consumer applicationFunction, boolean forward) { + return apply(id, (stmt, match) -> applicationFunction.accept(stmt), forward); + } + + public RewriterRuleBuilder apply(String id, BiConsumer applicationFunction, boolean forward) { + prepare(); + RewriterStatement stmt1 = forward ? mappingSeqIds.get(id) : instrSeqIds.get(id); + if (stmt1 == null) + stmt1 = globalIds.get(id); + if (stmt1 == null) + throw new IllegalArgumentException("Could not find instruction id: " + id); + if (!stmt1.isConsolidated()) + stmt1.consolidate(ctx); + + if (forward) + applyStmt1ToStmt2.add(new Tuple2<>(stmt1, applicationFunction)); + else + applyStmt2ToStmt1.add(new Tuple2<>(stmt1, applicationFunction)); + + return this; + } + + public RewriterRuleBuilder toDataType(String id, String type) { + toDataType(id, type, null); + return this; + } + + public RewriterRuleBuilder toDataType(String id, String type, Object literal) { + if (!mappingSeq.isEmpty()) + throw new IllegalArgumentException("To define a single data type, the mapping sequence must be empty"); + toRoot = new RewriterDataType().ofType(type).asLiteral(literal).as(id); + storeVar(toRoot); + return this; + } + + private HashMap currentVars() { + return mappingState ? mappingSeqIds : instrSeqIds; + } + + private RewriterStatement findVar(String id) { + RewriterStatement stmt = null; + + if (mappingState) { + stmt = mappingSeqIds.get(id); + if (stmt != null) + return stmt; + } else { + stmt = instrSeqIds.get(id); + if (stmt != null) + return stmt; + } + return globalIds.get(id); + } + + private void storeVar(RewriterStatement var) { + if (var.getId() == null) + throw new IllegalArgumentException("The id of a statement cannot be null!"); + + if (mappingState) { + mappingSeqIds.put(var.getId(), var); + } else { + if (var instanceof RewriterDataType) + globalIds.put(var.getId(), var); + else + instrSeqIds.put(var.getId(), var); + } + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterRuleCollection.java b/src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterRuleCollection.java new file mode 100644 index 00000000000..0c5d4c99b98 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterRuleCollection.java @@ -0,0 +1,1445 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter.rule; + +import org.apache.sysds.hops.rewriter.RewriterInstruction; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.apache.sysds.hops.rewriter.assertions.RewriterAssertions; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; + +import java.util.HashMap; +import java.util.List; +import java.util.UUID; + +import static org.apache.sysds.hops.rewriter.RewriterContextSettings.ALL_TYPES; +import static org.apache.sysds.hops.rewriter.RewriterContextSettings.SCALARS; + +public class RewriterRuleCollection { + public static void substituteEquivalentStatements(final List rules, final RuleContext ctx) { + HashMap hooks = new HashMap<>(); + + rules.add(new RewriterRuleBuilder(ctx, "as.scalar(A) => cast.FLOAT(A)") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A") + .withParsedStatement("as.scalar(A)") + .toParsedStatement("cast.FLOAT(A)") + .build() + ); + + SCALARS.forEach(t -> { + rules.add(new RewriterRuleBuilder(ctx, "as.matrix(a) => cast.MATRIX(a)") + .setUnidirectional(true) + .parseGlobalVars(t + ":a") + .withParsedStatement("as.matrix(a)") + .toParsedStatement("cast.MATRIX(a)") + .build() + ); + }); + + // Some meta operators + rules.add(new RewriterRuleBuilder(ctx, "rowVec(A) => [](A, ...)") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A") + .parseGlobalVars("LITERAL_INT:1") + .withParsedStatement("rowVec(A)") + .toParsedStatement("[]($1:A, 1, 1, 1, ncol(A))", hooks) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "colVec(A) => [](A, ...)") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A") + .parseGlobalVars("LITERAL_INT:1") + .withParsedStatement("colVec(A)") + .toParsedStatement("[](A, 1, nrow(A), 1, 1)") + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "cellMat(A) => [](A, ...)") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A") + .parseGlobalVars("LITERAL_INT:1") + .withParsedStatement("cellMat(A)") + .toParsedStatement("[](A, 1, 1, 1, 1)") + .build() + ); + + substituteFusedOps(rules, ctx); + } + + public static void substituteFusedOps(final List rules, final RuleContext ctx) { + // Now resolve fused operators + rules.add(new RewriterRuleBuilder(ctx, "1-*(A,B) => -(1, *(A, B))") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .parseGlobalVars("LITERAL_FLOAT:1.0") // We take a float as this framework is optimized for floats + .withParsedStatement("1-*(A, B)") + .toParsedStatement("-(1.0, *(A, B))") + .build() + ); + rules.add(new RewriterRuleBuilder(ctx, "log_nz(A) => *(!=(A, 0.0), log(A))") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A") + .parseGlobalVars("LITERAL_FLOAT:0.0") // We take a float as this framework is optimized for floats + .withParsedStatement("log_nz(A)") + .toParsedStatement("*(!=(A, 0.0), log(A))") + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "sumSq(A) => sum(*(A,A))") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A") + .parseGlobalVars("LITERAL_FLOAT:0.0") + .withParsedStatement("sumSq(A)") + .toParsedStatement("sum(*(A,A))") + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "+*(A,s,Y) => +(A, *(s, Y))") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,Y") + .parseGlobalVars("FLOAT:s") + .withParsedStatement("+*(A,s,Y)") + .toParsedStatement("+(A, *(s, Y))") + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "-*(A,s,Y) => -(A, *(s, Y))") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,Y") + .parseGlobalVars("FLOAT:s") + .withParsedStatement("-*(A,s,Y)") + .toParsedStatement("-(A, *(s, Y))") + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "sq(A) => *(A,A)") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A") + .withParsedStatement("sq(A)") + .toParsedStatement("*(A, A)") + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "_nnz(A) => sum(!=(A,0.0))") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A") + .parseGlobalVars("LITERAL_FLOAT:0.0") + .withParsedStatement("_nnz(A)") + .toParsedStatement("sum(!=(A,0.0))") + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "*2(A) => +(A,A)") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A") + .withParsedStatement("*2(A)") + .toParsedStatement("+(A,A)") + .build() + ); + + SCALARS.forEach(t -> { + rules.add(new RewriterRuleBuilder(ctx, "log_nz(A, a) => *(!=(A, 0.0), *(log(A), inv(log(a)))") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A") + .parseGlobalVars("FLOAT:a") // We take a float as this framework is optimized for floats + .parseGlobalVars("LITERAL_FLOAT:0.0") + .withParsedStatement("log_nz(A, a)") + .toParsedStatement("*(!=(A, 0.0), *(log(A), inv(log(a))))") + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "log(A, a) => *(log(A), inv(log(a)))") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A") + .parseGlobalVars("FLOAT:a") + .withParsedStatement("log(A, a)") + .toParsedStatement("*(log(A), inv(log(a)))") + .build() + ); + }); + } + + public static void eliminateMultipleCasts(final List rules, final RuleContext ctx) { + SCALARS.forEach(t -> { + rules.add(new RewriterRuleBuilder(ctx, "cast.TYPE(cast.TYPE(A)) => cast.TYPE(A)") + .setUnidirectional(true) + .parseGlobalVars(t + ":a") + .withParsedStatement("cast.MATRIX(cast.MATRIX(a))") + .toParsedStatement("cast.MATRIX(a)") + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "cast.TYPE(a::TYPE) => a") + .setUnidirectional(true) + .parseGlobalVars(t + ":a") + .withParsedStatement("cast." + t + "(a)") + .toParsedStatement("a") + .build() + ); + + SCALARS.forEach(t2 -> { + SCALARS.forEach(t3 -> { + rules.add(new RewriterRuleBuilder(ctx, "cast.TYPE(+(a, b)) => ...") + .setUnidirectional(true) + .parseGlobalVars(t2 + ":a") + .parseGlobalVars(t3 + ":b") + .withParsedStatement("cast." + t + "(+(a,b))") + .toParsedStatement("+(cast." + t + "(a), cast." + t + "(b))") + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "cast.TYPE(*(a, b)) => ...") + .setUnidirectional(true) + .parseGlobalVars(t2 + ":a") + .parseGlobalVars(t3 + ":b") + .withParsedStatement("cast." + t + "(*(a,b))") + .toParsedStatement("*(cast." + t + "(a), cast." + t + "(b))") + .build() + ); + }); + + rules.add(new RewriterRuleBuilder(ctx, "cast.TYPE(cast.TYPE(A)) => cast.TYPE(A)") + .setUnidirectional(true) + .parseGlobalVars(t + ":a") + .withParsedStatement("cast." + t2 + "(cast." + t2 + "(a))") + .toParsedStatement("cast." + t2 + "(a)") + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "cast.SCALAR(cast.MATRIX(a)) => a") + .setUnidirectional(true) + .parseGlobalVars(t2 + ":a") + .withParsedStatement("cast." + t + "(cast.MATRIX(a))") + .toParsedStatement("cast." + t + "(a)") + .build() + ); + }); + }); + } + + public static void canonicalizeAlgebraicStatements(final List rules, boolean allowInversionCanonicalization, final RuleContext ctx) { + HashMap hooks = new HashMap<>(); + + RewriterUtils.buildBinaryPermutations(ALL_TYPES, (t1, t2) -> { + rules.add(new RewriterRuleBuilder(ctx, "-(a,b) => +(a,-(b))") + .setUnidirectional(true) + .parseGlobalVars(t1 + ":a") + .parseGlobalVars(t2 + ":b") + .withParsedStatement("-(a, b)", hooks) + .toParsedStatement("+(a, -(b))", hooks) + .build() + ); + + if (allowInversionCanonicalization) { + rules.add(new RewriterRuleBuilder(ctx, "/(a,b) => *(a, inv(b))") + .setUnidirectional(true) + .parseGlobalVars(t1 + ":a") + .parseGlobalVars(t2 + ":b") + .withParsedStatement("/(a, b)", hooks) + .toParsedStatement("*(a, inv(b))", hooks) + .build() + ); + } + + rules.add(new RewriterRuleBuilder(ctx, "-(+(a, b)) => +(-(a), -(b))") + .setUnidirectional(true) + .parseGlobalVars(t1 + ":a") + .parseGlobalVars(t2 + ":b") + .withParsedStatement("-(+(a, b))", hooks) + .toParsedStatement("$1:+(-(a), -(b))", hooks) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "-(-(a)) => a") + .setUnidirectional(true) + .parseGlobalVars(t1 + ":a") + .parseGlobalVars(t2 + ":b") + .withParsedStatement("-(-(a))", hooks) + .toParsedStatement("a", hooks) + .build() + ); + }); + + rules.add(new RewriterRuleBuilder(ctx, "length(A) => nrow(A) * ncol(A)") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A") + .withParsedStatement("length(A)", hooks) + .toParsedStatement("*(nrow(A), ncol(A))", hooks) + .build() + ); + + for (String t : ALL_TYPES) { + rules.add(new RewriterRuleBuilder(ctx, "-(inv(a)) => inv(-(a))") + .setUnidirectional(true) + .parseGlobalVars(t + ":A") + .withParsedStatement("-(inv(A))", hooks) + .toParsedStatement("inv(-(A))", hooks) + .build() + ); + } + + rules.add(new RewriterRuleBuilder(ctx, "-(sum(A)) => sum(-(A))") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A") + .withParsedStatement("-(sum(A))", hooks) + .toParsedStatement("sum(-(A))", hooks) + .build() + ); + } + + public static void canonicalizeBooleanStatements(final List rules, final RuleContext ctx) { + HashMap hooks = new HashMap<>(); + + RewriterUtils.buildBinaryPermutations(ALL_TYPES, (t1, t2) -> { + rules.add(new RewriterRuleBuilder(ctx, ">(a, b) => <(b, a)") + .setUnidirectional(true) + .parseGlobalVars(t1 + ":a") + .parseGlobalVars(t2 + ":b") + .withParsedStatement(">(a, b)", hooks) + .toParsedStatement("<(b, a)", hooks) + .build() + ); + + // These hold only for boolean expressions + /*rules.add(new RewriterRuleBuilder(ctx, "!(!(a)) = a") + .setUnidirectional(true) + .parseGlobalVars(t1 + ":a") + .parseGlobalVars(t2 + ":b") + .withParsedStatement("!(!(a))", hooks) + .toParsedStatement("a", hooks) + .build() + );*/ + + rules.add(new RewriterRuleBuilder(ctx, "<=(a, b) => |(<(a, b), ==(a, b))") + .setUnidirectional(true) + .parseGlobalVars(t1 + ":a") + .parseGlobalVars(t2 + ":b") + .withParsedStatement("<=(a, b)", hooks) + .toParsedStatement("|(<(a, b), ==(a, b))", hooks) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, ">=(a, b) => |(<(b, a), ==(b, a))") + .setUnidirectional(true) + .parseGlobalVars(t1 + ":a") + .parseGlobalVars(t2 + ":b") + .withParsedStatement(">=(a, b)", hooks) + .toParsedStatement("|(<(b, a), ==(b, a))", hooks) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "!(&(a, b)) => |(!(a), !(b))") + .setUnidirectional(true) + .parseGlobalVars(t1 + ":a") + .parseGlobalVars(t2 + ":b") + .withParsedStatement("!(&(a, b))", hooks) + .toParsedStatement("|(!(a), !(b))", hooks) + .build() + ); + + List.of("&(a, b)", "&(b, a)").forEach(exp -> { + List.of("|(" + exp + ", a)", "|(a, " + exp + ")").forEach(tExpr -> { + rules.add(new RewriterRuleBuilder(ctx, tExpr + " => a") + .setUnidirectional(true) + .parseGlobalVars(t1 + ":a") + .parseGlobalVars(t2 + ":b") + .withParsedStatement(tExpr, hooks) + .toParsedStatement("a", hooks) + .build() + ); + }); + }); + + List.of("|(a, b)", "|(b, a)").forEach(exp -> { + List.of("&(" + exp + ", a)", "&(a, " + exp + ")").forEach(tExpr -> { + rules.add(new RewriterRuleBuilder(ctx, tExpr + " => a") + .setUnidirectional(true) + .parseGlobalVars(t1 + ":a") + .parseGlobalVars(t2 + ":b") + .withParsedStatement(tExpr, hooks) + .toParsedStatement("a", hooks) + .build() + ); + }); + }); + + rules.add(new RewriterRuleBuilder(ctx, "|(<(b, a), <(a, b)) => b != a") + .setUnidirectional(true) + .parseGlobalVars(t1 + ":a") + .parseGlobalVars(t2 + ":b") + .withParsedStatement("|(<(b, a), <(a, b))", hooks) + .toParsedStatement("!=(b, a)", hooks) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "&(<(b, a), <(a, b)) => FALSE") + .setUnidirectional(true) + .parseGlobalVars(t1 + ":a") + .parseGlobalVars(t2 + ":b") + .parseGlobalVars("LITERAL_BOOL:FALSE") + .withParsedStatement("&(<(b, a), <(a, b))", hooks) + .toParsedStatement("FALSE", hooks) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "!(!=(a, b)) => ==(a, b)") + .setUnidirectional(true) + .parseGlobalVars(t1 + ":a") + .parseGlobalVars(t2 + ":b") + .parseGlobalVars("LITERAL_BOOL:FALSE") + .withParsedStatement("!(!=(a, b))", hooks) + .toParsedStatement("==(a, b)", hooks) + .build() + ); + + /*rules.add(new RewriterRuleBuilder(ctx, "==(a, b) => isZero(+(a, -(b)))") + .setUnidirectional(true) + .parseGlobalVars(t1 + ":a") + .parseGlobalVars(t2 + ":b") + .parseGlobalVars("LITERAL_BOOL:FALSE") + .withParsedStatement("!(!=(a, b))", hooks) + .toParsedStatement("==(a, b)", hooks) + .build() + );*/ + }); + } + + // E.g. expand A * B -> _m($1:_idx(), 1, nrow(A), _m($2:_idx(), 1, nrow(B), A[$1, $2] * B[$1, $2])) + public static void expandStreamingExpressions(final List rules, final RuleContext ctx) { + HashMap hooks = new HashMap<>(); + + // cast.MATRIX + rules.add(new RewriterRuleBuilder(ctx, "Expand const matrix") + .setUnidirectional(true) + .parseGlobalVars("FLOAT:a") + .parseGlobalVars("MATRIX:A") + .parseGlobalVars("LITERAL_INT:1") + .withParsedStatement("cast.MATRIX(a)", hooks) + .toParsedStatement("$4:_m(1, 1, a)", hooks) + .build() + ); + + // cast.FLOAT + rules.add(new RewriterRuleBuilder(ctx, "") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:a") + .parseGlobalVars("MATRIX:A") + .parseGlobalVars("LITERAL_INT:1") + .withParsedStatement("cast.FLOAT(A)", hooks) + .toParsedStatement("[](A, 1, 1)", hooks) + .build() + ); + + // Const + rules.add(new RewriterRuleBuilder(ctx, "Expand const matrix") + .setUnidirectional(true) + .parseGlobalVars("FLOAT:a") + .parseGlobalVars("MATRIX:A") + .parseGlobalVars("LITERAL_INT:1") + .withParsedStatement("const(A, a)", hooks) + .toParsedStatement("$4:_m($1:_idx(1, nrow(A)), $2:_idx(1, ncol(A)), a)", hooks) + .apply(hooks.get(1).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) // Assumes it will never collide + .apply(hooks.get(2).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) // Assumes it will never collide + .apply(hooks.get(4).getId(), (stmt, match) -> { + UUID id = UUID.randomUUID(); + stmt.unsafePutMeta("ownerId", id); + stmt.getChild(0).unsafePutMeta("ownerId", id); + }, true) // Assumes it will never collide + .build() + ); + + // Diag + rules.add(new RewriterRuleBuilder(ctx, "Expand diag matrix") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A") + .parseGlobalVars("LITERAL_INT:1") + .parseGlobalVars("LITERAL_FLOAT:0.0") + .withParsedStatement("diag(A)", hooks) + .toParsedStatement("$4:_m($1:_idx(1, nrow(A)), $2:_idx(1, ncol(A)), $5:ifelse(==($1,$2), [](A, $1, $2), 0.0))", hooks) + .apply(hooks.get(1).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) // Assumes it will never collide + .apply(hooks.get(2).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) // Assumes it will never collide + .apply(hooks.get(4).getId(), (stmt, match) -> { + UUID id = UUID.randomUUID(); + stmt.unsafePutMeta("ownerId", id); + stmt.getChild(0).unsafePutMeta("ownerId", id); + RewriterStatement aRef = stmt.getChild(0, 1, 0); + match.getNewExprRoot().getAssertions(ctx).addEqualityAssertion(aRef.getNCol(), aRef.getNRow(), match.getNewExprRoot()); + }, true) // Assumes it will never collide + .build() + ); + + + // Matrix Multiplication + rules.add(new RewriterRuleBuilder(ctx, "Expand matrix product") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .parseGlobalVars("LITERAL_INT:1") + .withParsedStatement("%*%(A, B)", hooks) + .toParsedStatement("$4:_m($1:_idx(1, nrow(A)), $2:_idx(1, ncol(B)), sum($5:_m($3:_idx(1, ncol(A)), 1, *([](A, $1, $3), [](B, $3, $2)))))", hooks) + .apply(hooks.get(1).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) // Assumes it will never collide + .apply(hooks.get(2).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) // Assumes it will never collide + .apply(hooks.get(3).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) // Assumes it will never collide + .apply(hooks.get(4).getId(), (stmt, match) -> { + UUID id = UUID.randomUUID(); + stmt.unsafePutMeta("ownerId", id); + stmt.getOperands().get(0).unsafePutMeta("ownerId", id); + stmt.getOperands().get(1).unsafePutMeta("ownerId", id); + + RewriterStatement aRef = stmt.getChild(0, 1, 0); + RewriterStatement bRef = stmt.getChild(1, 1, 0); + RewriterAssertions assertions = match.getNewExprRoot().getAssertions(ctx); + assertions.addEqualityAssertion(aRef.getNCol(), bRef.getNRow(), match.getNewExprRoot()); + assertions.update(match.getNewExprRoot()); + }, true) // Assumes it will never collide + .apply(hooks.get(5).getId(), stmt -> { + UUID id = UUID.randomUUID(); + stmt.unsafePutMeta("ownerId", id); + stmt.getOperands().get(0).unsafePutMeta("ownerId", id); + }, true) // Assumes it will never collide + .build() + ); + + // E.g. A + B + rules.add(new RewriterRuleBuilder(ctx, "Expand Element Wise Instruction") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .parseGlobalVars("LITERAL_INT:1") + .withParsedStatement("$1:ElementWiseInstruction(A,B)", hooks) + .toParsedStatement("$7:_m($2:_idx(1, $5:nrow(A)), $3:_idx(1, $6:ncol(A)), $4:ElementWiseInstruction([](A, $2, $3), [](B, $2, $3)))", hooks) + .link(hooks.get(1).getId(), hooks.get(4).getId(), RewriterStatement::transferMeta) + .apply(hooks.get(2).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) // Assumes it will never collide + .apply(hooks.get(3).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) + .apply(hooks.get(7).getId(), (stmt, match) -> { + UUID id = UUID.randomUUID(); + stmt.unsafePutMeta("ownerId", id); + stmt.getOperands().get(0).unsafePutMeta("ownerId", id); + stmt.getOperands().get(1).unsafePutMeta("ownerId", id); + + // Now we assert that nrow(A) = nrow(B) and ncol(A) = ncol(B) + RewriterStatement aRef = stmt.getChild(2, 0, 0); + RewriterStatement bRef = stmt.getChild(2, 1, 0); + match.getNewExprRoot().getAssertions(ctx).addEqualityAssertion(aRef.getNRow(), bRef.getNRow(), match.getNewExprRoot()); + match.getNewExprRoot().getAssertions(ctx).addEqualityAssertion(aRef.getNCol(), bRef.getNCol(), match.getNewExprRoot()); + }, true) // Assumes it will never collide + .build() + ); + + List.of("$2:_m(i, j, v1), v2", "v1, $2:_m(i, j, v2)").forEach(s -> { + rules.add(new RewriterRuleBuilder(ctx) + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .parseGlobalVars("LITERAL_INT:1") + .parseGlobalVars("INT:i,j") + .parseGlobalVars("FLOAT:v1,v2") + .withParsedStatement("$1:ElementWiseInstruction(" + s + ")", hooks) + .toParsedStatement("$3:_m(i, j, $4:ElementWiseInstruction(v1, v2))", hooks) + .link(hooks.get(1).getId(), hooks.get(4).getId(), RewriterStatement::transferMeta) + .link(hooks.get(2).getId(), hooks.get(3).getId(), RewriterStatement::transferMeta) + .build() + ); + }); + + // Trace(A) + rules.add(new RewriterRuleBuilder(ctx) + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .parseGlobalVars("LITERAL_INT:1") + .withParsedStatement("trace(A)", hooks) + .toParsedStatement("sum($3:_m($1:_idx(1, $2:nrow(A)), 1, [](A, $1, $1)))", hooks) + .apply(hooks.get(1).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) // Assumes it will never collide + .apply(hooks.get(2).getId(), stmt -> stmt.unsafePutMeta("dontExpand", true), true) + .apply(hooks.get(3).getId(), (stmt, match) -> { + UUID id = UUID.randomUUID(); + stmt.unsafePutMeta("ownerId", id); + stmt.getOperands().get(0).unsafePutMeta("ownerId", id); + + // Assert that the matrix is squared + RewriterStatement aRef = stmt.getChild(0, 1, 0); + match.getNewExprRoot().getAssertions(ctx).addEqualityAssertion(aRef.getNRow(), aRef.getNCol(), match.getNewExprRoot()); + }, true) + .build() + ); + + // t(A) + rules.add(new RewriterRuleBuilder(ctx) + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .parseGlobalVars("LITERAL_INT:1") + .withParsedStatement("t(A)", hooks) + .toParsedStatement("$3:_m($1:_idx(1, ncol(A)), $2:_idx(1, nrow(A)), [](A, $2, $1))", hooks) + .apply(hooks.get(1).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) // Assumes it will never collide + .apply(hooks.get(2).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) + .apply(hooks.get(3).getId(), stmt -> { + UUID id = UUID.randomUUID(); + stmt.unsafePutMeta("ownerId", id); + stmt.getOperands().get(0).unsafePutMeta("ownerId", id); + stmt.getOperands().get(1).unsafePutMeta("ownerId", id); + }, true) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx) + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .parseGlobalVars("LITERAL_INT:1") + .withParsedStatement("rev(A)", hooks) + .toParsedStatement("$3:_m($1:_idx(1, nrow(A)), $2:_idx(1, ncol(A)), [](A, -(+(ncol(A), 1), $1), $2))", hooks) + .apply(hooks.get(1).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) // Assumes it will never collide + .apply(hooks.get(2).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) + .apply(hooks.get(3).getId(), stmt -> { + UUID id = UUID.randomUUID(); + stmt.unsafePutMeta("ownerId", id); + stmt.getOperands().get(0).unsafePutMeta("ownerId", id); + stmt.getOperands().get(1).unsafePutMeta("ownerId", id); + }, true) + .build() + ); + + // rand(rows, cols, min, max) + rules.add(new RewriterRuleBuilder(ctx) + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .parseGlobalVars("LITERAL_INT:1") + .parseGlobalVars("INT:n,m") + .parseGlobalVars("FLOAT:a,b") + .withParsedStatement("rand(n, m, a, b)", hooks) + .toParsedStatement("$3:_m($1:_idx(1, n), $2:_idx(1, m), +(a, $4:*(+(b, -(a)), rand(argList($1,$2)))))", hooks) + .apply(hooks.get(1).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) // Assumes it will never collide + .apply(hooks.get(2).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) + .apply(hooks.get(3).getId(), stmt -> { + UUID id = UUID.randomUUID(); + stmt.unsafePutMeta("ownerId", id); + stmt.getOperands().get(0).unsafePutMeta("ownerId", id); + stmt.getOperands().get(1).unsafePutMeta("ownerId", id); + }, true) + .build() + ); + + // sum(A) = sum(_m($1:_idx(1, nrow(A)), 1, sum(_m($2:_idx(1, ncol(A)), 1, [](A, $1, $2))))) + rules.add(new RewriterRuleBuilder(ctx) + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .parseGlobalVars("LITERAL_INT:1") + .withParsedStatement("sum(A)", hooks) + .toParsedStatement("sum($3:_m($1:_idx(1, nrow(A)), 1, sum($4:_m($2:_idx(1, ncol(A)), 1, [](A, $1, $2)))))", hooks) + .iff(match -> { + RewriterStatement meta = (RewriterStatement) match.getMatchRoot().getOperands().get(0).getMeta("ncol"); + + if (meta == null) + throw new IllegalArgumentException("Column meta should not be null: " + match.getMatchRoot().getOperands().get(0).toString(ctx)); + + return !meta.isLiteral() || ((long)meta.getLiteral()) != 1; + }, true) + .apply(hooks.get(1).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) // Assumes it will never collide + .apply(hooks.get(2).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) + .apply(hooks.get(3).getId(), stmt -> { + UUID id = UUID.randomUUID(); + stmt.unsafePutMeta("ownerId", id); + stmt.getOperands().get(0).unsafePutMeta("ownerId", id); + }, true) + .apply(hooks.get(4).getId(), stmt -> { + UUID id = UUID.randomUUID(); + stmt.unsafePutMeta("ownerId", id); + stmt.getOperands().get(0).unsafePutMeta("ownerId", id); + }, true) + .build() + ); + + // rowSums(A) -> _m($1:_idx(1, nrow(A)), 1, sum(_m($2:_idx(1, ncol(A)), 1, [](A, $1, $2))) + rules.add(new RewriterRuleBuilder(ctx) + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .parseGlobalVars("LITERAL_INT:1") + .withParsedStatement("rowSums(A)", hooks) + .toParsedStatement("$3:_m($1:_idx(1, nrow(A)), 1, sum($4:_m($2:_idx(1, ncol(A)), 1, [](A, $1, $2))))", hooks) + .apply(hooks.get(1).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) // Assumes it will never collide + .apply(hooks.get(2).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) + .apply(hooks.get(3).getId(), stmt -> { + UUID id = UUID.randomUUID(); + stmt.unsafePutMeta("ownerId", id); + stmt.getOperands().get(0).unsafePutMeta("ownerId", id); + }, true) + .apply(hooks.get(4).getId(), stmt -> { + UUID id = UUID.randomUUID(); + stmt.unsafePutMeta("ownerId", id); + stmt.getOperands().get(0).unsafePutMeta("ownerId", id); + }, true) + .build() + ); + + // colSums(A) -> _m($1:_idx(1, ncol(A)), 1, sum(_m($2:_idx(1, nrow(A)), 1, [](A, $2, $1))) + rules.add(new RewriterRuleBuilder(ctx) + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .parseGlobalVars("LITERAL_INT:1") + .withParsedStatement("colSums(A)", hooks) + .toParsedStatement("$3:_m(1, $1:_idx(1, ncol(A)), sum($4:_m($2:_idx(1, nrow(A)), 1, [](A, $2, $1))))", hooks) + .apply(hooks.get(1).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) // Assumes it will never collide + .apply(hooks.get(2).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) + .apply(hooks.get(3).getId(), stmt -> { + UUID id = UUID.randomUUID(); + stmt.unsafePutMeta("ownerId", id); + stmt.getOperands().get(1).unsafePutMeta("ownerId", id); + }, true) + .apply(hooks.get(4).getId(), stmt -> { + UUID id = UUID.randomUUID(); + stmt.unsafePutMeta("ownerId", id); + stmt.getOperands().get(0).unsafePutMeta("ownerId", id); + }, true) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx) + .setUnidirectional(true) + .parseGlobalVars("INT:l") + .withParsedStatement("_idx(l, l)", hooks) + .toParsedStatement("l", hooks) + .build() + ); + + // Scalars dependent on matrix to index streams + rules.add(new RewriterRuleBuilder(ctx) + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .parseGlobalVars("LITERAL_INT:1") + .withParsedStatement("sum(A)", hooks) + .toParsedStatement("sum($3:_idxExpr($1:_idx(1, nrow(A)), $4:_idxExpr($2:_idx(1, ncol(A)), [](A, $1, $2))))", hooks) + .apply(hooks.get(1).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) // Assumes it will never collide + .apply(hooks.get(2).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) + .apply(hooks.get(3).getId(), stmt -> { + UUID id = UUID.randomUUID(); + stmt.unsafePutMeta("ownerId", id); + stmt.getOperands().get(0).unsafePutMeta("ownerId", id); + }, true) + .apply(hooks.get(4).getId(), stmt -> { + UUID id = UUID.randomUUID(); + stmt.unsafePutMeta("ownerId", id); + stmt.getOperands().get(0).unsafePutMeta("ownerId", id); + }, true) + .build() + ); + + // diag(A) -> _m($1:_idx(1, nrow(A)), 1, [](A, $1, $1)) + /*rules.add(new RewriterRuleBuilder(ctx) + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A") + .parseGlobalVars("LITERAL_INT:1") + .withParsedStatement("diag(A)", hooks) + .toParsedStatement("$2:_m($1:_idx(1, nrow(A)), 1, [](A, $1, $1))", hooks) + .apply(hooks.get(1).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) // Assumes it will never collide + .apply(hooks.get(2).getId(), (stmt, match) -> { + UUID id = UUID.randomUUID(); + stmt.unsafePutMeta("ownerId", id); + stmt.getOperands().get(0).unsafePutMeta("ownerId", id); + + RewriterStatement aRef = stmt.getChild(0, 1, 0); + match.getNewExprRoot().getAssertions(ctx).addEqualityAssertion(aRef.getNRow(), aRef.getNCol(), match.getNewExprRoot()); + }, true) + .build() + );*/ + + // cast.MATRIX(a) => _m(1, 1, a) + for (String t : List.of("INT", "BOOL", "FLOAT")) { + rules.add(new RewriterRuleBuilder(ctx) + .setUnidirectional(true) + .parseGlobalVars(t + ":a") + .parseGlobalVars("LITERAL_INT:1") + .withParsedStatement("cast.MATRIX(a)", hooks) + .toParsedStatement("$2:_m(1, 1, a)", hooks) + .apply(hooks.get(2).getId(), (stmt, match) -> stmt.unsafePutMeta("ownerId", UUID.randomUUID()), true) + .build() + ); + } + } + + public static void expandArbitraryMatrices(final List rules, final RuleContext ctx) { + HashMap hooks = new HashMap<>(); + // This must be the last rule in the heuristic as it handles any matrix that has not been written as a stream + // A -> _m() + rules.add(new RewriterRuleBuilder(ctx, "Expand arbitrary matrix expression") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A") + .parseGlobalVars("LITERAL_INT:1") + .withParsedStatement("A", hooks) + .toParsedStatement("$3:_m($1:_idx(1, nrow(A)), $2:_idx(1, ncol(A)), [](A, $1, $2))", hooks) + .iff(match -> match.getMatchRoot().getMeta("dontExpand") == null && !(match.getMatchRoot().isInstruction() && match.getMatchRoot().trueInstruction().equals("_m")), true) + .apply(hooks.get(1).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) // Assumes it will never collide + .apply(hooks.get(2).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) + .apply(hooks.get(3).getId(), stmt -> { + UUID id = UUID.randomUUID(); + stmt.unsafePutMeta("ownerId", id); + stmt.getOperands().get(0).unsafePutMeta("ownerId", id); + stmt.getOperands().get(1).unsafePutMeta("ownerId", id); + RewriterStatement A = stmt.getChild(0, 1, 0); + A.unsafePutMeta("dontExpand", true); + if (A.getNRow().isInstruction() && A.getNRow().trueInstruction().equals("nrow") && A.getNRow().getChild(0) == stmt) + A.getNRow().getOperands().set(0, A); + if (A.getNCol().isInstruction() && A.getNCol().trueInstruction().equals("ncol") && A.getNCol().getChild(0) == stmt) + A.getNCol().getOperands().set(0, A); + }, true) + .build() + ); + } + + // TODO: Big issue when having multiple references to the same sub-dag + public static void pushdownStreamSelections(final List rules, final RuleContext ctx) { + HashMap hooks = new HashMap<>(); + + // ifelse merging + rules.add(new RewriterRuleBuilder(ctx) + .setUnidirectional(true) + .parseGlobalVars("FLOAT:a,b,c,d") + .parseGlobalVars("INT:l1,l2") + .withParsedStatement("$1:ElementWiseInstruction(ifelse(==(l1, l2), a, b), ifelse(==(l1, l2), c, d))", hooks) + .toParsedStatement("ifelse(==(l1, l2), $2:ElementWiseInstruction(a, c), $3:ElementWiseInstruction(b, d))", hooks) + .linkManyUnidirectional(hooks.get(1).getId(), List.of(hooks.get(2).getId(), hooks.get(3).getId()), RewriterStatement::transferMeta, true) + .build() + ); + + SCALARS.forEach(t -> { + SCALARS.forEach(t2 -> { + // redundant ifelse elimination + rules.add(new RewriterRuleBuilder(ctx, "Remove redundant ifelse") + .setUnidirectional(true) + .parseGlobalVars(t2 + ":c,d,e") + .parseGlobalVars(t + ":a,b") + .withParsedStatement("ifelse(==(a, b), ifelse(==(a, b), c, e), d)", hooks) + .toParsedStatement("ifelse(==(a, b), c, d)", hooks) + .build() + ); + rules.add(new RewriterRuleBuilder(ctx, "Remove redundant ifelse") + .setUnidirectional(true) + .parseGlobalVars(t2 + ":c,d,e") + .parseGlobalVars(t + ":a,b") + .withParsedStatement("ifelse(==(a, b), d, ifelse(==(a, b), c, e))", hooks) + .toParsedStatement("ifelse(==(a, b), d, e)", hooks) + .build() + ); + + // ifelse expression pullup + rules.add(new RewriterRuleBuilder(ctx, "Ifelse expression pullup") + .setUnidirectional(true) + .parseGlobalVars(t + ":a,c") + .parseGlobalVars(t2 + ":d") + .parseGlobalVars("BOOL:b") + .withParsedStatement("$1:ElementWiseInstruction(ifelse(b, a, c), d)", hooks) + .toParsedStatement("ifelse(b, $2:ElementWiseInstruction(a, d), $3:ElementWiseInstruction(c, d))", hooks) + .linkManyUnidirectional(hooks.get(1).getId(), List.of(hooks.get(2).getId(), hooks.get(3).getId()), RewriterStatement::transferMeta, true) + .build() + ); + rules.add(new RewriterRuleBuilder(ctx, "Ifelse expression pullup") + .setUnidirectional(true) + .parseGlobalVars(t + ":a,c") + .parseGlobalVars(t2 + ":d") + .parseGlobalVars("BOOL:b") + .withParsedStatement("$1:ElementWiseInstruction(d, ifelse(b, a, c))", hooks) + .toParsedStatement("ifelse(b, $2:ElementWiseInstruction(d, a), $3:ElementWiseInstruction(d, c))", hooks) + .linkManyUnidirectional(hooks.get(1).getId(), List.of(hooks.get(2).getId(), hooks.get(3).getId()), RewriterStatement::transferMeta, true) + .build() + ); + }); + + rules.add(new RewriterRuleBuilder(ctx, "Ifelse branch merge") + .setUnidirectional(true) + .parseGlobalVars(t + ":a,c,d") + .parseGlobalVars("BOOL:b") + .withParsedStatement("ifelse(b, a, a)", hooks) + .toParsedStatement("a", hooks) + .build() + ); + }); + + SCALARS.forEach(t -> { + rules.add(new RewriterRuleBuilder(ctx, "Fold true statement") + .setUnidirectional(true) + .parseGlobalVars(t + ":a") + .parseGlobalVars("LITERAL_BOOL:TRUE") + .withParsedStatement("==(a,a)", hooks) + .toParsedStatement("TRUE", hooks) + .build() + ); + }); + + rules.add(new RewriterRuleBuilder(ctx, "Eliminate unnecessary branches") + .setUnidirectional(true) + .parseGlobalVars("FLOAT:a,b") + .parseGlobalVars("LITERAL_BOOL:TRUE") + .withParsedStatement("ifelse(TRUE, a, b)", hooks) + .toParsedStatement("a", hooks) + .build() + ); + rules.add(new RewriterRuleBuilder(ctx, "Eliminate unnecessary branches") + .setUnidirectional(true) + .parseGlobalVars("FLOAT:a,b") + .parseGlobalVars("LITERAL_BOOL:FALSE") + .withParsedStatement("ifelse(FALSE, a, b)", hooks) + .toParsedStatement("b", hooks) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx) + .setUnidirectional(true) + .parseGlobalVars("INT:l") + .withParsedStatement("_idx(l, l)", hooks) + .toParsedStatement("l", hooks) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "Eliminate scalar matrices") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .parseGlobalVars("INT:i,j") + .parseGlobalVars("FLOAT:v") + .withParsedStatement("as.scalar(v)", hooks) + .toParsedStatement("v", hooks) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "Element selection pushdown") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .parseGlobalVars("INT:h,i,j,k,l,m") + .parseGlobalVars("FLOAT:v") + .parseGlobalVars("LITERAL_INT:1") + .withParsedStatement("[]($1:_m(h, i, v), l, m)", hooks) + .toParsedStatement("$3:as.scalar($2:_m(l, m, v))", hooks) + .iff(match -> { + List ops = match.getMatchRoot().getOperands().get(0).getOperands(); + return (ops.get(0).isInstruction() + && ops.get(0).trueTypedInstruction(ctx).equals("_idx(INT,INT)")) + || (ops.get(1).isInstruction() + && ops.get(1).trueTypedInstruction(ctx).equals("_idx(INT,INT)")); + }, true) + .linkUnidirectional(hooks.get(1).getId(), hooks.get(2).getId(), lnk -> { + RewriterStatement.transferMeta(lnk); + + for (int idx = 0; idx < 2; idx++) { + RewriterStatement oldRef = lnk.oldStmt.getChild(idx); + + if (!oldRef.isInstruction() || !oldRef.trueTypedInstruction(ctx).equals("_idx(INT,INT)")) + continue; + + UUID oldRefId = (UUID)oldRef.getMeta("idxId"); + + RewriterStatement newRef = lnk.newStmt.get(0).getChild(idx); + + RewriterStatement newOne = RewriterUtils.replaceReferenceAware(lnk.newStmt.get(0).getChild(2), stmt -> { + UUID idxId = (UUID) stmt.getMeta("idxId"); + if (idxId != null) { + if (idxId.equals(oldRefId)) + return newRef; + } + + return null; + }); + + if (newOne != null) + lnk.newStmt.get(0).getOperands().set(2, newOne); + } + }, true) + .apply(hooks.get(3).getId(), stmt -> { + stmt.getOperands().set(0, stmt.getChild(0, 2)); + }, true) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "Scalar matrix selection pushdown") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .parseGlobalVars("INT:h,i,j,k,l,m") + .parseGlobalVars("FLOAT:v") + .parseGlobalVars("LITERAL_INT:1") + .withParsedStatement("[]($1:_m(1, 1, v), j, k)", hooks) + .toParsedStatement("v", hooks) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "Selection pushdown") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .parseGlobalVars("INT:h,i,j,k,l,m") + .parseGlobalVars("FLOAT:v") + .parseGlobalVars("LITERAL_INT:1") + .withParsedStatement("[]($1:_m(h, i, v), j, k, l, m)", hooks) + .toParsedStatement("$2:_m(_idx(1, +(+(k, 1), -(j))), _idx(1, +(+(m, 1), -(l))), v)", hooks) // Assuming that selections are valid + .linkUnidirectional(hooks.get(1).getId(), hooks.get(2).getId(), lnk -> { + RewriterStatement.transferMeta(lnk); + + for (int idx = 0; idx < 2; idx++) { + RewriterStatement oldRef = lnk.oldStmt.getOperands().get(idx); + RewriterStatement newRef = lnk.newStmt.get(0).getChild(idx); + RewriterStatement mStmtC = new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("+").withOps(newRef.getChild(1, 1, 0), RewriterStatement.literal(ctx, -1L)).consolidate(ctx); + RewriterStatement mStmt = new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("+").withOps(newRef, mStmtC).consolidate(ctx); + final RewriterStatement newStmt = RewriterUtils.foldConstants(mStmt, ctx); + + UUID oldRefId = (UUID)oldRef.getMeta("idxId"); + + RewriterStatement newOne = RewriterUtils.replaceReferenceAware(lnk.newStmt.get(0).getChild(2), stmt -> { + UUID idxId = (UUID) stmt.getMeta("idxId"); + if (idxId != null) { + if (idxId.equals(oldRefId)) + return newStmt; + } + + return null; + }); + + if (newOne != null) + lnk.newStmt.get(0).getOperands().set(2, newOne); + } + }, true) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "_idx(a,a) => a") + .setUnidirectional(true) + .parseGlobalVars("INT:a") + .withParsedStatement("_idx(a,a)", hooks) + .toParsedStatement("a", hooks) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "_idxExpr(i::, v) => v") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .parseGlobalVars("INT:i,j") + .parseGlobalVars("FLOAT:v") + .withParsedStatement("_idxExpr(i, v)", hooks) + .toParsedStatement("v", hooks) + .iff(match -> { + List ops = match.getMatchRoot().getOperands(); + + boolean matching = (!ops.get(0).isInstruction() || !ops.get(0).trueInstruction().equals("_idx") || ops.get(0).getMeta("ownerId") != match.getMatchRoot().getMeta("ownerId")) + && (!ops.get(1).isInstruction() || !ops.get(1).trueInstruction().equals("_idx") || ops.get(1).getMeta("ownerId") != match.getMatchRoot().getMeta("ownerId")); + + return matching; + }, true) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "_idxExpr(i::, v) => v") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .parseGlobalVars("INT:i,j") + .parseGlobalVars("FLOAT*:v") + .withParsedStatement("_idxExpr(i, v)", hooks) + .toParsedStatement("v", hooks) + .iff(match -> { + List ops = match.getMatchRoot().getOperands(); + + boolean matching = (!ops.get(0).isInstruction() || !ops.get(0).trueInstruction().equals("_idx") || ops.get(0).getMeta("ownerId") != match.getMatchRoot().getMeta("ownerId")) + && (!ops.get(1).isInstruction() || !ops.get(1).trueInstruction().equals("_idx") || ops.get(1).getMeta("ownerId") != match.getMatchRoot().getMeta("ownerId")); + + return matching; + }, true) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "_idxExpr(i, sum(...)) => sum(_idxExpr(i, ...))") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .parseGlobalVars("INT:i,j") + .parseGlobalVars("FLOAT:v") + .withParsedStatement("$1:_idxExpr(i, sum(v))", hooks) + .toParsedStatement("sum($2:_idxExpr(i, v))", hooks) + .link(hooks.get(1).getId(), hooks.get(2).getId(), RewriterStatement::transferMeta) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "_idxExpr(i, sum(...)) => sum(_idxExpr(i, ...))") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .parseGlobalVars("INT:i,j") + .parseGlobalVars("FLOAT*:v") + .withParsedStatement("$1:_idxExpr(i, sum(v))", hooks) + .toParsedStatement("sum($2:_idxExpr(i, v))", hooks) + .link(hooks.get(1).getId(), hooks.get(2).getId(), RewriterStatement::transferMeta) + .build() + ); + + RewriterUtils.buildBinaryPermutations(List.of("FLOAT"), (t1, t2) -> { + rules.add(new RewriterRuleBuilder(ctx, "*(sum(_idxExpr(i, ...)), sum(_idxExpr(j, ...))) => _idxExpr(i, _idxExpr(j, sum(*(...)))") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .parseGlobalVars("INT:i,j") + .parseGlobalVars(t1 + ":v1") + .parseGlobalVars(t2 + ":v2") + .withParsedStatement("$1:*(sum($2:_idxExpr(i, v1)), sum($3:_idxExpr(j, v2)))", hooks) + .toParsedStatement("sum($4:_idxExpr(i, $5:_idxExpr(j, $6:*(v1, v2))))", hooks) + .link(hooks.get(1).getId(), hooks.get(6).getId(), RewriterStatement::transferMeta) + .link(hooks.get(2).getId(), hooks.get(4).getId(), RewriterStatement::transferMeta) + .link(hooks.get(3).getId(), hooks.get(5).getId(), RewriterStatement::transferMeta) + .build() + ); + }); + + rules.add(new RewriterRuleBuilder(ctx, "sum(sum(v)) => sum(v)") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .parseGlobalVars("INT:i,j") + .parseGlobalVars("FLOAT:v") + .withParsedStatement("sum(sum(v))", hooks) + .toParsedStatement("sum(v)", hooks) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "sum(sum(v)) => sum(v)") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .parseGlobalVars("INT:i,j") + .parseGlobalVars("FLOAT*:v") + .withParsedStatement("sum(sum(v))", hooks) + .toParsedStatement("sum(v)", hooks) + .build() + ); + + SCALARS.forEach(t -> { + rules.add(new RewriterRuleBuilder(ctx, "sum(v::" + t + ") => v::" + t) + .setUnidirectional(true) + .parseGlobalVars(t + ":v") + .withParsedStatement("sum(v)", hooks) + .toParsedStatement("v", hooks) + .build() + ); + }); + + rules.add(new RewriterRuleBuilder(ctx, "[](UnaryElementWiseOperator(A), i, j) => UnaryElementWiseOperator([](A, i, j))") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A") + .parseGlobalVars("INT:i,j") + .withParsedStatement("[]($1:UnaryElementWiseOperator(A), i, j)", hooks) + .toParsedStatement("$2:UnaryElementWiseOperator([](A, i, j))", hooks) + .link(hooks.get(1).getId(), hooks.get(2).getId(), RewriterStatement::transferMeta) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "[](ElementWiseUnary.FLOAT(A), i, j) => ElementWiseUnary.FLOAT([](A, i, j))") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A") + .parseGlobalVars("INT:i,j") + .withParsedStatement("[]($1:ElementWiseUnary.FLOAT(A), i, j)", hooks) + .toParsedStatement("$2:ElementWiseUnary.FLOAT([](A, i, j))", hooks) + .link(hooks.get(1).getId(), hooks.get(2).getId(), RewriterStatement::transferMeta) + .build() + ); + + for (String t : ALL_TYPES) { + if (t.equals("MATRIX")) { + rules.add(new RewriterRuleBuilder(ctx, "ElementWiseInstruction(_m(i, j, v), b) => _m(i, j, ElementWiseInstruction(v, b))") + .setUnidirectional(true) + .parseGlobalVars("FLOAT:v") + .parseGlobalVars(t + ":B") + .parseGlobalVars("INT:i,j") + .withParsedStatement("$1:ElementWiseInstruction($2:_m(i, j, v), B)", hooks) + .toParsedStatement("$3:_m(i, j, $4:ElementWiseInstruction(v, [](B, i, j)))", hooks) + .link(hooks.get(1).getId(), hooks.get(4).getId(), RewriterStatement::transferMeta) + .link(hooks.get(2).getId(), hooks.get(3).getId(), RewriterStatement::transferMeta) + .apply(hooks.get(3).getId(), (stmt, match) -> { + // Then we an infer that the two matrices have the same dimensions + match.getNewExprRoot().getAssertions(ctx).addEqualityAssertion(stmt.getNCol(), stmt.getChild(2, 1, 0).getNCol(), match.getNewExprRoot()); + match.getNewExprRoot().getAssertions(ctx).addEqualityAssertion(stmt.getNRow(), stmt.getChild(2, 1, 0).getNRow(), match.getNewExprRoot()); + }, true) + .build() + ); + + continue; + } + rules.add(new RewriterRuleBuilder(ctx, "ElementWiseInstruction(_m(i, j, A), b) => _m(i, j, ElementWiseInstruction(A, b))") + .setUnidirectional(true) + .parseGlobalVars("FLOAT:v") + .parseGlobalVars(t + ":b") + .parseGlobalVars("INT:i,j") + .withParsedStatement("$1:ElementWiseInstruction($2:_m(i, j, v), b)", hooks) + .toParsedStatement("$3:_m(i, j, $4:ElementWiseInstruction(v, b))", hooks) + .link(hooks.get(1).getId(), hooks.get(4).getId(), RewriterStatement::transferMeta) + .link(hooks.get(2).getId(), hooks.get(3).getId(), RewriterStatement::transferMeta) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "[](ElementWiseInstruction(A, v), i, j) => ElementWiseInstruction(v, [](A, i, j))") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A") + .parseGlobalVars(t + ":v") + .parseGlobalVars("INT:i,j") + .withParsedStatement("[]($1:ElementWiseInstruction(A, v), i, j)", hooks) + .toParsedStatement("$2:ElementWiseInstruction([](A, i, j), v)", hooks) + .link(hooks.get(1).getId(), hooks.get(2).getId(), RewriterStatement::transferMeta) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "[](ElementWiseInstruction(v, A), i, j) => ElementWiseInstruction(v, [](A, i, j))") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A") + .parseGlobalVars(t + ":v") + .parseGlobalVars("INT:i,j") + .withParsedStatement("[]($1:ElementWiseInstruction(v, A), i, j)", hooks) + .toParsedStatement("$2:ElementWiseInstruction(v, [](A, i, j))", hooks) + .link(hooks.get(1).getId(), hooks.get(2).getId(), RewriterStatement::transferMeta) + .build() + ); + } + } + + // This expands the statements to a common canonical form + public static void canonicalExpandAfterFlattening(final List rules, final RuleContext ctx) { + HashMap hooks = new HashMap<>(); + + rules.add(new RewriterRuleBuilder(ctx, "sum($1:_idxExpr(indices, -(A))) => -(sum($2:_idxExpr(indices, A)))") + .setUnidirectional(true) + .parseGlobalVars("FLOAT:a") + .parseGlobalVars("INT...:indices") + .withParsedStatement("sum($1:_idxExpr(indices, -(a)))", hooks) + .toParsedStatement("-(sum($2:_idxExpr(indices, a)))", hooks) + .link(hooks.get(1).getId(), hooks.get(2).getId(), RewriterStatement::transferMeta) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "sum($1:_idxExpr(indices, -(a))) => -(sum($2:_idxExpr(indices, a)))") + .setUnidirectional(true) + .parseGlobalVars("INT:a") + .parseGlobalVars("INT...:indices") + .withParsedStatement("sum($1:_idxExpr(indices, -(a)))", hooks) + .toParsedStatement("-(sum($2:_idxExpr(indices, a)))", hooks) + .link(hooks.get(1).getId(), hooks.get(2).getId(), RewriterStatement::transferMeta) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "sum(_idxExpr(indices, +(ops))) => +(argList(sum(_idxExpr(indices, op1)), sum(_idxExpr(...)), ...))") + .setUnidirectional(true) + .parseGlobalVars("INT...:indices") + .parseGlobalVars("FLOAT...:ops") + .withParsedStatement("sum($1:_idxExpr(indices, +(ops)))", hooks) + .toParsedStatement("$4:+($3:argList(sum($2:_idxExpr(indices, +(ops)))))", hooks) // The inner +(ops) is temporary and will be removed + .link(hooks.get(1).getId(), hooks.get(2).getId(), RewriterStatement::transferMeta) + .apply(hooks.get(3).getId(), newArgList -> { + RewriterStatement oldArgList = newArgList.getChild(0, 0, 1, 0); + newArgList.getChild(0, 0).getOperands().set(1, oldArgList.getChild(0)); + + for (int i = 1; i < oldArgList.getOperands().size(); i++) { + RewriterStatement newIdxExpr = newArgList.getChild(0, 0).copyNode(); + newIdxExpr.getOperands().set(1, oldArgList.getChild(i)); + RewriterStatement newSum = new RewriterInstruction() + .as(UUID.randomUUID().toString()) + .withInstruction("sum") + .withOps(newIdxExpr); + RewriterUtils.copyIndexList(newIdxExpr); + newIdxExpr.refreshReturnType(ctx); + newSum.consolidate(ctx); + newArgList.getOperands().add(newSum); + } + + newArgList.refreshReturnType(ctx); + }, true) + .apply(hooks.get(4).getId(), stmt -> { + stmt.refreshReturnType(ctx); + }, true) + .build() + ); + } + + public static void flattenedAlgebraRewrites(final List rules, final RuleContext ctx) { + HashMap hooks = new HashMap<>(); + + // Minus pushdown + rules.add(new RewriterRuleBuilder(ctx, "-(+(...)) => +(-(el1), -(el2), ...)") + .setUnidirectional(true) + .parseGlobalVars("FLOAT...:ops") + .withParsedStatement("-(+(ops))", hooks) + .toParsedStatement("$1:+(ops)", hooks) // Temporary + .apply(hooks.get(1).getId(), (stmt, match) -> { + RewriterStatement argList = stmt.getChild(0); + + for (int i = 0; i < argList.getOperands().size(); i++) { + RewriterInstruction newStmt = new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("-").withOps(argList.getOperands().get(i)); + newStmt.consolidate(ctx); + argList.getOperands().set(i, newStmt); + } + + RewriterUtils.tryFlattenNestedOperatorPatterns(ctx, match.getNewExprRoot()); + }, true) + .build() + ); + } + + public static List buildElementWiseAlgebraicCanonicalization(final List rules, final RuleContext ctx) { + RewriterUtils.buildTernaryPermutations(List.of("FLOAT", "INT", "BOOL"), (t1, t2, t3) -> { + rules.add(new RewriterRuleBuilder(ctx, "*(+(a, b), c) => +(*(a, c), *(b, c))") + .setUnidirectional(true) + .parseGlobalVars(t1 + ":a") + .parseGlobalVars(t2 + ":b") + .parseGlobalVars(t3 + ":c") + .withParsedStatement("*(+(a, b), c)") + .toParsedStatement("+(*(a, c), *(b, c))") + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "*(c, +(a, b)) => +(*(c, a), *(c, b))") + .setUnidirectional(true) + .parseGlobalVars(t1 + ":a") + .parseGlobalVars(t2 + ":b") + .parseGlobalVars(t3 + ":c") + .withParsedStatement("*(c, +(a, b))") + .toParsedStatement("+(*(c, a), *(c, b))") + .build() + ); + }); + + /*List.of("FLOAT", "INT").forEach(t -> { + rules.add(new RewriterRuleBuilder(ctx, "-(a) => *(-1.0, a)") + .setUnidirectional(true) + .parseGlobalVars(t + ":a") + .parseGlobalVars("LITERAL_" + t + ":-1") + .withParsedStatement("-(a)") + .toParsedStatement("*(-1, a)") + .build() + ); + });*/ + + return rules; + } + + public static List replaceNegation(final List rules, final RuleContext ctx) { + List.of("FLOAT", "INT").forEach(t -> { + rules.add(new RewriterRuleBuilder(ctx, "-(a) => *(-1.0, a)") + .setUnidirectional(true) + .parseGlobalVars(t + ":a") + .parseGlobalVars("LITERAL_" + t + ":-1") + .withParsedStatement("-(a)") + .toParsedStatement("*(-1, a)") + .build() + ); + }); + + return rules; + } + + @Deprecated + public static void streamifyExpressions(final List rules, final RuleContext ctx) { + HashMap hooks = new HashMap<>(); + + ALL_TYPES.forEach(t -> { + if (t.equals("MATRIX")) + return; + + rules.add(new RewriterRuleBuilder(ctx) + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A") + .parseGlobalVars(t + ":b") + .parseGlobalVars("INT:i,j") + .parseGlobalVars("FLOAT:v") + .withParsedStatement("$1:ElementWiseInstruction($3:_m(i, j, v), b)", hooks) + .toParsedStatement("$4:_m(i, j, $2:ElementWiseInstruction(v, b))", hooks) + .link(hooks.get(1).getId(), hooks.get(2).getId(), RewriterStatement::transferMeta) + .link(hooks.get(3).getId(), hooks.get(4).getId(), RewriterStatement::transferMeta) + .build()); + + rules.add(new RewriterRuleBuilder(ctx) + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A") + .parseGlobalVars(t + ":b") + .parseGlobalVars("INT:i,j") + .parseGlobalVars("FLOAT:v") + .withParsedStatement("$1:ElementWiseInstruction(b, $3:_m(i, j, v))", hooks) + .toParsedStatement("$4:_m(i, j, $2:ElementWiseInstruction(b, v))", hooks) + .link(hooks.get(1).getId(), hooks.get(2).getId(), RewriterStatement::transferMeta) + .link(hooks.get(3).getId(), hooks.get(4).getId(), RewriterStatement::transferMeta) + .build()); + }); + + + } + + public static void flattenOperations(final List rules, final RuleContext ctx) { + HashMap hooks = new HashMap<>(); + + RewriterUtils.buildBinaryPermutations(List.of("INT", "INT..."), (t1, t2) -> { + for (String t3 : List.of("FLOAT", "FLOAT*", "INT", "INT*", "BOOL", "BOOL*")) { + rules.add(new RewriterRuleBuilder(ctx, "Flatten nested index expression") + .setUnidirectional(true) + .parseGlobalVars(t1 + ":i") + .parseGlobalVars(t2 + ":j") + .parseGlobalVars(t3 + ":v") + .withParsedStatement("$1:_idxExpr(i, $2:_idxExpr(j, v))", hooks) + .toParsedStatement("$3:_idxExpr(argList(i, j), v)", hooks) + .link(hooks.get(1).getId(), hooks.get(3).getId(), RewriterStatement::transferMeta) + .apply(hooks.get(3).getId(), (stmt, match) -> { + UUID newOwnerId = (UUID) stmt.getMeta("ownerId"); + + if (newOwnerId == null) + throw new IllegalArgumentException(); + + if (!stmt.getChild(0, 1).isLiteral()) + stmt.getOperands().get(0).getOperands().get(1).unsafePutMeta("ownerId", newOwnerId); + }, true) + .build()); + + if (t1.equals("INT")) { + // This must be executed after the rule above + rules.add(new RewriterRuleBuilder(ctx, "Flatten nested index expression") + .setUnidirectional(true) + .parseGlobalVars(t1 + ":i") + .parseGlobalVars(t3 + ":v") + .withParsedStatement("$1:_idxExpr(i, v)", hooks) + .toParsedStatement("$3:_idxExpr(argList(i), v)", hooks) + .link(hooks.get(1).getId(), hooks.get(3).getId(), RewriterStatement::transferMeta) + .build()); + } + } + }); + + RewriterUtils.buildBinaryPermutations(List.of("MATRIX", "INT", "FLOAT", "BOOL"), (t1, t2) -> { + rules.add(new RewriterRuleBuilder(ctx, "Flatten fusable binary operator") + .setUnidirectional(true) + .parseGlobalVars(t1 + ":A") + .parseGlobalVars(t2 + ":B") + .withParsedStatement("$1:FusableBinaryOperator(A,B)", hooks) + .toParsedStatement("$2:FusedOperator(argList(A,B))", hooks) + .link(hooks.get(1).getId(), hooks.get(2).getId(), RewriterStatement::transferMeta) + .build()); + + rules.add(new RewriterRuleBuilder(ctx, "Flatten fusable binary operator") + .setUnidirectional(true) + .parseGlobalVars(t1 + "...:A") + .parseGlobalVars(t2 + ":B") + .withParsedStatement("$1:FusableBinaryOperator($2:FusedOperator(A), B)", hooks) + .toParsedStatement("$3:FusedOperator(argList(A, B))", hooks) + .iff(match -> { + return match.getMatchRoot().trueInstruction().equals(match.getMatchRoot().getOperands().get(0).trueInstruction()); + }, true) + .link(hooks.get(2).getId(), hooks.get(3).getId(), RewriterStatement::transferMeta) + .build()); + + rules.add(new RewriterRuleBuilder(ctx, "Flatten fusable binary operator") + .setUnidirectional(true) + .parseGlobalVars(t1 + "...:A") + .parseGlobalVars(t2 + ":B") + .withParsedStatement("$1:FusableBinaryOperator(B, $2:FusedOperator(A))", hooks) + .toParsedStatement("$3:FusedOperator(argList(B, A))", hooks) + .iff(match -> { + return match.getMatchRoot().trueInstruction().equals(match.getMatchRoot().getOperands().get(0).trueInstruction()); + }, true) + .link(hooks.get(2).getId(), hooks.get(3).getId(), RewriterStatement::transferMeta) + .build()); + }); + + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterRuleCreator.java b/src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterRuleCreator.java new file mode 100644 index 00000000000..28fc7bf6028 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterRuleCreator.java @@ -0,0 +1,537 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter.rule; + +import org.apache.commons.collections4.bidimap.DualHashBidiMap; +import org.apache.commons.lang3.mutable.MutableBoolean; +import org.apache.sysds.hops.Hop; +import org.apache.sysds.hops.rewriter.RewriterDataType; +import org.apache.sysds.hops.rewriter.RewriterRuntimeUtils; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.apache.sysds.hops.rewriter.dml.DMLCodeGenerator; +import org.apache.sysds.hops.rewriter.dml.DMLExecutor; +import org.apache.sysds.hops.rewriter.estimators.RewriterCostEstimator; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; +import scala.Tuple2; + +import javax.annotation.Nullable; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Random; +import java.util.Set; +import java.util.UUID; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.stream.Collectors; + +public class RewriterRuleCreator { + + private RuleContext ctx; + private RewriterRuleSet ruleSet; + private List activeRules; + + public RewriterRuleCreator(final RuleContext ctx) { + this.ctx = ctx; + activeRules = Collections.synchronizedList(new LinkedList<>()); + ruleSet = new RewriterRuleSet(ctx, activeRules); + } + + public synchronized void forEachRule(Consumer consumer) { + activeRules.forEach(consumer); + } + + public boolean registerRule(RewriterRule rule, Function canonicalFormConverter, final RuleContext ctx) { + try { + return registerRule(rule, RewriterCostEstimator.estimateCost(rule.getStmt1(), ctx), RewriterCostEstimator.estimateCost(rule.getStmt2(), ctx), false, canonicalFormConverter); + } catch (Exception e) { + System.err.println("Error while registering a rule: " + rule); + e.printStackTrace(); + return false; + } + } + + public synchronized boolean registerRule(RewriterRule rule, long preCost, long postCost, boolean validateCorrectness, Function canonicalFormCreator) { + // First, we check if an existing rule already applies an equivalent rewrite (cost wise) + RewriterStatement toTest = rule.getStmt1().nestedCopy(false); + + RewriterStatement newStmt = rule.getStmt2().nestedCopy(false); + + boolean converged = false; + boolean changed = false; + + List appliedRules = new ArrayList<>(); + + for (int i = 0; i < 500; i++) { + RewriterRuleSet.ApplicableRule applicableRule = ruleSet.acceleratedFindFirst(newStmt, true); + + if (applicableRule == null) { + converged = true; + break; // Then we converged + } + + newStmt = applicableRule.rule.apply(applicableRule.matches.get(0), newStmt, applicableRule.forward, false); + RewriterUtils.mergeArgLists(newStmt, ctx); + newStmt = RewriterUtils.foldConstants(newStmt, ctx); + appliedRules.add(applicableRule.rule); + changed = true; + } + + if (!converged) + throw new IllegalArgumentException("The existing rule-set did not seem to converge for the example: \n" + toTest.toParsableString(ctx, true) + "\n" + String.join("\n", appliedRules.subList(appliedRules.size()-5, appliedRules.size()).stream().map(rl -> rl.toParsableString(ctx)).collect(Collectors.toList()))); + + appliedRules.clear(); + + for (int i = 0; i < 500; i++) { + RewriterRuleSet.ApplicableRule applicableRule = ruleSet.acceleratedFindFirst(toTest, true); + + if (applicableRule == null) { + converged = true; + break; // Then we converged + } + + toTest = applicableRule.rule.apply(applicableRule.matches.get(0), toTest, applicableRule.forward, false); + + RewriterUtils.mergeArgLists(toTest, ctx); + toTest = RewriterUtils.foldConstants(toTest, ctx); + appliedRules.add(applicableRule.rule); + changed = true; + } + + if (!converged) + throw new IllegalArgumentException("The existing rule-set did not seem to converge for the example: \n" + toTest.toParsableString(ctx, true) + "\n" + String.join("\n", appliedRules.stream().map(rl -> rl.toParsableString(ctx)).collect(Collectors.toList()))); + + if (newStmt != rule.getStmt2()) { + // Then the mapping has changed, and we need to + try { + postCost = RewriterCostEstimator.estimateCost(newStmt, ctx); + } catch (Exception e) { + System.err.println("Err in cost from orig: " + rule.getStmt2().toParsableString(ctx)); + System.err.println("NewStmt: " + newStmt.toParsableString(ctx)); + e.printStackTrace(); + return false; + } + } + + if (changed) { + long existingPostCost; + + try { + existingPostCost = RewriterCostEstimator.estimateCost(toTest, ctx); + } catch (Exception e) { + System.err.println("Err in cost from orig: " + rule.getStmt1().toParsableString(ctx)); + System.err.println("ToTest: " + toTest.toParsableString(ctx)); + System.err.println("AppliedRules: " + appliedRules); + e.printStackTrace(); + return false; + } + + if (existingPostCost <= postCost || preCost >= postCost) + return false; // Then this rule is not beneficial + } + + // We might have to rebuild the rule + if (changed || newStmt != rule.getStmt2()) { + try { + rule = createRule(toTest, newStmt, canonicalFormCreator.apply(toTest), canonicalFormCreator.apply(newStmt), ctx); + } catch (Exception e) { + System.err.println("Failed to create: " + toTest.toParsableString(ctx) + " => " + newStmt.toParsableString(ctx)); + } + } + + + if (validateCorrectness) { + // Now, we validate the rule by executing it in the system + if (!validateRuleCorrectnessAndGains(rule, ctx)) + return false; // Then, either the rule is incorrect or is already implemented + } + + //System.out.println("Rule is correct!"); + + RewriterRuleSet probingSet = new RewriterRuleSet(ctx, List.of(rule)); + List rulesToRemove = new ArrayList<>(); + List rulesThatMustComeBefore = new ArrayList<>(); + + // Check for interactions between different rules + for (RewriterRule existingRule : activeRules) { + RewriterStatement mProbe = existingRule.getStmt1(); + RewriterRuleSet.ApplicableRule applicableRule = probingSet.acceleratedFindFirst(mProbe); + + if (applicableRule != null) { + // Then we have to take a deeper look into the interaction between the rules + // Either the new rule achieves a better result -> the old rule can be eliminated + // Or the new rule finds a worse rewrite for the existing rule -> Then the existing rule must be kept and be applied before the new rule + mProbe = mProbe.nestedCopy(true); + + for (int i = 0; i < 20; i++) { + applicableRule = probingSet.acceleratedFindFirst(mProbe); + + if (i == 19) + throw new IllegalArgumentException("The following rule created a conflict with another rule:\nNew one:\n" + rule + "\t[Cost: " + preCost + " => " + postCost + "]\nExisting:\n" + existingRule + "\t[Cost: " + existingRule.getStmt1().getCost(ctx) + " => " + existingRule.getStmt2().getCost(ctx) + "]"); + if (applicableRule != null) + mProbe = applicableRule.rule.apply(applicableRule.matches.get(0), mProbe, applicableRule.forward, false); + else + break; + } + + long newCost = mProbe.getCost(ctx); + long existingRuleNewCost = existingRule.getStmt2().getCost(ctx); + + if (newCost == -1 || existingRuleNewCost == -1) + throw new IllegalArgumentException("The rule set or the new rule resulted in an invalid cost:\nNew one:\n" + rule + "\nExisting:\n" + existingRule); + + if (newCost <= existingRuleNewCost) { + // Then we remove the old rule + rulesToRemove.add(existingRule); + } else { + // Then the existing rule is still legitimate and must come before the new rule as it is more specific + rulesThatMustComeBefore.add(existingRule); + } + } + } + + // Check if rule is expansive (e.g. expands itself leading to an infinite loop) + RewriterRuleSet testSet = new RewriterRuleSet(ctx, List.of(rule)); + testSet.accelerate(); + RewriterStatement mProbe = rule.getStmt2(); + if (testSet.acceleratedFindFirst(mProbe) != null) + throw new IllegalArgumentException("Expansive rule detected!"); + + + activeRules.removeAll(rulesToRemove); + + // Now, we include the rule to the system + // TODO: Further checks are needed, especially if the new heuristic converges in all cases + activeRules.add(rule); + + ruleSet.accelerate(); + + return true; + } + + public RewriterRuleSet getRuleSet() { + return ruleSet; + } + + public void throwOutInvalidRules(boolean correctness, boolean relevance) { + if (!correctness && !relevance) + return; + + activeRules.removeIf(rule -> (correctness && !validateRuleCorrectness(rule, ctx)) || (relevance && !validateRuleApplicability(rule, ctx))); + ruleSet.accelerate(); + } + + + + + + + ///// STATIC METHODS ///// + + // This runs the rule from expressions + public static boolean validateRuleCorrectnessAndGains(RewriterRule rule, final RuleContext ctx) { + return validateRuleCorrectness(rule, ctx) && validateRuleApplicability(rule, ctx); + } + + public static boolean validateRuleCorrectness(RewriterRule rule, final RuleContext ctx) { + RewriterUtils.renameIllegalVarnames(ctx, rule.getStmt1(), rule.getStmt2()); + String sessionId = UUID.randomUUID().toString(); + String code = DMLCodeGenerator.generateRuleValidationDML(rule, sessionId, ctx); + + MutableBoolean isValid = new MutableBoolean(false); + boolean successful = DMLExecutor.executeCode(code, DMLCodeGenerator.ruleValidationScript(rule.toParsableString(ctx), sessionId, isValid::setValue)); + + if (!isValid.booleanValue()) { + String errStr = "An invalid rule was found: " + rule + "\n\tReason: " + (successful ? "Assertion" : "Error"); + + if (!successful && !DMLExecutor.getLastErr().isEmpty()) + errStr += " (" + DMLExecutor.getLastErr().get(0) + ")"; + + DMLExecutor.println(errStr); + } + + return isValid.booleanValue(); + } + + public static boolean validateRuleApplicability(RewriterRule rule, final RuleContext ctx) { + return validateRuleApplicability(rule, ctx, false, null); + } + + public static boolean validateRuleApplicability(RewriterRule rule, final RuleContext ctx, boolean print, @Nullable Function injectedRewriteClass) { + RewriterStatement _mstmt = rule.getStmt1(); + RewriterStatement _mstmt2 = rule.getStmt2(); + if (ctx.metaPropagator != null) { + ctx.metaPropagator.apply(_mstmt); + ctx.metaPropagator.apply(_mstmt2); + } + + final RewriterStatement stmt1 = RewriterUtils.unfuseOperators(_mstmt, ctx); + + Set vars = DMLCodeGenerator.getVariables(stmt1); + Set varNames = vars.stream().map(RewriterStatement::getId).collect(Collectors.toSet()); + String code2Header = DMLCodeGenerator.generateDMLVariables(vars); + String code2 = code2Header + "\nresult = " + DMLCodeGenerator.generateDML(stmt1); + + boolean isMatrix = stmt1.getResultingDataType(ctx).equals("MATRIX"); + + if (isMatrix) + code2 += "\nprint(lineage(result))"; + else + code2 += "\nprint(lineage(as.matrix(result)))"; + + MutableBoolean isRelevant = new MutableBoolean(false); + + final RewriterStatement expectedStmt = injectedRewriteClass != null ? _mstmt2 : _mstmt; + + RewriterRuntimeUtils.attachHopInterceptor(prog -> { + Hop hop; + + if (isMatrix) + hop = prog.getStatementBlocks().get(0).getHops().get(0).getInput(0).getInput(0); + else + hop = prog.getStatementBlocks().get(0).getHops().get(0).getInput(0).getInput(0).getInput(0); + + RewriterStatement stmt = RewriterRuntimeUtils.buildDAGFromHop(hop, 1000, true, ctx); + + if (stmt == null) + return false; + + Map nameAssocs = new HashMap<>(); + // Find the variables that are actually leafs in the original rule + stmt.forEachPreOrder(cur -> { + for (int i = 0; i < cur.getOperands().size(); i++) { + RewriterStatement child = cur.getChild(i); + + if (varNames.contains(child.getId())) { + RewriterStatement assoc = nameAssocs.get(child.getId()); + + if (assoc == null) { + assoc = new RewriterDataType().as(child.getId()).ofType(child.getResultingDataType(ctx)).consolidate(ctx); + + Long ncol = (Long) child.getMeta("_actualNCol"); + Long nrow = (Long) child.getMeta("_actualNRow"); + + if (ncol != null) + assoc.unsafePutMeta("_actualNCol", ncol); + + if (nrow != null) + assoc.unsafePutMeta("_actualNRow", nrow); + + nameAssocs.put(child.getId(), assoc); + } + + cur.getOperands().set(i, assoc); + } + } + + return true; + }, false); + + stmt = RewriterRuntimeUtils.populateDataCharacteristics(stmt, ctx); + stmt = ctx.metaPropagator.apply(stmt); + + stmt = stmt.nestedCopyOrInject(new HashMap<>(), mstmt -> { + if (mstmt.isInstruction() && (mstmt.trueInstruction().equals("ncol") || mstmt.trueInstruction().equals("nrow"))) + return RewriterStatement.literal(ctx, DMLCodeGenerator.MATRIX_DIMS); + return null; + }); + + stmt.prepareForHashing(); + stmt.recomputeHashCodes(ctx); + + Map createdObjects = new HashMap<>(); + + RewriterStatement stmt1ReplaceNCols = expectedStmt.nestedCopyOrInject(createdObjects, mstmt -> { + if (mstmt.isInstruction() && (mstmt.trueInstruction().equals("ncol") || mstmt.trueInstruction().equals("nrow"))) + return RewriterStatement.literal(ctx, DMLCodeGenerator.MATRIX_DIMS); + return null; + }); + + stmt1ReplaceNCols.prepareForHashing(); + stmt1ReplaceNCols.recomputeHashCodes(ctx); + + Set mVars = vars.stream().map(createdObjects::get).filter(Objects::nonNull).collect(Collectors.toSet()); + + if (print) { + DMLExecutor.println("Observed statement: " + stmt.toParsableString(ctx)); + DMLExecutor.println("Expected statement: " + stmt1ReplaceNCols.toParsableString(ctx)); + } + + RewriterStatement.MatcherContext mCtx = RewriterStatement.MatcherContext.exactMatch(ctx, stmt, stmt1ReplaceNCols); + if (stmt1ReplaceNCols.match(mCtx)) { + // Check if also the right variables are associated + boolean assocsMatching = true; + if (mCtx.getDependencyMap() != null) { + for (RewriterStatement var : mVars) { + RewriterStatement assoc = mCtx.getDependencyMap().get(var.isInstruction() && !var.trueInstruction().equals("const") ? var.getChild(0) : var); + + if (assoc == null) + throw new IllegalArgumentException("Association is null!"); + + if (!assoc.getId().equals(var.getId())) { + assocsMatching = false; + break; + } + } + } + + if (assocsMatching) { + // Then the rule matches, meaning that the statement is not rewritten by SystemDS + isRelevant.setValue(true); + } + } + + // TODO: Maybe we can still rewrite the new graph if it still has less cost + + // TODO: Evaluate cost and if our rule can still be applied + return injectedRewriteClass != null; // The program should not be executed as we just want to extract any rewrites that are applied to the current statement + }); + + MutableBoolean wasApplied = new MutableBoolean(true); + + if (injectedRewriteClass != null) { + String ruleStr = rule.toString(); + wasApplied.setValue(false); + DMLExecutor.executeCode(code2, s -> { + if (s.equals("Applying rewrite: " + ruleStr)) { + wasApplied.setValue(true); + } + }, injectedRewriteClass); + } else { + DMLExecutor.executeCode(code2, true); + } + + RewriterRuntimeUtils.detachHopInterceptor(); + + return isRelevant.booleanValue() && wasApplied.booleanValue(); + } + + public static RewriterRule createRule(RewriterStatement from, RewriterStatement to, RewriterStatement canonicalForm1, RewriterStatement canonicalForm2, final RuleContext ctx) { + Tuple2 commonForm = createCommonForm(from, to, canonicalForm1, canonicalForm2, ctx); + from = commonForm._1; + to = commonForm._2; + + return new RewriterRuleBuilder(ctx, "Autogenerated rule").setUnidirectional(true).completeRule(from, to).build(); + } + + public static RewriterRule createRuleFromCommonStatements(RewriterStatement from, RewriterStatement to, final RuleContext ctx) { + return new RewriterRuleBuilder(ctx, "Autogenerated rule").setUnidirectional(true).completeRule(from, to).build(); + } + + public static RewriterRule createConditionalRuleFromCommonStatements(RewriterStatement from, List to, final RuleContext ctx) { + return new RewriterRuleBuilder(ctx, "Autogenerated conditional rule").setUnidirectional(true).completeConditionalRule(from, to).build(); + } + + public static Tuple2 createCommonForm(RewriterStatement from, RewriterStatement to, RewriterStatement canonicalForm1, RewriterStatement canonicalForm2, final RuleContext ctx) { + from = from.nestedCopy(true); + Map assocs = getAssociations(from, to, canonicalForm1, canonicalForm2, ctx); + // Now, we replace all variables with a common element + from.forEachPreOrder((cur, pred) -> { + for (int i = 0; i < cur.getOperands().size(); i++) { + RewriterStatement child = cur.getChild(i); + + if (child instanceof RewriterDataType && !child.isLiteral()) { + RewriterStatement newRef = assocs.get(child); + + if (newRef != null) + cur.getOperands().set(i, newRef); + } + } + + return true; + }, false); + + from = ctx.metaPropagator.apply(from); + return new Tuple2<>(from, to); + } + + private static Map getAssociations(RewriterStatement from, RewriterStatement to, RewriterStatement canonicalFormFrom, RewriterStatement canonicalFormTo, final RuleContext ctx) { + Map fromCanonicalLink = getAssociationToCanonicalForm(from, canonicalFormFrom, true, ctx); + Map toCanonicalLink = getAssociationToCanonicalForm(to, canonicalFormTo, true, ctx); + + RewriterStatement.MatcherContext matcher = RewriterStatement.MatcherContext.exactMatch(ctx, canonicalFormTo, canonicalFormFrom); + canonicalFormFrom.match(matcher); + + Map assocs = new HashMap<>(); + matcher.getDependencyMap().forEach((k, v) -> { + if (k.isLiteral()) + return; + + RewriterStatement newKey = fromCanonicalLink.get(k); + RewriterStatement newValue = toCanonicalLink.get(v); + + if (newKey == null || newValue == null) + return; + + assocs.put(newKey, newValue); + }); + + return assocs; + } + + private static Random rd = new Random(); + private static Map getAssociationToCanonicalForm(RewriterStatement stmt, RewriterStatement canonicalForm, boolean reversed, final RuleContext ctx) { + // We identify all associations by their names + // If there are name collisions, this does not work + Map namedVariables = new HashMap<>(); + stmt.forEachPostOrder((cur, pred) -> { + if (!(cur instanceof RewriterDataType) || cur.isLiteral()) + return; + + if (namedVariables.put(cur.getId(), cur) != null) + throw new IllegalArgumentException("Duplicate variable name: " + cur.toParsableString(RuleContext.currentContext) + "\nEntire statement:\n" + stmt.toParsableString(ctx) + "\nRaw: " + stmt); + }, false); + + Map assoc = new DualHashBidiMap<>(); + + canonicalForm.forEachPostOrder((cur, pred) -> { + if (!(cur instanceof RewriterDataType) || cur.isLiteral()) + return; + + RewriterStatement ref = namedVariables.get(cur.getId()); + + if (ref == null) { + assoc.put(ref, ref); + } + + if (reversed) + assoc.put(cur, ref); + else + assoc.put(ref, cur); + }, false); + + namedVariables.values().forEach(ref -> { + if (reversed) { + if (!assoc.containsValue(ref)) + ref.rename("u_" + rd.nextInt(100000)); + } else { + if (!assoc.containsKey(ref)) + ref.rename("u_" + rd.nextInt(100000)); + } + }); + + return assoc; + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterRuleSet.java b/src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterRuleSet.java new file mode 100644 index 00000000000..468f7fd2ad2 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterRuleSet.java @@ -0,0 +1,350 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter.rule; + +import org.apache.commons.collections4.bidimap.DualHashBidiMap; +import org.apache.commons.lang3.mutable.MutableObject; +import org.apache.sysds.hops.Hop; +import org.apache.sysds.hops.rewriter.RewriterInstruction; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.apache.sysds.hops.rewriter.codegen.RewriterCodeGen; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; +import scala.Tuple2; +import scala.Tuple3; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.BiConsumer; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +public class RewriterRuleSet { + + public static class ApplicableRule { + public final ArrayList matches; + public final RewriterRule rule; + public final boolean forward; + + public ApplicableRule(ArrayList matches, RewriterRule rule, boolean forward) { + this.matches = matches; + this.rule = rule; + this.forward = forward; + } + + public String toString(final RuleContext ctx) { + StringBuilder builder = new StringBuilder(); + builder.append("Rule: " + rule + "\n\n"); + int ctr = 1; + for (RewriterStatement.MatchingSubexpression match : matches) { + builder.append("Match " + ctr++ + ": \n"); + builder.append(" " + match.getMatchRoot() + " = " + (forward ? rule.getStmt1() : rule.getStmt2()) + "\n\n"); + for (Map.Entry entry : match.getAssocs().entrySet()) { + builder.append(" - " + entry.getKey() + "::" + (ctx == null ? "?" : entry.getKey().getResultingDataType(ctx)) + " -> " + entry.getValue().getId() + "::" + (ctx == null ? "?" : entry.getValue().getResultingDataType(ctx)) + "\n"); + } + builder.append("\n"); + } + + return builder.toString(); + } + + @Override + public String toString() { + return toString(null); + } + } + + private RuleContext ctx; + private List rules; + private Map>> accelerator; + + public RewriterRuleSet(RuleContext ctx, List rules) { + this.ctx = ctx; + this.rules = rules; + accelerate(); + } + + public RuleContext getContext() { + return ctx; + } + + public void determineConditionalApplicability() { + rules.forEach(RewriterRule::determineConditionalApplicability); + } + + public void forEachRule(BiConsumer consumer) { + rules.forEach(r -> consumer.accept(r, ctx)); + } + + public List getRules() { + return rules; + } + + public ApplicableRule acceleratedFindFirst(RewriterStatement root) { + return acceleratedFindFirst(root, false); + } + + public ApplicableRule acceleratedFindFirst(RewriterStatement root, boolean allowImplicitTypeConversions) { + List match = acceleratedRecursiveMatch(root, true, allowImplicitTypeConversions); + if (match.isEmpty()) + return null; + else + return match.get(0); + } + + public List acceleratedRecursiveMatch(RewriterStatement root, boolean findFirst, boolean allowImplicitTypeConversions) { + List> matches = new ArrayList<>(); + MutableObject> dependencyMap = new MutableObject<>(new HashMap<>()); + MutableObject> links = new MutableObject<>(new ArrayList<>()); + MutableObject> linkObjects = new MutableObject<>(new HashMap<>()); + + root.forEachPreOrder((el, pred) -> { + String typedStr = el.isInstruction() ? el.trueTypedInstruction(allowImplicitTypeConversions, ctx) : RewriterUtils.convertImplicitly(el.getResultingDataType(ctx), allowImplicitTypeConversions); + Set props = el instanceof RewriterInstruction ? ((RewriterInstruction)el).getProperties(ctx) : Collections.emptySet(); + boolean found = acceleratedMatch(root, el, matches, typedStr, RewriterUtils.convertImplicitly(el.getResultingDataType(ctx), allowImplicitTypeConversions), props, pred, dependencyMap, links, linkObjects, findFirst, allowImplicitTypeConversions); + return !findFirst || !found; + }, true); + + Map, ApplicableRule> uniqueRules = new HashMap<>(); + + for (Tuple3 match : matches) { + Tuple2 t = new Tuple2<>(match._1(), match._2()); + + if (uniqueRules.containsKey(t)) + uniqueRules.get(t).matches.add(match._3()); + else { + ArrayList list = new ArrayList<>(); + list.add(match._3()); + uniqueRules.put(t, new ApplicableRule(list, match._1(), match._2())); + } + } + + return new ArrayList<>(uniqueRules.values()); + } + + public boolean acceleratedMatch(RewriterStatement exprRoot, RewriterStatement stmt, List> appRules, String realTypedInstr, String realType, Set properties, RewriterStatement.RewriterPredecessor pred, MutableObject> dependencyMap, MutableObject> links, MutableObject> linkObjects, boolean findFirst, boolean allowImplicitTypeConversions) { + List> potentialMatches; + boolean foundMatch = false; + + if (realTypedInstr != null) { + potentialMatches = accelerator.get(realTypedInstr); + if (potentialMatches != null) { + foundMatch |= checkPotentialMatches(stmt, potentialMatches, appRules, pred, dependencyMap, links, linkObjects, exprRoot, findFirst, allowImplicitTypeConversions); + + if (foundMatch && findFirst) + return true; + } + } + + potentialMatches = accelerator.get(realType); + if (potentialMatches != null) { + foundMatch |= checkPotentialMatches(stmt, potentialMatches, appRules, pred, dependencyMap, links, linkObjects, exprRoot, findFirst, allowImplicitTypeConversions); + + if (foundMatch && findFirst) + return true; + } + + if (properties != null) { + for (String props : properties) { + potentialMatches = accelerator.get(props); + if (potentialMatches != null) { + foundMatch |= checkPotentialMatches(stmt, potentialMatches, appRules, pred, dependencyMap, links, linkObjects, exprRoot, findFirst, allowImplicitTypeConversions); + + if (foundMatch && findFirst) + return true; + } + } + } + + return foundMatch; + } + + private boolean checkPotentialMatches(RewriterStatement stmt, List> potentialMatches, List> appRules, RewriterStatement.RewriterPredecessor pred, MutableObject> dependencyMap, MutableObject> links, MutableObject> linkObjects, RewriterStatement exprRoot, boolean findFirst, boolean allowImplicitTypeConversions) { + boolean anyMatch = false; + for (Tuple2 m : potentialMatches) { + RewriterStatement.MatchingSubexpression match; + + if (m._2()) { + match = m._1().matchSingleStmt1(exprRoot, pred, stmt, allowImplicitTypeConversions); + } else { + match = m._1().matchSingleStmt2(exprRoot, pred, stmt, allowImplicitTypeConversions); + } + + if (match != null) { + appRules.add(new Tuple3<>(m._1(), m._2(), match)); + dependencyMap.setValue(new HashMap<>()); + links.setValue(new ArrayList<>()); + linkObjects.setValue(new HashMap<>()); + + if (findFirst) + return true; + + anyMatch = true; + } else { + dependencyMap.getValue().clear(); + links.getValue().clear(); + linkObjects.getValue().clear(); + } + } + + return anyMatch; + } + + // Look for intersecting roots and try to find them once + public void accelerate() { + accelerator = new HashMap<>(); + for (RewriterRule rule : rules) { + accelerate(rule, true); + if (!rule.isUnidirectional()) + accelerate(rule, false); + } + } + + private void accelerate(RewriterRule rule, boolean forward) { + RewriterStatement stmt = forward ? rule.getStmt1() : rule.getStmt2(); + String t = stmt.isInstruction() ? stmt.trueTypedInstruction(ctx) : stmt.getResultingDataType(ctx); + List> l = accelerator.get(t); + + if (l == null) { + l = new ArrayList<>(); + accelerator.put(t, l); + } + + l.add(new Tuple2<>(rule, forward)); + } + + @Override + public String toString() { + return serialize(); + } + + public String serialize() { + StringBuilder sb = new StringBuilder(); + + for (RewriterRule rule : rules) { + try { + sb.append("::RULE\n"); + sb.append(rule.toParsableString(ctx)); + sb.append("\n\n"); + } catch (Exception e) { + e.printStackTrace(); + } + } + + return sb.toString(); + } + + public Set generateCodeAndTest(boolean optimize, boolean print) { + String javaCode = toJavaCode("MGeneratedRewriteClass", optimize, false, true, true); + Function f = RewriterCodeGen.compile(javaCode, "MGeneratedRewriteClass"); + + if (f == null) + return null; // Then, the code could not compile + + Set removed = new HashSet<>(); + + for (int i = 0; i < rules.size(); i++) { + if (!RewriterRuleCreator.validateRuleApplicability(rules.get(i), ctx, print, f)) { + System.out.println("Faulty rule: " + rules.get(i)); + removed.add(rules.get(i)); + } + } + + return removed; + } + + public static RewriterRuleSet deserialize(String data, final RuleContext ctx) { + return deserialize(data.split("\n"), ctx); + } + + public static RewriterRuleSet deserialize(List data, final RuleContext ctx) { + return deserialize(data.toArray(String[]::new), ctx); + } + + public static RewriterRuleSet deserialize(String[] data, final RuleContext ctx) { + List currentLines = new ArrayList<>(); + List rules = new ArrayList<>(); + + for (int i = 0; i < data.length; i++) { + if (data[i].equals("::RULE")) { + if (!currentLines.isEmpty()) { + try { + rules.add(RewriterUtils.parseRule(String.join("\n", currentLines), ctx)); + } catch (Exception e) { + System.err.println("An error occurred while parsing the rule:\n" + String.join("\n", currentLines)); + e.printStackTrace(); + } + currentLines.clear(); + } + } else { + currentLines.add(data[i]); + } + } + + if (!currentLines.isEmpty()) { + rules.add(RewriterUtils.parseRule(String.join("\n", currentLines), ctx)); + currentLines.clear(); + } + + for (RewriterRule rule : rules) { + try { + rule.determineConditionalApplicability(); + } catch (Exception e) { + System.err.println("Error while determining the conditional ability of " + rule.toString()); + e.printStackTrace(); + } + } + + return new RewriterRuleSet(ctx, rules); + } + + public String toJavaCode(String className, boolean optimize, boolean includePackageInfo, boolean printErrors, boolean maintainStatistics) { + List> mRules = IntStream.range(0, rules.size()).mapToObj(i -> new Tuple2<>("_applyRewrite" + i, rules.get(i))).collect(Collectors.toList()); + return RewriterCodeGen.generateClass(className, mRules, optimize, 2, includePackageInfo, ctx, true, printErrors, maintainStatistics); + } + + public String toJavaCode(String className, boolean optimize) { + List> mRules = IntStream.range(0, rules.size()).mapToObj(i -> new Tuple2<>("_applyRewrite" + i, rules.get(i))).collect(Collectors.toList()); + return RewriterCodeGen.generateClass(className, mRules, optimize, 2, true, ctx, true, true, false); + } + + public String toJavaCode(String className, boolean optimize, int maxOptimizationDepth, boolean includePackageInfo, boolean printErrors, boolean maintainStatistics) { + List> mRules = IntStream.range(0, rules.size()).mapToObj(i -> new Tuple2<>("_applyRewrite" + i, rules.get(i))).collect(Collectors.toList()); + return RewriterCodeGen.generateClass(className, mRules, optimize, maxOptimizationDepth, includePackageInfo, ctx, true, printErrors, maintainStatistics); + } + + public Function compile(String className, boolean printErrors) { + try { + List> mRules = IntStream.range(0, rules.size()).mapToObj(i -> new Tuple2<>("_applyRewrite" + i, rules.get(i))).collect(Collectors.toList()); + return RewriterCodeGen.compileRewrites(className, mRules, ctx, true, printErrors); + } catch (Exception e) { + if (printErrors) + e.printStackTrace(); + + return null; + } + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/utils/CodeGenUtils.java b/src/main/java/org/apache/sysds/hops/rewriter/utils/CodeGenUtils.java new file mode 100644 index 00000000000..5351cecdd68 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/utils/CodeGenUtils.java @@ -0,0 +1,540 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter.utils; + +import org.apache.commons.lang3.NotImplementedException; +import org.apache.sysds.hops.DataGenOp; +import org.apache.sysds.hops.rewrite.HopRewriteUtils; +import org.apache.sysds.hops.rewriter.assertions.RewriterAssertions; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.RuleContext; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +public class CodeGenUtils { + // Function to access child statement (which are not neccessarily through .getInput(n)) + public static String getChildAccessor(String parentVar, RewriterStatement stmt, int childIdx) { + switch (stmt.trueInstruction()) { + case "const": + if (childIdx != 1) + return null; + + if (stmt.getChild(1).isLiteral() && Math.abs(stmt.getChild(1).floatLiteral()) == 0.0) + return "new LiteralOp(0.0D)"; // as this might be nnz = 0 and not DataGenOp + return "((DataGenOp)" + parentVar + ").getConstantValue()"; + } + + return parentVar + ".getInput(" + childIdx + ")"; + } + + public static String getSpecialOpCheck(RewriterStatement stmt, final RuleContext ctx, String hopVar) { + if (!stmt.isInstruction()) + return null; + switch (stmt.trueInstruction()) { + case "%*%": + return "HopRewriteUtils.isMatrixMultiply(" + hopVar + ")"; + case "const": + if (stmt.getChild(1).isLiteral()) { + if (Math.abs(stmt.getChild(1).floatLiteral()) == 0.0) // Then this also holds for nnz=0 + return "HopRewriteUtils.isDataGenOpWithConstantValue(" + hopVar + ", " + stmt.getChild(1).floatLiteral() + ") || " + hopVar + ".getNnz() == 0"; + return "HopRewriteUtils.isDataGenOpWithConstantValue(" + hopVar + ", " + stmt.getChild(1).floatLiteral() + ")"; + } else + return "HopRewriteUtils.isDataGenOpWithConstantValue(" + hopVar + ")"; + } + + return null; + } + + public static String getAdditionalCheck(RewriterStatement stmt, final RuleContext ctx, String hopVar) { + if (!stmt.isInstruction()) + return null; + + switch (stmt.trueInstruction()) { + case "rowSums": + return hopVar + ".getDirection() == Types.Direction.Row"; + case "colSums": + return hopVar + ".getDirection() == Types.Direction.Col"; + case "sum": + return hopVar + ".getDirection() == Types.Direction.RowCol"; + } + + return null; + } + + public static String getOpCode(RewriterStatement stmt, final RuleContext ctx) { + if (stmt.getOperands().size() == 1) { + // Handle unary ops + switch (stmt.trueInstruction()) { + case "t": + return "Types.ReOrgOp.TRANS"; + case "rev": + return "Types.ReOrgOp.REV"; + case "!": + return "Types.OpOp1.NOT"; + case "sqrt": + return "Types.OpOp1.SQRT"; + //case "sq": + // return "Types.OpOp1.POW2"; // POW2 does not seem to work in all cases when applying the rewrite (e.g., LinearLogRegTest) + case "log": + return "Types.OpOp1.LOG"; + case "log_nz": + return "Types.OpOp1.LOG_NZ"; + case "abs": + return "Types.OpOp1.ABS"; + case "round": + return "Types.OpOp1.ROUND"; + case "exp": + return "Types.OpOp1.EXP"; + case "rowSums": + case "colSums": + case "sum": + return "Types.AggOp.SUM"; + case "sumSq": + return "Types.AggOp.SUM_SQ"; + case "trace": + return "Types.AggOp.TRACE"; + case "*2": + return "Types.OpOp1.MULT2"; + case "cast.MATRIX": + return "Types.OpOp1.CAST_AS_MATRIX"; + case "cast.FLOAT": + return "Types.OpOp1.CAST_AS_SCALAR"; + case "const": + return "Types.OpOpDG.RAND"; + case "nrow": + return "Types.OpOp1.NROW"; + case "ncol": + return "Types.OpOp1.NCOL"; + case "length": + return "Types.OpOp1.LENGTH"; + } + } else if (stmt.getOperands().size() == 2) { + switch (stmt.trueInstruction()) { + case "+": + if (stmt.getOperands().size() != 2) + throw new IllegalArgumentException(); + + return "Types.OpOp2.PLUS"; + case "-": + if (stmt.getOperands().size() != 2) + throw new IllegalArgumentException(); + + return "Types.OpOp2.MINUS"; + case "*": + if (stmt.getOperands().size() != 2) + throw new IllegalArgumentException(); + + return "Types.OpOp2.MULT"; + case "/": + if (stmt.getOperands().size() != 2) + throw new IllegalArgumentException(); + + return "Types.OpOp2.DIV"; + case "min": + if (stmt.getOperands().size() != 2) + throw new IllegalArgumentException(); + + return "Types.OpOp2.MIN"; + case "max": + if (stmt.getOperands().size() != 2) + throw new IllegalArgumentException(); + + return "Types.OpOp2.MAX"; + case "!=": + if (stmt.getOperands().size() != 2) + throw new IllegalArgumentException(); + + return "Types.OpOp2.NOTEQUAL"; + case "==": + if (stmt.getOperands().size() != 2) + throw new IllegalArgumentException(); + + return "Types.OpOp2.EQUAL"; + case ">": + if (stmt.getOperands().size() != 2) + throw new IllegalArgumentException(); + + return "Types.OpOp2.GREATER"; + case ">=": + if (stmt.getOperands().size() != 2) + throw new IllegalArgumentException(); + + return "Types.OpOp2.GREATEREQUAL"; + case "<": + if (stmt.getOperands().size() != 2) + throw new IllegalArgumentException(); + + return "Types.OpOp2.LESS"; + case "<=": + if (stmt.getOperands().size() != 2) + throw new IllegalArgumentException(); + + return "Types.OpOp2.LESSEQUAL"; + case "&": + if (stmt.getOperands().size() != 2) + throw new IllegalArgumentException(); + + return "Types.OpOp2.AND"; + case "|": + if (stmt.getOperands().size() != 2) + throw new IllegalArgumentException(); + + return "Types.OpOp2.OR"; + case "^": + if (stmt.getOperands().size() != 2) + throw new IllegalArgumentException(); + + return "Types.OpOp2.POW"; + + case "RBind": + if (stmt.getOperands().size() != 2) + throw new IllegalArgumentException(); + + return "Types.OpOp2.RBIND"; + case "CBind": + if (stmt.getOperands().size() != 2) + throw new IllegalArgumentException(); + + return "Types.OpOp2.CBIND"; + case "1-*": + if (stmt.getOperands().size() != 2) + throw new IllegalArgumentException(); + + return "Types.OpOp2.MINUS1_MULT"; + case "log_nz": + if (stmt.getOperands().size() != 1) + throw new IllegalArgumentException(); + + return "Types.OpOp1.LOG_NZ"; + + case "%*%": + return "true"; // This should be resolved by the custom handler function + } + } else { + switch (stmt.trueInstruction()) { + case "+*": + if (stmt.getOperands().size() != 3) + throw new IllegalArgumentException(); + + return "Types.OpOp3.PLUS_MULT"; + case "-*": + if (stmt.getOperands().size() != 3) + throw new IllegalArgumentException(); + + return "Types.OpOp3.MINUS_MULT"; + case "literal.FLOAT": + return null; // There is no opcheck on literals + } + } + + throw new NotImplementedException(stmt.trueInstruction()); + } + + /** + * + * @param stmt the statement + * @param ctx the context + * @return a list of operand indices that must be matched + */ + public static List matchingDimRequirement(RewriterStatement stmt, final RuleContext ctx) { + switch (stmt.trueInstruction()) { + case "1-*": + return List.of(0, 1); + case "+*": + case "-*": + return List.of(0, 2); + default: + return Collections.emptyList(); + } + } + + public static boolean opRequiresBinaryBroadcastingMatch(RewriterStatement stmt, final RuleContext ctx) { + return getOpClass(stmt, ctx).equals("BinaryOp") && stmt.getChild(0).getResultingDataType(ctx).equals("MATRIX") && stmt.getChild(1).getResultingDataType(ctx).equals("MATRIX"); + } + + public static String getOpClass(RewriterStatement stmt, final RuleContext ctx) { + switch (stmt.trueInstruction()) { + case "!": + case "sqrt": + case "log": + case "log_nz": + case "abs": + case "round": + case "*2": + case "cast.MATRIX": + case "cast.FLOAT": + case "nrow": + case "ncol": + case "length": + //case "sq": // SQ does not appear to work in some cases + case "exp": + return "UnaryOp"; + + case "rowSums": + case "colSums": + case "sum": + case "sumSq": + case "trace": + return "AggUnaryOp"; + + case "+": + case "-": + case "*": + case "/": + case "min": + case "max": + case "!=": + case "==": + case ">": + case ">=": + case "<": + case "<=": + case "&": + case "|": + case "^": + case "RBind": + case "CBind": + case "1-*": + if (stmt.getOperands().size() != 2) + throw new IllegalArgumentException(); + + return "BinaryOp"; + + case "%*%": + if (stmt.getOperands().size() != 2) + throw new IllegalArgumentException(); + + return "AggBinaryOp"; + + case "t": + case "rev": + return "ReorgOp"; + + case "+*": + case "-*": + return "TernaryOp"; + + case "const": + return "DataGenOp"; + + case "literal.FLOAT": + case "literal.INT": + case "literal.BOOL": + return "LiteralOp"; + } + + throw new NotImplementedException(stmt.trueTypedInstruction(ctx)); + } + + public static String[] getReturnType(RewriterStatement stmt, final RuleContext ctx) { + return getReturnType(stmt.getResultingDataType(ctx)); + } + + public static String[] getReturnType(String typeStr) { + switch (typeStr) { + case "FLOAT": + return new String[] { "Types.DataType.SCALAR", "Types.ValueType.FP64", "Types.ValueType.FP32" }; + case "INT": + return new String[] { "Types.DataType.SCALAR", "Types.ValueType.INT64", "Types.ValueType.INT32" }; + case "BOOL": + return new String[] { "Types.DataType.SCALAR", "Types.ValueType.BOOLEAN" }; + case "MATRIX": + return new String[] { "Types.DataType.MATRIX" }; + } + + throw new NotImplementedException(typeStr); + } + + public static String literalGetterFunction(RewriterStatement stmt, final RuleContext ctx) { + switch (stmt.getResultingDataType(ctx)) { + case "INT": + return "getLongValue()"; + case "FLOAT": + return "getDoubleValue()"; + case "BOOL": + return "getBooleanValue()"; + } + + throw new IllegalArgumentException(); + } + + public static String getHopConstructor(RewriterStatement cur, RewriterAssertions assertions, Map varNameMapping, final RuleContext ctx, String... children) { + String opClass = getOpClass(cur, ctx); + String opCode = null; + + for (int i = 0; i < children.length; i++) + if (children[i] == null) + throw new IllegalArgumentException("The argument " + i + " is null: " + cur.toParsableString(ctx)); + + // Special instructions + switch (cur.trueInstruction()) { + case "%*%": + if (children.length != 2) + throw new IllegalArgumentException(); + + return "HopRewriteUtils.createMatrixMultiply(" + children[0] + ", " + children[1] + ")"; + + case "t": + if (children.length != 1) + throw new IllegalArgumentException(); + return "HopRewriteUtils.createTranspose(" + children[0] + ")"; + + case "rev": + if (children.length != 1) + throw new IllegalArgumentException(); + return "HopRewriteUtils.createReorg(" + children[0] + ", Types.ReOrgOp.REV)"; + + case "rowSums": + if (children.length != 1) + throw new IllegalArgumentException(); + + return "HopRewriteUtils.createAggUnaryOp(" + children[0] + ", Types.AggOp.SUM, Types.Direction.Row)"; + + case "colSums": + if (children.length != 1) + throw new IllegalArgumentException(); + + return "HopRewriteUtils.createAggUnaryOp(" + children[0] + ", Types.AggOp.SUM, Types.Direction.Col)"; + + case "sum": + if (children.length != 1) + throw new IllegalArgumentException(); + + return "HopRewriteUtils.createAggUnaryOp(" + children[0] + ", Types.AggOp.SUM, Types.Direction.RowCol)"; + + case "sumSq": + if (children.length != 1) + throw new IllegalArgumentException(); + + return "HopRewriteUtils.createAggUnaryOp(" + children[0] + ", Types.AggOp.SUM_SQ, Types.Direction.RowCol)"; + case "trace": + if (children.length != 1) + throw new IllegalArgumentException(); + + return "HopRewriteUtils.createAggUnaryOp(" + children[0] + ", Types.AggOp.TRACE, Types.Direction.RowCol)"; + + case "ncol": + if (children.length != 1) + throw new IllegalArgumentException(); + + return "HopRewriteUtils.createUnary(" + children[0] + ", Types.OpOp1.NCOL)"; + + case "nrow": + if (children.length != 1) + throw new IllegalArgumentException(); + + return "HopRewriteUtils.createUnary(" + children[0] + ", Types.OpOp1.NROW)"; + + case "const": + String referredVarName = varNameMapping.get(cur.getChild(0)); + String nrowContent; + String ncolContent; + + if (referredVarName == null) { + Optional nrowLiteral = cur.getNRow().isLiteral() ? Optional.of(cur.getNRow()) : Optional.empty(); + Optional ncolLiteral = cur.getNCol().isLiteral() ? Optional.of(cur.getNCol()) : Optional.empty(); + + RewriterAssertions.RewriterAssertion nrowAssertion = assertions.getAssertionObj(cur.getNRow()); + RewriterAssertions.RewriterAssertion ncolAssertion = assertions.getAssertionObj(cur.getNCol()); + + nrowLiteral = nrowAssertion == null ? nrowLiteral : nrowAssertion.getLiteral(); + ncolLiteral = ncolAssertion == null ? ncolLiteral : ncolAssertion.getLiteral(); + + + if (nrowLiteral.isPresent()) { + nrowContent = "new LiteralOp(" + nrowLiteral.get().getLiteral().toString() + ")"; + } else { + // Find the first + nrowContent = null; + + if (nrowAssertion == null) + throw new IllegalArgumentException(); + + for (RewriterStatement stmt : nrowAssertion.getEClass()) { + String mappedName = varNameMapping.get(stmt); + + if (mappedName != null) { + nrowContent = getHopConstructor(stmt, assertions, varNameMapping, ctx, mappedName); + if (nrowContent != null) + break; + } + } + + if (nrowContent == null) + throw new IllegalArgumentException(nrowAssertion.toString()); + } + + if (ncolLiteral.isPresent()) { + ncolContent = "new LiteralOp(" + ncolLiteral.get().getLiteral().toString() + ")"; + } else { + // Find the first + ncolContent = null; + + if (ncolAssertion == null) + throw new IllegalArgumentException(); + + for (RewriterStatement stmt : ncolAssertion.getEClass()) { + String mappedName = varNameMapping.get(stmt); + + if (mappedName != null) { + ncolContent = getHopConstructor(stmt, assertions, varNameMapping, ctx, mappedName); + break; + } + } + + if (ncolContent == null) + throw new IllegalArgumentException(); + } + } else { + nrowContent = getHopConstructor(cur.getChild(0).getNRow(), assertions, varNameMapping, ctx, referredVarName); + ncolContent = getHopConstructor(cur.getChild(0).getNCol(), assertions, varNameMapping, ctx, referredVarName); + } + + return "((DataGenOp) HopRewriteUtils.createDataGenOpFromDims(" + nrowContent + "," + ncolContent + "," + cur.getChild(1).getLiteral() + "D))"; + } + + switch (opClass) { + case "UnaryOp": + if (children.length != 1) + throw new IllegalArgumentException(); + + opCode = getOpCode(cur, ctx); + return "HopRewriteUtils.createUnary(" + children[0] + ", " + opCode + ")"; + case "BinaryOp": + if (children.length != 2) + throw new IllegalArgumentException(); + + opCode = getOpCode(cur, ctx); + return "HopRewriteUtils.createAutoGeneratedBinary(" + children[0] + ", " + children[1] + ", " + opCode + ")"; + case "TernaryOp": + if (children.length != 3) + throw new IllegalArgumentException(); + + opCode = getOpCode(cur, ctx); + return "HopRewriteUtils.createTernary(" + children[0] + ", " + children[1] + ", " + children[2] + "," + opCode + ")"; + } + + throw new NotImplementedException(cur.trueTypedInstruction(ctx)); + } + +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/utils/ConstantFoldingUtils.java b/src/main/java/org/apache/sysds/hops/rewriter/utils/ConstantFoldingUtils.java new file mode 100644 index 00000000000..b46fb0e62b0 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/utils/ConstantFoldingUtils.java @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter.utils; + +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.RuleContext; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.function.BiFunction; + +public class ConstantFoldingUtils { + static final double EPS = 1e-20; + + public static BiFunction foldingBiFunction(String op, String type) { + switch (op) { + case "+": + if (type.equals("FLOAT")) + return (num, stmt) -> foldSumFloat(num == null ? 0.0 : (double)num, stmt); + else if (type.equals("INT")) + return (num, stmt) -> foldSumInt(num == null ? 0L : (long)num, stmt); + else + throw new UnsupportedOperationException(); + case "*": + if (type.equals("FLOAT")) + return (num, stmt) -> foldMulFloat(num == null ? 1.0D : (double)num, stmt); + else if (type.equals("INT")) + return (num, stmt) -> foldMulInt(num == null ? 1L : (long)num, stmt); + else + throw new UnsupportedOperationException(); + case "min": + if (type.equals("FLOAT")) + return (num, stmt) -> num == null ? stmt.floatLiteral() : foldMinFloat((double)num, stmt); + else if (type.equals("INT")) + return (num, stmt) -> num == null ? stmt.intLiteral(false) : foldMinInt((long)num, stmt); + break; + case "max": + if (type.equals("FLOAT")) + return (num, stmt) -> num == null ? stmt.floatLiteral() : foldMaxFloat((double)num, stmt); + else if (type.equals("INT")) + return (num, stmt) -> num == null ? stmt.intLiteral(false) : foldMaxInt((long)num, stmt); + break; + } + + throw new UnsupportedOperationException(); + } + + public static boolean isNeutralElement(Object num, String op) { + switch (op) { + case "+": + return num.equals(0L) || num.equals(0.0D); + case "*": + return num.equals(1L) || num.equals(1.0D); + } + + return false; + } + + public static boolean isNegNeutral(Object num, String op) { + if (num == null) + return false; + + switch (op) { + case "*": + return num.equals(-1L) || num.equals(-1.0D); + } + + return false; + } + + public static boolean cancelOutNary(String op, List stmts) { + Set toRemove = new HashSet<>(); + switch (op) { + case "+": + for (int i = 0; i < stmts.size(); i++) { + RewriterStatement stmt1 = stmts.get(i); + for (int j = i+1; j < stmts.size(); j++) { + RewriterStatement stmt2 = stmts.get(j); + + if (stmt1.isInstruction() && stmt1.trueInstruction().equals("-") && stmt1.getChild(0).equals(stmt2) + || (stmt2.isInstruction() && stmt2.trueInstruction().equals("-") && stmt2.getChild(0).equals(stmt1))) { + if (!toRemove.contains(i) && !toRemove.contains(j)) { + toRemove.add(i); + toRemove.add(j); + } + } + + } + } + case "*": + for (int i = 0; i < stmts.size(); i++) { + RewriterStatement stmt1 = stmts.get(i); + for (int j = i+1; j < stmts.size(); j++) { + RewriterStatement stmt2 = stmts.get(j); + + if (stmt1.isInstruction() && stmt1.trueInstruction().equals("inv") && stmt1.getChild(0).equals(stmt2) + || (stmt2.isInstruction() && stmt2.trueInstruction().equals("inv") && stmt2.getChild(0).equals(stmt1))) { + if (!toRemove.contains(i) && !toRemove.contains(j)) { + toRemove.add(i); + toRemove.add(j); + } + } + + } + } + } + + if (toRemove.isEmpty()) + return false; + + List oldCpy = new ArrayList<>(stmts); + stmts.clear(); + + for (int i = 0; i < oldCpy.size(); i++) { + if (!toRemove.contains(i)) + stmts.add(oldCpy.get(i)); + } + + return true; + } + + // This function does not handle NaNs + public static RewriterStatement overwritesLiteral(Number num, String op, final RuleContext ctx) { + if (op.equals("*") && Math.abs(num.doubleValue()) < EPS) { + if (num instanceof Double) + return RewriterStatement.literal(ctx, 0.0); + else + return RewriterStatement.literal(ctx, 0L); + } + + return null; + } + + public static double foldSumFloat(double num, RewriterStatement next) { + return num + next.floatLiteral(); + } + + public static long foldSumInt(long num, RewriterStatement next) { + return num + next.intLiteral(false); + } + + public static double foldMulFloat(double num, RewriterStatement next) { + return num * next.floatLiteral(); + } + + public static long foldMulInt(long num, RewriterStatement next) { + return num * next.intLiteral(false); + } + + public static double foldMinFloat(double num, RewriterStatement next) { + return Math.min(num, next.floatLiteral()); + } + + public static long foldMinInt(long num, RewriterStatement next) { + return Math.min(num, next.intLiteral(false)); + } + + public static double foldMaxFloat(double num, RewriterStatement next) { + return Math.max(num, next.floatLiteral()); + } + + public static long foldMaxInt(long num, RewriterStatement next) { + return Math.max(num, next.intLiteral(false)); + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/utils/RewriterSearchUtils.java b/src/main/java/org/apache/sysds/hops/rewriter/utils/RewriterSearchUtils.java new file mode 100644 index 00000000000..daaafa71612 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/utils/RewriterSearchUtils.java @@ -0,0 +1,618 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter.utils; + +import org.apache.sysds.hops.rewriter.RewriterDataType; +import org.apache.sysds.hops.rewriter.RewriterInstruction; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.RuleContext; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.Set; +import java.util.UUID; +import java.util.stream.Collectors; + +public class RewriterSearchUtils { + public static final List ALL_TYPES = List.of("MATRIX", "FLOAT"); + public static final List SCALAR = List.of("FLOAT"); + public static final List MATRIX = List.of("MATRIX"); + + public static Operand[] instructionAlphabet = new Operand[] { + null, + new Operand("+", 2, ALL_TYPES, ALL_TYPES), + //new Operand("+", 2, MATRIX, SCALAR), + //new Operand("+", 2, MATRIX, MATRIX), + + new Operand("-", 2, ALL_TYPES, ALL_TYPES), + //new Operand("-", 2, MATRIX, SCALAR), + //new Operand("-", 2, MATRIX, MATRIX), + + new Operand("*", 2, ALL_TYPES, ALL_TYPES), + //new Operand("*", 2, MATRIX, SCALAR), + //new Operand("*", 2, MATRIX, MATRIX), + + new Operand("/", 2, ALL_TYPES, ALL_TYPES), + //new Operand("/", 2, MATRIX, SCALAR), + //new Operand("/", 2, MATRIX, MATRIX), + + new Operand("%*%", 2, MATRIX, MATRIX), + + new Operand("sum", 1, MATRIX), + new Operand("*sum", 2, MATRIX, ALL_TYPES), // To have a bigger search space for this instruction combination + new Operand("t", 1, MATRIX), + new Operand("rev", 1, MATRIX), + new Operand("diag", 1, MATRIX), + new Operand("trace", 1, MATRIX), + new Operand("rowSums", 1, MATRIX), + new Operand("colSums", 1, MATRIX), + new Operand("max", 1, MATRIX), + new Operand("min", 1, MATRIX), + new Operand("ncol", 0, true, MATRIX), + new Operand("nrow", 0, true, MATRIX), + new Operand("length", 0, true, MATRIX), + + new Operand("!=", 2, ALL_TYPES, ALL_TYPES), + new Operand("!=0", 1, MATRIX), + new Operand("0!=", 1, MATRIX), + + new Operand("cast.MATRIX",1, SCALAR), + new Operand("cast.FLOAT", 1, MATRIX), + + new Operand("1-*", 2, MATRIX, MATRIX), + new Operand("+*", 3, MATRIX, SCALAR, MATRIX), + new Operand("-*", 3, MATRIX, SCALAR, MATRIX), + new Operand("*2", 1, MATRIX), + new Operand("_nnz", 1, MATRIX), + new Operand("sumSq", 1, MATRIX), + new Operand("sq", 1, MATRIX), + //new Operand("log", 1, MATRIX), + + // constant stuff + new Operand("c_1+", 1, ALL_TYPES), + new Operand("c_+1", 1, ALL_TYPES), + new Operand("c_1-", 1, ALL_TYPES), + new Operand("c_-1", 1, ALL_TYPES), + + // ncol / nrow / length stuff + new Operand("c_length*", 1, ALL_TYPES), + new Operand("c_ncol*", 1, ALL_TYPES), + new Operand("c_nrow*", 1, ALL_TYPES), + + new Operand("log_nz", 1, MATRIX), + + // Placeholder operators + new Operand("zero", 0, true), + new Operand("one", 0, true) + }; + + private static String[] varNames = new String[] { + "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M" + }; + + private static RuleContext ctx; + + public static int getMaxSearchNumberForNumOps(int numOps) { + int out = 1; + for (int i = 0; i < numOps; i++) + out *= instructionAlphabet.length; + + return out; + } + + public static void rename(RewriterStatement stmt) { + Set namedVars = new HashSet<>(); + + stmt.forEachPostOrder((cur, pred) -> { + if (!cur.isInstruction() && !cur.isLiteral()) { + if (!namedVars.contains(cur)) { + if (cur.getResultingDataType(ctx).equals("MATRIX")) + cur.rename(varNames[namedVars.size()]); + else + cur.rename(varNames[namedVars.size()].toLowerCase()); + + namedVars.add(cur); + } + } + }, false); + } + + // To include structures like row/column vectors etc. + public static List buildAssertionVariations(RewriterStatement root, final RuleContext ctx) { + List interestingLeaves = new ArrayList<>(); + root.forEachPreOrder(cur -> { + if (!cur.isInstruction() && !cur.isLiteral() && cur.getResultingDataType(ctx).equals("MATRIX")) + interestingLeaves.add(cur); + return true; + }, true); + + if (interestingLeaves.isEmpty()) + return Collections.emptyList(); + + List out = new ArrayList<>(); + + for (int i = 0; i < interestingLeaves.size(); i++) { + RewriterStatement from = interestingLeaves.get(i); + RewriterStatement rv = createVectorizedStatement(root, from, true); + if (ctx.metaPropagator != null) + rv = ctx.metaPropagator.apply(rv); + out.add(rv); + RewriterStatement cv = createVectorizedStatement(root, from, false); + if (ctx.metaPropagator != null) + cv = ctx.metaPropagator.apply(cv); + out.add(cv); + + for (int j = i + 1; j < interestingLeaves.size(); j++) { + RewriterStatement to = interestingLeaves.get(i); + Map map = new HashMap<>(); + map.put(from, false); + map.put(to, false); + out.add(createVectorizedStatements(root, map)); + map.put(from, true); + out.add(createVectorizedStatements(root, map)); + map.put(to, true); + out.add(createVectorizedStatements(root, map)); + map.put(from, false); + out.add(createVectorizedStatements(root, map)); + } + } + + // Serialize and parse again as there may still be duplicate references + out = out.stream().map(stmt -> RewriterUtils.parse(stmt.toParsableString(ctx, true), ctx)).collect(Collectors.toList()); + + if (ctx.metaPropagator != null) + return out.stream().map(stmt -> ctx.metaPropagator.apply(stmt)).collect(Collectors.toList()); + + return out; + } + + private static RewriterStatement createVector(RewriterStatement of, boolean rowVector, Map createdObjects) { + // TODO: Why is it necessary to discard the old DataType? + RewriterStatement mCpy = createdObjects.get(of); + + if (mCpy == null) { + mCpy = new RewriterDataType().as(of.getId()).ofType(of.getResultingDataType(ctx)).consolidate(ctx); + createdObjects.put(of, mCpy); + } + //RewriterStatement nRowCol = new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction(rowVector ? "nrow" : "ncol").withOps(mCpy).consolidate(ctx); + //createdObjects.put(of, mCpy); + return new RewriterInstruction() + .as(of.getId()) + .withInstruction(rowVector ? "rowVec" : "colVec") + .withOps(mCpy) + .consolidate(ctx); + } + + private static RewriterStatement createVectorizedStatement(RewriterStatement root, RewriterStatement of, boolean rowVector) { + HashMap createdObjects = new HashMap<>(); + RewriterStatement out = root.nestedCopyOrInject(createdObjects, stmt -> { + if (stmt.equals(of)) + return createVector(of, rowVector, createdObjects); + + return null; + }); + + return out; + } + + private static RewriterStatement createVectorizedStatements(RewriterStatement root, Map of) { + HashMap createdObjects = new HashMap<>(); + + RewriterStatement out = root.nestedCopyOrInject(createdObjects, stmt -> { + if (!stmt.isInstruction() && !stmt.isLiteral() && stmt.getResultingDataType(ctx).equals("MATRIX")) { + Boolean rowVector = of.get(stmt); + + if (rowVector != null) + return createVector(stmt, rowVector, createdObjects); + } + + return null; + }); + + return out; + } + + // Builds variations of the same graph (e.g. +(A,B) -> +(A,A)) + public static List buildVariations(RewriterStatement root, final RuleContext ctx) { + List interestingLeaves = new ArrayList<>(); + root.forEachPreOrder(cur -> { + if (!cur.isInstruction() && !cur.isLiteral() && cur.getResultingDataType(ctx).equals("MATRIX")) + interestingLeaves.add(cur); + return true; + }, true); + + if (interestingLeaves.size() < 2) + return Collections.emptyList(); + + List out = new ArrayList<>(); + + for (int i = 0; i < interestingLeaves.size(); i++) { + RewriterStatement to = interestingLeaves.get(i); + for (int j = i + 1; j < interestingLeaves.size(); j++) { + RewriterStatement from = interestingLeaves.get(j); + HashMap createdObjects = new HashMap<>(); + RewriterStatement toCpy = new RewriterDataType().as(to.getId()).ofType(to.getResultingDataType(ctx)).consolidate(ctx); + createdObjects.put(from, toCpy); + createdObjects.put(to, toCpy); + RewriterStatement cpy = root.nestedCopyOrInject(createdObjects, stmt -> null); + if (ctx.metaPropagator != null) + cpy = ctx.metaPropagator.apply(cpy); + out.add(cpy); + } + } + + // Serialize and parse again as there may still be duplicate references + out = out.stream().map(stmt -> RewriterUtils.parse(stmt.toParsableString(ctx, true), ctx)).collect(Collectors.toList()); + + return out; + } + + public static List buildAllPossibleDAGs(List operands, final RuleContext ctx, boolean rename) { + if (operands == null) + return Collections.emptyList(); + + RewriterSearchUtils.ctx = ctx; + + List allStmts = recursivelyFindAllCombinations(operands, null, ALL_TYPES); + + if (rename) + allStmts.forEach(RewriterSearchUtils::rename); + + if (ctx.metaPropagator != null) + allStmts = allStmts.stream().map(stmt -> ctx.metaPropagator.apply(stmt)).collect(Collectors.toList()); + + // Serialize and parse all statements as there are still duplicate references + return allStmts.stream().map(stmt -> RewriterUtils.parse(stmt.toParsableString(ctx, true), ctx)).collect(Collectors.toList()); + } + + private static List recursivelyFindAllCombinations(List operands, Operand parent, List supportedTypes) { + if (operands.isEmpty()) + return supportedTypes.stream().map(t -> new RewriterDataType().as(UUID.randomUUID().toString()).ofType(t).consolidate(ctx)).collect(Collectors.toList()); + + // Check if op is a placeholder + Operand op = operands.get(0); + + if (op.isLeaf && operands.size() > 1) + return Collections.emptyList(); + + if (op.op.equals("zero") || op.op.equals("one")) { + List l = new ArrayList<>(2); + if (op.op.equals("zero")) { + if (supportedTypes.contains("FLOAT")) + l.add(RewriterStatement.literal(ctx, 0.0D)); + if (supportedTypes.contains("MATRIX")) + l.add(new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("const").withOps(new RewriterDataType().as(UUID.randomUUID().toString()).ofType("MATRIX").consolidate(ctx), RewriterStatement.literal(ctx, 0.0D)).consolidate(ctx)); + } else { + if (supportedTypes.contains("FLOAT")) + l.add(RewriterStatement.literal(ctx, 1.0D)); + + if (supportedTypes.contains("MATRIX")) + l.add(new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("const").withOps(new RewriterDataType().as(UUID.randomUUID().toString()).ofType("MATRIX").consolidate(ctx), RewriterStatement.literal(ctx, 1.0D)).consolidate(ctx)); + } + + return l; + } + + int nOps = operands.get(0).numArgs; + + if (nOps == 0) { + return List.of(buildStmt(op, null)); + } + + int[] slices = new int[Math.max(nOps-1, 0)]; + + List possibleStmts = new ArrayList<>(); + + forEachSlice(1, 0, operands.size()+1, slices, () -> { + List> cartesianBuilder = new ArrayList<>(); + + for (int i = 0; i < nOps; i++) { + int lIdx = i == 0 ? 1 : slices[i-1]; + int uIdx = i == slices.length ? operands.size() : slices[i]; + + List view; + if (lIdx == uIdx) + view = Collections.emptyList(); + else + view = operands.subList(lIdx, uIdx); + + List combs = recursivelyFindAllCombinations(view, op, op.supportedTypes[i]); + + if (combs.isEmpty()) + return; // Then no subgraph can be created from that order + + cartesianBuilder.add(combs); + } + + RewriterStatement[] stack = new RewriterStatement[nOps]; + RewriterUtils.cartesianProduct(cartesianBuilder, stack, mStack -> { + try { + for (int i = 0; i < stack.length; i++) + if (!op.supportedTypes[i].contains(stack[i].getResultingDataType(ctx))) + return true; + + RewriterStatement stmt = buildStmt(operands.get(0), stack); + possibleStmts.add(stmt); + } catch (Exception e) { + // Might fail as there could be wrong types + } + return true; // Should continue + }); + }); + + return possibleStmts; + } + + private static RewriterStatement buildStmt(Operand op, RewriterStatement[] stack) { + RewriterInstruction stmt = new RewriterInstruction().as(UUID.randomUUID().toString()); + switch (op.op) { + case "!=0": { + stmt.withInstruction("!=").addOp(stack[0]).addOp(RewriterStatement.literal(ctx, 0.0D)); + break; + } + case "0!=": { + stmt.withInstruction("!=").addOp(RewriterStatement.literal(ctx, 0.0D)).addOp(stack[0]); + break; + } + case "ncol": + case "nrow": + case "length": { + String actualOp = op.op; + stmt.withInstruction(actualOp).withOps(new RewriterDataType().as(UUID.randomUUID().toString()).ofType("MATRIX").consolidate(ctx)).consolidate(ctx); + break; + } + case "fncol": + case "fnrow": + case "flength": { + String actualOp = op.op.substring(1); + stmt.withInstruction(actualOp).withOps(stack).consolidate(ctx); + stmt = (RewriterInstruction) RewriterStatement.castFloat(ctx, stmt); + break; + } + case "*sum": { + RewriterStatement old = stmt.withInstruction("sum").withOps(stack[0]).consolidate(ctx); + stmt = new RewriterInstruction("*", ctx, old, stack[1]); + break; + } + case "c_1+": { + stmt = new RewriterInstruction("+", ctx, RewriterStatement.literal(ctx, 1.0D), stack[0]); + break; + } + case "c_+1": { + stmt = new RewriterInstruction("+", ctx, stack[0], RewriterStatement.literal(ctx, 1.0D)); + break; + } + case "c_1-": { + stmt = new RewriterInstruction("-", ctx, RewriterStatement.literal(ctx, 1.0D), stack[0]); + break; + } + case "c_-1": { + stmt = new RewriterInstruction("-", ctx, stack[0], RewriterStatement.literal(ctx, 1.0D)); + break; + } + case "c_length*": { + stmt = new RewriterInstruction("*", ctx, new RewriterInstruction("length", ctx, new RewriterDataType().as(UUID.randomUUID().toString()).ofType("MATRIX").consolidate(ctx)), stack[0]); + break; + } + case "c_nrow*": { + stmt = new RewriterInstruction("*", ctx, new RewriterInstruction("nrow", ctx, new RewriterDataType().as(UUID.randomUUID().toString()).ofType("MATRIX").consolidate(ctx)), stack[0]); + break; + } + case "c_col*": { + stmt = new RewriterInstruction("*", ctx, new RewriterInstruction("ncol", ctx, new RewriterDataType().as(UUID.randomUUID().toString()).ofType("MATRIX").consolidate(ctx)), stack[0]); + break; + } + default: { + stmt.withInstruction(op.op).withOps(stack); + break; + } + } + + stmt.consolidate(ctx); + return stmt; + } + + private static void forEachSlice(int startIdx, int pos, int maxIdx, int[] slices, Runnable trigger) { + if (pos >= slices.length) { + trigger.run(); + return; + } + + for (int idx = startIdx; idx < maxIdx; idx++) { + slices[pos] = idx; + + if (pos != slices.length-1) { + forEachSlice(idx, pos+1, maxIdx, slices, trigger); + } else { + trigger.run(); + } + } + } + + public static List decodeOrderedStatements(int stmt) { + int[] instructions = fromBaseNNumber(stmt, instructionAlphabet.length); + List out = new ArrayList<>(instructions.length); + + for (int i = 0; i < instructions.length; i++) { + Operand toAdd = instructionAlphabet[instructions[i]]; + if (toAdd == null) + return null; + out.add(toAdd); + } + + return out; + } + + public static int[] fromBaseNNumber(int l, int n) { + if (l == 0) + return new int[0]; + + int numDigits = (int)(Math.log(l) / Math.log(n)) + 1; + int[] digits = new int[numDigits]; + + for (int i = numDigits - 1; i >= 0; i--) { + digits[i] = l % n; + l = l / n; + } + + return digits; + } + + public static int toBaseNNumber(int[] digits, int n) { + if (digits.length == 0) + throw new IllegalArgumentException(); + + int multiplicator = 1; + int out = 0; + + for (int i = digits.length - 1; i >= 0; i--) { + out += multiplicator * digits[i]; + multiplicator *= n; + } + + return out; + } + + public static List mergeSubtreeCombinations(RewriterStatement stmt, List indices, List> mList, final RuleContext ctx, int maximumCombinations) { + if (indices.isEmpty()) + return List.of(stmt); + + List mergedTreeCombinations = new ArrayList<>(); + RewriterUtils.cartesianProduct(mList, new RewriterStatement[mList.size()], stack -> { + RewriterStatement cpy = stmt.copyNode(); + for (int i = 0; i < stack.length; i++) + cpy.getOperands().set(indices.get(i), stack[i]); + cpy.consolidate(ctx); + cpy.prepareForHashing(); + cpy.recomputeHashCodes(ctx); + mergedTreeCombinations.add(cpy); + return mergedTreeCombinations.size() < maximumCombinations; + }); + + return mergedTreeCombinations; + } + + public static List generateSubtrees(RewriterStatement stmt, final RuleContext ctx, int maximumCombinations) { + List l = generateSubtrees(stmt, new HashMap<>(), ctx, maximumCombinations); + + if (ctx.metaPropagator != null) + l.forEach(subtree -> ctx.metaPropagator.apply(subtree)); + + return l.stream().map(subtree -> { + if (ctx.metaPropagator != null) + subtree = ctx.metaPropagator.apply(subtree); + + subtree.prepareForHashing(); + subtree.recomputeHashCodes(ctx); + // We return a copy of the tree as there are still duplicate references + return RewriterUtils.parse(subtree.toParsableString(ctx, true), ctx); + }).collect(Collectors.toList()); + } + + private static Random rd = new Random(); + + private static List generateSubtrees(RewriterStatement stmt, Map> visited, final RuleContext ctx, int maxCombinations) { + if (stmt == null) + return Collections.emptyList(); + + RewriterStatement is = stmt; + List alreadyVisited = visited.get(is); + + if (alreadyVisited != null) + return alreadyVisited; + + if (stmt.getOperands().size() == 0) + return List.of(stmt); + + // Scan if operand is not a DataType + List indices = new ArrayList<>(); + for (int i = 0; i < stmt.getOperands().size(); i++) { + if (stmt.getChild(i).isInstruction() || stmt.getChild(i).isLiteral()) + indices.add(i); + } + + int n = indices.size(); + int totalSubsets = 1 << n; + + List mList = new ArrayList<>(); + + visited.put(is, mList); + + List> mOptions = indices.stream().map(i -> generateSubtrees(stmt.getOperands().get(i), visited, ctx, maxCombinations)).collect(Collectors.toList()); + List out = new ArrayList<>(); + + for (int subsetMask = 0; subsetMask < totalSubsets; subsetMask++) { + List> mOptionCpy = new ArrayList<>(mOptions); + + for (int i = 0; i < n; i++) { + // Check if the i-th child is included in the current subset + if ((subsetMask & (1 << i)) == 0) { + String dt = stmt.getOperands().get(indices.get(i)).getResultingDataType(ctx); + String namePrefix = "tmp"; + if (dt.equals("MATRIX")) + namePrefix = "M"; + else if (dt.equals("FLOAT")) + namePrefix = "f"; + else if (dt.equals("INT")) + namePrefix = "i"; + else if (dt.equals("BOOL")) + namePrefix = "b"; + RewriterDataType mT = new RewriterDataType().as(namePrefix + rd.nextInt(100000)).ofType(dt); + mT.consolidate(ctx); + mOptionCpy.set(i, List.of(mT)); + } + } + + out.addAll(mergeSubtreeCombinations(stmt, indices, mOptionCpy, ctx, maxCombinations)); + if (out.size() > maxCombinations) { + System.out.println("Aborting early due to too many combinations"); + return out; + } + } + + return out; + } + + public static final class Operand { + public final String op; + public final int numArgs; + public final List[] supportedTypes; + public final boolean isLeaf; + + public Operand(String op, int numArgs, List... supportedTypes) { + this(op, numArgs, false, supportedTypes); + } + public Operand(String op, int numArgs, boolean isLeaf, List... supportedTypes) { + this.op = op; + this.numArgs = numArgs; + this.supportedTypes = supportedTypes; + this.isLeaf = isLeaf; + } + + public String toString() { + return op; + } + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/utils/RewriterUtils.java b/src/main/java/org/apache/sysds/hops/rewriter/utils/RewriterUtils.java new file mode 100644 index 00000000000..258b65002fb --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/utils/RewriterUtils.java @@ -0,0 +1,1375 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter.utils; + +import org.apache.commons.lang3.NotImplementedException; +import org.apache.commons.lang3.mutable.MutableBoolean; +import org.apache.commons.lang3.mutable.MutableInt; +import org.apache.commons.lang3.mutable.MutableObject; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.logging.log4j.util.TriConsumer; +import org.apache.sysds.hops.rewriter.MetaPropagator; +import org.apache.sysds.hops.rewriter.RewriterContextSettings; +import org.apache.sysds.hops.rewriter.RewriterDataType; +import org.apache.sysds.hops.rewriter.rule.RewriterHeuristic; +import org.apache.sysds.hops.rewriter.rule.RewriterHeuristics; +import org.apache.sysds.hops.rewriter.RewriterInstruction; +import org.apache.sysds.hops.rewriter.rule.RewriterRule; +import org.apache.sysds.hops.rewriter.rule.RewriterRuleBuilder; +import org.apache.sysds.hops.rewriter.rule.RewriterRuleCollection; +import org.apache.sysds.hops.rewriter.rule.RewriterRuleSet; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.apache.sysds.hops.rewriter.TopologicalSort; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.Set; +import java.util.UUID; +import java.util.function.BiConsumer; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +public class RewriterUtils { + protected static final Log LOG = LogFactory.getLog(RewriterUtils.class.getName()); + + public static final Pattern LONG_PATTERN = Pattern.compile("-?\\d+"); + public static final Pattern DOUBLE_PATTERN = Pattern.compile("-?\\d*\\.\\d+([eE][+-]?\\d+)?"); + public static final Pattern SPECIAL_FLOAT_PATTERN = Pattern.compile("Infinity|NaN"); + + public static String typedToUntypedInstruction(String instr) { + return instr.substring(0, instr.indexOf('(')); + } + + public static BiFunction binaryStringRepr(String op) { + return (stmt, ctx) -> { + List operands = stmt.getOperands(); + String op1Str = operands.get(0).toString(ctx); + if (operands.get(0) instanceof RewriterInstruction && operands.get(0).getOperands().size() > 1) + op1Str = "(" + op1Str + ")"; + String op2Str = operands.get(1).toString(ctx); + if (operands.get(1) instanceof RewriterInstruction && operands.get(1).getOperands().size() > 1) + op2Str = "(" + op2Str + ")"; + return op1Str + op + op2Str; + }; + } + + public static void mergeArgLists(RewriterStatement stmt, final RuleContext ctx) { + + stmt.forEachPreOrder(el -> { + tryFlattenNestedArgList(ctx, el, el, -1); + tryFlattenNestedOperatorPatterns(ctx, el); + el.refreshReturnType(ctx); + return true; + }, true); + + stmt.prepareForHashing(); + stmt.recomputeHashCodes(ctx); + } + + public static boolean tryFlattenNestedArgList(final RuleContext ctx, RewriterStatement stmt, RewriterStatement root, int insertAt) { + if (!stmt.isArgumentList()) + return false; + + if (stmt == root) { + boolean anyMatch = false; + + for (int i = 0; i < stmt.getOperands().size(); i++) { + RewriterStatement op = stmt.getOperands().get(i); + if (tryFlattenNestedArgList(ctx, op, root, i)) { + stmt.getOperands().remove(i); + anyMatch = true; + } + } + + return anyMatch; + } + + String dt1 = root.getResultingDataType(ctx); + String dt2 = stmt.getResultingDataType(ctx); + + String convertibleDataType = convertibleType(dt1.substring(0, dt1.length()-3), dt2.substring(0, dt2.length()-3)); + + if (convertibleDataType == null) + return false; + + root.getOperands().addAll(insertAt+1, stmt.getOperands()); + + return true; + } + + public static void tryFlattenNestedOperatorPatterns(final RuleContext ctx, RewriterStatement stmt) { + if (!stmt.isInstruction()) + return; + + RewriterInstruction instr = (RewriterInstruction) stmt; + + if (instr.hasProperty("FusedOperator", ctx)) { + for (int i = 0; i < instr.getOperands().get(0).getOperands().size(); i++) + if (flattenNestedOperatorPatterns(ctx, instr.getOperands().get(0).getOperands().get(i), instr, i)) + i--; + } + } + + private static boolean flattenNestedOperatorPatterns(final RuleContext ctx, RewriterStatement stmt, RewriterInstruction rootInstr, int insertAt) { + if (stmt.isInstruction() && ((RewriterInstruction)stmt).hasProperty("FusedOperator", ctx) && stmt.trueInstruction().equals(rootInstr.trueInstruction())) { + RewriterStatement origArgList = rootInstr.getOperands().get(0); + RewriterStatement subArgList = stmt.getOperands().get(0); + + origArgList.getOperands().set(insertAt, subArgList.getOperands().get(0)); + origArgList.getOperands().addAll(insertAt+1, subArgList.getOperands().subList(1, subArgList.getOperands().size())); + + return true; + } + + return false; + } + + public static RewriterStatement parse(String expr, final RuleContext ctx) { + String[] split = expr.split("\n"); + return parse(split[split.length-1], ctx, Arrays.copyOfRange(split, 0, split.length-1)); + } + + public static RewriterRule parseRule(String expr, final RuleContext ctx) { + // Remove empty lines + expr = expr.replaceAll("\n\\s*\n", "\n"); + String[] split = expr.split("\n"); + Set allowedMultiRefs = Collections.emptySet(); + boolean allowCombinations = false; + boolean parsedExtendedHeader = false; + + if (split[0].startsWith("AllowedMultiRefs:")) { + split[0] = split[0].substring(17); + String[] sSplit = split[0].split(","); + allowedMultiRefs = Arrays.stream(sSplit).map(s -> Integer.parseInt(s.substring(1))).collect(Collectors.toSet()); + + if (!split[1].startsWith("AllowCombinations:")) + throw new IllegalArgumentException(); + + split[1] = split[1].substring(18); + allowCombinations = Boolean.parseBoolean(split[1]); + parsedExtendedHeader = true; + } + + int condIdxStart = -1; + for (int i = 2; i < split.length; i++) { + if (split[i].startsWith("{")) { + // Then we have a conditional rule + condIdxStart = i; + break; + } + } + + if (condIdxStart != -1) { + // Then we have a conditional rule + List toExprs = Arrays.asList(split).subList(condIdxStart+1, split.length-1); + return parseRule(split[condIdxStart-2], toExprs, allowedMultiRefs, allowCombinations, ctx, Arrays.copyOfRange(split, parsedExtendedHeader ? 2 : 0, condIdxStart-2)); + } + + return parseRule(split[split.length-3], split[split.length-1], allowedMultiRefs, allowCombinations, ctx, Arrays.copyOfRange(split, parsedExtendedHeader ? 2 : 0, split.length-3)); + } + + public static RewriterStatement parse(String expr, final RuleContext ctx, String... varDefinitions) { + return parse(expr, ctx, new HashMap<>(), varDefinitions); + } + + public static RewriterRule parseRule(String exprFrom, String exprTo, Set allowedMultiRefs, boolean allowCombinations, final RuleContext ctx, String... varDefinitions) { + return parseRule(exprFrom, exprTo, ctx, new HashMap<>(), allowedMultiRefs, allowCombinations, varDefinitions); + } + + public static RewriterRule parseRule(String exprFrom, List exprsTo, Set allowedMultiRefs, boolean allowCombinations, final RuleContext ctx, String... varDefinitions) { + return parseRule(exprFrom, exprsTo, ctx, new HashMap<>(), allowedMultiRefs, allowCombinations, true, varDefinitions); + } + + public static RewriterStatement parse(String expr, final RuleContext ctx, Map dataTypes, String... varDefinitions) { + for (String def : varDefinitions) + parseDataTypes(def, dataTypes, ctx); + + RewriterStatement parsed = parseExpression(expr, new HashMap<>(), dataTypes, ctx); + if (ctx.metaPropagator == null) + return parsed; + else { + RewriterStatement out = ctx.metaPropagator.apply(parsed); + out.prepareForHashing(); + out.recomputeHashCodes(ctx); + return out; + } + } + + public static RewriterRule parseRule(String exprFrom, String exprTo, final RuleContext ctx, Map dataTypes, Set allowedMultiRefs, boolean allowCombinations, String... varDefinitions) { + for (String def : varDefinitions) + parseDataTypes(def, dataTypes, ctx); + + HashMap mmap = new HashMap<>(); + + RewriterStatement parsedFrom = parseExpression(exprFrom, mmap, dataTypes, ctx); + RewriterStatement parsedTo = parseExpression(exprTo, mmap, dataTypes, ctx); + + if (ctx.metaPropagator != null) { + parsedFrom = ctx.metaPropagator.apply(parsedFrom); + parsedTo = ctx.metaPropagator.apply(parsedTo); + } + + return new RewriterRuleBuilder(ctx).completeRule(parsedFrom, parsedTo).withAllowedMultiRefs(allowedMultiRefs.stream().map(mmap::get).collect(Collectors.toSet()), allowCombinations).setUnidirectional(true).build(); + } + + public static RewriterRule parseRule(String exprFrom, List exprsTo, final RuleContext ctx, Map dataTypes, Set allowedMultiRefs, boolean allowCombinations, boolean asConditional, String... varDefinitions) { + if (!asConditional && exprsTo.size() > 1) + throw new IllegalArgumentException(); + + for (String def : varDefinitions) + parseDataTypes(def, dataTypes, ctx); + + HashMap mmap = new HashMap<>(); + + RewriterStatement parsedFrom = parseExpression(exprFrom, mmap, dataTypes, ctx); + if (ctx.metaPropagator != null) { + parsedFrom = ctx.metaPropagator.apply(parsedFrom); + } + + List parsedTos = new ArrayList<>(); + for (String exprTo : exprsTo) { + RewriterStatement parsedTo = parseExpression(exprTo, mmap, dataTypes, ctx); + + if (ctx.metaPropagator != null) { + parsedTo = ctx.metaPropagator.apply(parsedTo); + parsedTo.prepareForHashing(); + parsedTo.recomputeHashCodes(ctx); + } + + parsedTos.add(parsedTo); + } + + return new RewriterRuleBuilder(ctx) + .completeConditionalRule(parsedFrom, parsedTos) + .withAllowedMultiRefs(allowedMultiRefs.stream().map(mmap::get).collect(Collectors.toSet()), allowCombinations) + .setUnidirectional(true).build(); + } + + /** + * Parses an expression + * @param expr the expression string + * @param refmap test + * @param dataTypes data type + * @param ctx context + * @return test + */ + public static RewriterStatement parseExpression(String expr, Map refmap, Map dataTypes, final RuleContext ctx) { + RuleContext.currentContext = ctx; + expr = expr.replaceAll("\\s+", ""); + MutableObject mexpr = new MutableObject<>(expr); + RewriterStatement stmt = doParseExpression(mexpr, refmap, dataTypes, ctx); + stmt.prepareForHashing(); + stmt.consolidate(ctx); + return stmt; + } + + private static RewriterStatement doParseExpression(MutableObject mexpr, Map refmap, Map dataTypes, final RuleContext ctx) { + String expr = mexpr.getValue(); + if (expr.startsWith("$")) { + expr = expr.substring(1); + Pattern pattern = Pattern.compile("^\\d+"); + Matcher matcher = pattern.matcher(expr); + + if (matcher.find()) { + String number = matcher.group(); + int n = Integer.parseInt(number); + if (expr.charAt(matcher.end()) != ':') { + // Then we inject the common subexpression + String remainder = expr.substring(matcher.end()); + mexpr.setValue(remainder); + RewriterStatement var = refmap.get(n); + + if (var == null) + throw new IllegalArgumentException("Variable '$" + n + "' does not exist!"); + + return var; + } + String remainder = expr.substring(matcher.end() + 1); + mexpr.setValue(remainder); + RewriterStatement stmt = parseRawExpression(mexpr, refmap, dataTypes, ctx); + refmap.put(n, stmt); + return stmt; + } else { + throw new IllegalArgumentException("Expected a number"); + } + } else { + return parseRawExpression(mexpr, refmap, dataTypes, ctx); + } + } + + public static boolean parseDataTypes(String expr, Map dataTypes, final RuleContext ctx) { + RuleContext.currentContext = ctx; + Pattern pattern = Pattern.compile("([A-Za-z0-9]|_|\\.|\\*|\\?)([A-Za-z0-9]|_|\\.|\\*|-)*"); + Matcher matcher = pattern.matcher(expr); + + if (!matcher.find()) + return false; + + String dType = matcher.group(); + boolean intLiteral = dType.equals("LITERAL_INT"); + boolean boolLiteral = dType.equals("LITERAL_BOOL"); + boolean floatLiteral = dType.equals("LITERAL_FLOAT"); + + if (intLiteral) { + pattern = Pattern.compile("(-)?[0-9]+"); + } else if (boolLiteral) { + pattern = Pattern.compile("(TRUE|FALSE)"); + } else if (floatLiteral) { + pattern = Pattern.compile("((-)?([0-9]+(\\.[0-9]*)?(E(-)?[0-9]+)?|Infinity)|NaN)"); + } + + if (expr.charAt(matcher.end()) != ':') + return false; + + expr = expr.substring(matcher.end() + 1); + + matcher = pattern.matcher(expr); + + while (matcher.find()) { + String varName = matcher.group(); + + RewriterDataType dt; + + if (intLiteral) { + dt = new RewriterDataType().as(varName).ofType("INT").asLiteral(Long.parseLong(varName)); + } else if (boolLiteral) { + dt = new RewriterDataType().as(varName).ofType("BOOL").asLiteral(Boolean.parseBoolean(varName)); + } else if (floatLiteral) { + dt = new RewriterDataType().as(varName).ofType("FLOAT").asLiteral(Double.parseDouble(varName)); + } else { + dt = new RewriterDataType().as(varName).ofType(dType); + } + + dt.consolidate(ctx); + dataTypes.put(varName, dt); + + if (expr.length() == matcher.end()) + return true; + + if (expr.charAt(matcher.end()) != ',') + return false; + + expr = expr.substring(matcher.end()+1); + matcher = pattern.matcher(expr); + } + + return false; + } + + private static RewriterStatement parseRawExpression(MutableObject mexpr, Map refmap, Map dataTypes, final RuleContext ctx) { + String expr = mexpr.getValue(); + + Pattern pattern = Pattern.compile("^[^(),:]+"); + Matcher matcher = pattern.matcher(expr); + + if (matcher.find()) { + String token = matcher.group(); + String remainder = expr.substring(matcher.end()); + + if (remainder.isEmpty()) { + mexpr.setValue(remainder); + if (dataTypes.containsKey(token)) + return dataTypes.get(token); + throw new IllegalArgumentException("DataType: '" + token + "' doesn't exist"); + } + + + char nextChar = remainder.charAt(0); + + switch (nextChar) { + case '(': + // Then this is a function + if (remainder.charAt(1) == ')') { + RewriterInstruction mInstr = new RewriterInstruction().withInstruction(token).as(UUID.randomUUID().toString()); + handleSpecialInstructions(mInstr); + mInstr.consolidate(ctx); + mexpr.setValue(remainder.substring(2)); + return mInstr; + } else { + List opList = new ArrayList<>(); + mexpr.setValue(remainder.substring(1)); + RewriterStatement cstmt = doParseExpression(mexpr, refmap, dataTypes, ctx); + opList.add(cstmt); + + while (mexpr.getValue().charAt(0) == ',') { + mexpr.setValue(mexpr.getValue().substring(1)); + cstmt = doParseExpression(mexpr, refmap, dataTypes, ctx); + opList.add(cstmt); + } + + if (mexpr.getValue().charAt(0) != ')') + throw new IllegalArgumentException(mexpr.getValue()); + + mexpr.setValue(mexpr.getValue().substring(1)); + RewriterInstruction instr = new RewriterInstruction().withInstruction(token).withOps(opList.toArray(RewriterStatement[]::new)).as(UUID.randomUUID().toString()); + handleSpecialInstructions(instr); + instr.consolidate(ctx); + return instr; + } + case ')': + case ',': + mexpr.setValue(remainder); + if (dataTypes.containsKey(token)) + return dataTypes.get(token); + throw new IllegalArgumentException("DataType: '" + token + "' doesn't exist"); + default: + throw new NotImplementedException(); + } + } else { + throw new IllegalArgumentException(mexpr.getValue()); + } + } + + private static void handleSpecialInstructions(RewriterInstruction instr) { + if (instr.trueInstruction().equals("_m")) { + UUID ownerId = UUID.randomUUID(); + instr.unsafePutMeta("ownerId", ownerId); + + if (instr.getOperands().get(0).isInstruction() && instr.getOperands().get(0).trueInstruction().equals("_idx")) { + instr.getOperands().get(0).unsafePutMeta("ownerId", ownerId); + instr.getOperands().get(0).unsafePutMeta("idxId", UUID.randomUUID()); + } + + if (instr.getOperands().get(1).isInstruction() && instr.getOperands().get(1).trueInstruction().equals("_idx")) { + instr.getOperands().get(1).unsafePutMeta("ownerId", ownerId); + instr.getOperands().get(1).unsafePutMeta("idxId", UUID.randomUUID()); + } + } else if (instr.trueInstruction().equals("_idxExpr")) { + UUID ownerId = UUID.randomUUID(); + instr.unsafePutMeta("ownerId", ownerId); + + if (instr.getOperands().get(0).isInstruction() && instr.getOperands().get(0).trueInstruction().equals("_idx")) { + instr.getOperands().get(0).unsafePutMeta("ownerId", ownerId); + instr.getOperands().get(0).unsafePutMeta("idxId", UUID.randomUUID()); + } + } + } + + public static void buildBinaryAlgebraInstructions(StringBuilder sb, String instr, List instructions) { + for (String arg1 : instructions) { + for (String arg2 : instructions) { + sb.append(instr + "(" + arg1 + "," + arg2 + ")::"); + + if (arg1.equals("MATRIX") || arg2.equals("MATRIX")) + sb.append("MATRIX\n"); + else if (arg1.equals("FLOAT") || arg2.equals("FLOAT")) + sb.append("FLOAT\n"); + else + sb.append("INT\n"); + } + } + } + + public static void buildTernaryPermutations(List args, TriConsumer func) { + buildBinaryPermutations(args, (t1, t2) -> args.forEach(t3 -> func.accept(t1, t2, t3))); + } + + public static void buildBinaryPermutations(List args, BiConsumer func) { + buildBinaryPermutations(args, args, func); + } + + public static void buildBinaryPermutations(List args1, List args2, BiConsumer func) { + for (String arg1 : args1) + for (String arg2 : args2) + func.accept(arg1, arg2); + } + + public static String defaultTypeHierarchy(String t1, String t2) { + boolean is1ArgList = t1.endsWith("..."); + boolean is2ArgList = t2.endsWith("..."); + + if (is1ArgList) + t1 = t1.substring(0, t1.length() - 3); + + if (is2ArgList) + t2 = t2.substring(0, t2.length() - 3); + + if (t1.equals("BOOL") && t2.equals("BOOL")) + return "BOOL"; + if (t1.equals("INT") && (t2.equals("INT") || t2.equals("BOOL"))) + return "INT"; + + if (t2.equals("INT") && (t1.equals("INT") || t1.equals("BOOL"))) + return "INT"; + + if (!t1.equals("MATRIX") && !t2.equals("MATRIX")) + return "FLOAT"; + return "MATRIX"; + } + + public static String convertibleType(String t1, String t2) { + if (t1.equals("MATRIX") && t2.equals("MATRIX")) + return "MATRIX"; + + if (t1.equals("MATRIX") || t2.equals("MATRIX")) + return null; // Then it is not convertible + + if (!List.of("FLOAT", "INT", "BOOL").contains(t1) || !List.of("FLOAT", "INT", "BOOL").contains(t2)) + return null; + + if (t1.equals("FLOAT") || t2.equals("FLOAT")) + return "FLOAT"; // This is the most "general" type + + if (t1.equals("INT") || t2.equals("INT")) + return "INT"; + + return "BOOL"; + } + + public static String convertImplicitly(String type, boolean allowTypeConversions) { + if (!allowTypeConversions) + return type; + return convertImplicitly(type); + } + + public static String convertImplicitly(String type) { + if (type == null) + return null; + + if (type.equals("INT") || type.equals("BOOL")) + return "FLOAT"; + return type; + } + + public static void putAsBinaryPrintable(String instr, List types, HashMap> printFunctions, BiFunction function) { + for (String type1 : types) + for (String type2 : types) + printFunctions.put(instr + "(" + type1 + "," + type2 + ")", function); + } + + public static void putAsDefaultBinaryPrintable(List instrs, List types, HashMap> funcs) { + for (String instr : instrs) + putAsBinaryPrintable(instr, types, funcs, binaryStringRepr(" " + instr + " ")); + } + + // Updates the references (including metadata UUIDs) for a copied _idxExpr(args(_idx(...),...),...) + public static void copyIndexList(RewriterStatement idxExprRoot) { + if (!idxExprRoot.isInstruction() || !idxExprRoot.trueInstruction().equals("_idxExpr")) + throw new IllegalArgumentException(); + + Map replacements = new HashMap<>(); + UUID newOwnerId = UUID.randomUUID(); + idxExprRoot.unsafePutMeta("ownerId", newOwnerId); + + RewriterStatement newArgList = idxExprRoot.getChild(0).copyNode(); + idxExprRoot.getOperands().set(0, newArgList); + + List operands = newArgList.getOperands(); + + for (int i = 0; i < operands.size(); i++) { + RewriterStatement idx = operands.get(i); + RewriterStatement cpy = idx.copyNode(); + UUID newId = UUID.randomUUID(); + cpy.unsafePutMeta("idxId", newId); + cpy.unsafePutMeta("ownerId", newOwnerId); + replacements.put((UUID)idx.getMeta("idxId"), cpy); + operands.set(i, cpy); + } + + RewriterStatement out = RewriterUtils.replaceReferenceAware(idxExprRoot.getChild(1), stmt -> { + UUID idxId = (UUID) stmt.getMeta("idxId"); + if (idxId != null) { + RewriterStatement newStmt = replacements.get(idxId); + if (newStmt != null) + return newStmt; + } + + return null; + }); + + if (out != null) + idxExprRoot.getOperands().set(1, out); + } + + public static RewriterStatement replaceReferenceAware(RewriterStatement root, Function comparer) { + return replaceReferenceAware(root, false, comparer, new HashMap<>()); + } + + // Replaces elements in a DAG. If a parent item has multiple references, the entire path is duplicated + public static RewriterStatement replaceReferenceAware(RewriterStatement root, boolean duplicateReferences, Function comparer, HashMap visited) { + if (visited.containsKey(root)) + return visited.get(root); + + RewriterStatement newOne = comparer.apply(root); + + if (newOne == root) + newOne = null; + + root = newOne != null ? newOne : root; + + if (newOne == null) + duplicateReferences |= root.refCtr > 1; + + if (root.getOperands() != null) { + for (int i = 0; i < root.getOperands().size(); i++) { + RewriterStatement newSub = replaceReferenceAware(root.getOperands().get(i), duplicateReferences, comparer, visited); + + if (newSub != null) { + if (duplicateReferences && newOne == null) { + root = root.copyNode(); + newOne = root; + } + + root.getOperands().set(i, newSub); + } + } + } + + return newOne; + } + + // Deduplicates the DAG (removes duplicate references with new nodes except for leaf data-types) + public static void unfoldExpressions(RewriterStatement root, RuleContext ctx) { + for (int i = 0; i < root.getOperands().size(); i++) { + RewriterStatement child = root.getChild(i); + if (child.isInstruction() && child.refCtr > 1) { + if (!child.trueInstruction().equals("_idx") + && !child.trueInstruction().equals("_m") + && !child.trueInstruction().equals("idxExpr") + && !child.trueInstruction().equals("rand") + && !child.trueInstruction().equals("_EClass")) { + RewriterStatement cpy = child.copyNode(); + root.getOperands().set(i, cpy); + child.refCtr--; + cpy.getOperands().forEach(op -> op.refCtr++); + } + } + + unfoldExpressions(child, ctx); + } + } + + public static boolean cartesianProduct(List> list, T[] stack, Function emitter) { + if (list.size() == 0) + return false; + + if (list.size() == 1) { + list.get(0).forEach(t -> { + stack[0] = t; + emitter.apply(stack); + }); + return true; + } + + return _cartesianProduct(0, list, stack, emitter, new MutableBoolean(true)); + } + + private static boolean _cartesianProduct(int index, List> sets, T[] currentStack, Function emitter, MutableBoolean doContinue) { + if (index >= sets.size()) { + if (!emitter.apply(currentStack)) + doContinue.setValue(false); + return true; + } + + int size = sets.get(index).size(); + boolean matchFound = false; + + for (int i = 0; i < size; i++) { + currentStack[index] = sets.get(index).get(i); + matchFound |= _cartesianProduct(index+1, sets, currentStack, emitter, doContinue); + + if (!doContinue.booleanValue()) + return matchFound; + } + + return matchFound; + } + + public static boolean isImplicitlyConvertible(String typeFrom, String typeTo) { + if (typeFrom.equals(typeTo)) + return true; + + if (typeFrom.equals("INT") && typeTo.equals("FLOAT")) + return true; + + return false; + } + + public static boolean compareLiterals(RewriterDataType lit1, RewriterDataType lit2, boolean allowImplicitTypeConversions) { + if (allowImplicitTypeConversions) + return lit1.getLiteral().equals(literalAs(lit1.getType(), lit2)); + return lit1.getLiteral().equals(lit2.getLiteral()); + } + + public static Object literalAs(String type, RewriterDataType literal) { + switch (type) { + case "FLOAT": + return literal.floatLiteral(); + case "INT": + return literal.intLiteral(false); + case "BOOL": + return literal.boolLiteral(); + default: + return null; + } + } + + public static RuleContext buildDefaultContext() { + RuleContext ctx = RewriterContextSettings.getDefaultContext(); + ctx.metaPropagator = new MetaPropagator(ctx); + return ctx; + } + + private static RuleContext lastCtx; + private static Function lastUnfuse; + public static RewriterStatement unfuseOperators(RewriterStatement stmt, final RuleContext ctx) { + return unfuseOperators(ctx).apply(stmt); + } + public static Function unfuseOperators(final RuleContext ctx) { + if (lastCtx == ctx) + return lastUnfuse; + + ArrayList unfuseRules = new ArrayList<>(); + RewriterRuleCollection.substituteFusedOps(unfuseRules, ctx); + RewriterHeuristic heur = new RewriterHeuristic(new RewriterRuleSet(ctx, unfuseRules)); + lastCtx = ctx; + lastUnfuse = heur::apply; + return lastUnfuse; + } + + public static Function buildCanonicalFormConverter(final RuleContext ctx, boolean debug) { + return buildCanonicalFormConverter(ctx, true, debug); + } + + public static Function buildCanonicalFormConverter(final RuleContext ctx, boolean allowInversionCanonicalization, boolean debug) { + ArrayList algebraicCanonicalizationRules = new ArrayList<>(); + RewriterRuleCollection.substituteEquivalentStatements(algebraicCanonicalizationRules, ctx); + RewriterRuleCollection.eliminateMultipleCasts(algebraicCanonicalizationRules, ctx); + RewriterRuleCollection.canonicalizeBooleanStatements(algebraicCanonicalizationRules, ctx); + RewriterRuleCollection.canonicalizeAlgebraicStatements(algebraicCanonicalizationRules, allowInversionCanonicalization, ctx); + RewriterRuleCollection.eliminateMultipleCasts(algebraicCanonicalizationRules, ctx); + RewriterRuleCollection.buildElementWiseAlgebraicCanonicalization(algebraicCanonicalizationRules, ctx); + RewriterHeuristic algebraicCanonicalization = new RewriterHeuristic(new RewriterRuleSet(ctx, algebraicCanonicalizationRules)); + + ArrayList expRules = new ArrayList<>(); + RewriterRuleCollection.expandStreamingExpressions(expRules, ctx); + RewriterHeuristic streamExpansion = new RewriterHeuristic(new RewriterRuleSet(ctx, expRules)); + + ArrayList expArbitraryMatricesRules = new ArrayList<>(); + RewriterRuleCollection.expandArbitraryMatrices(expArbitraryMatricesRules, ctx); + RewriterHeuristic expandArbitraryMatrices = new RewriterHeuristic(new RewriterRuleSet(ctx, expArbitraryMatricesRules)); + + ArrayList pd = new ArrayList<>(); + RewriterRuleCollection.pushdownStreamSelections(pd, ctx); + RewriterRuleCollection.buildElementWiseAlgebraicCanonicalization(pd, ctx); + RewriterRuleCollection.eliminateMultipleCasts(pd, ctx); + RewriterRuleCollection.canonicalizeBooleanStatements(pd, ctx); + RewriterRuleCollection.canonicalizeAlgebraicStatements(pd, allowInversionCanonicalization, ctx); + RewriterHeuristic streamSelectPushdown = new RewriterHeuristic(new RewriterRuleSet(ctx, pd)); + + ArrayList flatten = new ArrayList<>(); + RewriterRuleCollection.flattenOperations(flatten, ctx); + RewriterHeuristic flattenOperations = new RewriterHeuristic(new RewriterRuleSet(ctx, flatten)); + + RewriterHeuristics canonicalFormCreator = new RewriterHeuristics(); + canonicalFormCreator.add("ALGEBRAIC CANONICALIZATION", algebraicCanonicalization); + canonicalFormCreator.add("EXPAND STREAMING EXPRESSIONS", streamExpansion); + canonicalFormCreator.add("EXPAND ARBITRARY MATRICES", expandArbitraryMatrices); + canonicalFormCreator.add("PUSHDOWN STREAM SELECTIONS", streamSelectPushdown); + canonicalFormCreator.add("FOLD CONSTANTS", new RewriterHeuristic(t -> foldConstants(t, ctx))); + //canonicalFormCreator.add("CANON ALGB", new RewriterHeuristic(new RewriterRuleSet(ctx, RewriterRuleCollection.buildElementWiseAlgebraicCanonicalization(new ArrayList<>(), ctx)))); + canonicalFormCreator.add("REPLACE NEGATIONS", new RewriterHeuristic(new RewriterRuleSet(ctx, RewriterRuleCollection.replaceNegation(new ArrayList<>(), ctx)))); + canonicalFormCreator.add("PUSHDOWN STREAM SELECTIONS", streamSelectPushdown); + canonicalFormCreator.add("FLATTEN OPERATIONS", flattenOperations); + + ArrayList canonicalExpand = new ArrayList<>(); + RewriterRuleCollection.canonicalExpandAfterFlattening(canonicalExpand, ctx); + RewriterHeuristic canonicalExpandOps = new RewriterHeuristic(new RewriterRuleSet(ctx, canonicalExpand)); + + ArrayList flattenAlgebraicRewriteList = new ArrayList<>(); + RewriterRuleCollection.flattenedAlgebraRewrites(flattenAlgebraicRewriteList, ctx); + RewriterHeuristic flattenedAlgebraicRewrites = new RewriterHeuristic(new RewriterRuleSet(ctx, flattenAlgebraicRewriteList)); + + RewriterHeuristics afterFlattening = new RewriterHeuristics(); + afterFlattening.add("CANONICAL EXPAND", canonicalExpandOps); + afterFlattening.add("FLATTENED ALGEBRA REWRITES", flattenedAlgebraicRewrites); + + return stmt -> { + stmt = stmt.nestedCopy(true); + stmt = canonicalFormCreator.apply(stmt, (t, r) -> { + if (!debug) + return true; + + if (r != null) + System.out.println("Applying rule: " + r.getName()); + System.out.println(t.toParsableString(ctx)); + return true; + }, debug); + + for (int i = 0; i < 2; i++) { + RewriterUtils.mergeArgLists(stmt, ctx); + stmt = RewriterUtils.pullOutConstants(stmt, ctx); + } + RewriterUtils.mergeArgLists(stmt, ctx); + unfoldExpressions(stmt, ctx); + stmt = RewriterUtils.pullOutConstants(stmt, ctx); + cleanupUnecessaryIndexExpressions(stmt, ctx); + stmt.prepareForHashing(); + stmt.recomputeHashCodes(ctx); + + stmt = afterFlattening.apply(stmt, (t, r) -> { + if (!debug) + return true; + + if (r != null) + System.out.println("Applying rule: " + r.getName()); + System.out.println(t.toParsableString(ctx)); + return true; + }, debug); + + stmt = foldConstants(stmt, ctx); + + for (int i = 0; i < 2; i++) { + RewriterUtils.mergeArgLists(stmt, ctx); + stmt = RewriterUtils.pullOutConstants(stmt, ctx); + } + RewriterUtils.mergeArgLists(stmt, ctx); + + stmt = stmt.getAssertions(ctx).cleanupEClasses(stmt); + unfoldExpressions(stmt, ctx); + stmt.prepareForHashing(); + + if (debug) + System.out.println("PRE1: " + stmt.toParsableString(ctx, false)); + + stmt.compress(); // To remove unnecessary metadata such as assertions that are not encoded in the graph + TopologicalSort.sort(stmt, ctx); + + if (debug) + System.out.println("FINAL1: " + stmt.toParsableString(ctx, false)); + + return stmt; + }; + } + + public static RewriterStatement pullOutConstants(RewriterStatement oldRoot, final RuleContext ctx) { + RewriterStatement newRoot = pullOutConstantsRecursively(oldRoot, ctx, new HashMap<>()); + + // Check if we have to move the assertions to new root + if (newRoot != oldRoot) + oldRoot.moveRootTo(newRoot); + + return newRoot; + } + + private static RewriterStatement pullOutConstantsRecursively(RewriterStatement cur, final RuleContext ctx, Map alreadyModified) { + if (!cur.isInstruction()) + return cur; + + RewriterStatement modified = alreadyModified.get(cur); + + if (modified != null) + return modified; + + alreadyModified.put(cur, cur); + + for (int i = 0; i < cur.getOperands().size(); i++) + cur.getOperands().set(i, pullOutConstantsRecursively(cur.getChild(i), ctx, alreadyModified)); + + cur.updateMetaObjects(el -> pullOutConstantsRecursively(el, ctx, alreadyModified)); + + switch (cur.trueInstruction()) { + case "sum": + return tryPullOutSum(cur, ctx); + } + + return cur; + } + + private static RewriterStatement tryPullOutSum(RewriterStatement sum, final RuleContext ctx) { + // TODO: What happens on multi-index? Then, some unnecessary indices will currently not be pulled out + RewriterStatement idxExpr = sum.getChild(0); + UUID ownerId = (UUID) idxExpr.getMeta("ownerId"); + RewriterStatement sumBody = idxExpr.getChild(1); + + Map checked = new HashMap<>(); + + + if (!checkSubgraphDependency(sumBody, ownerId, checked)) { + // Then we have to remove the sum entirely + List indices = idxExpr.getChild(0).getOperands(); + List components = new ArrayList<>(); + + for (RewriterStatement idx : indices) { + if (idx.isLiteral()) + continue; + RewriterStatement idxFrom = idx.getChild(0); + RewriterStatement idxTo = idx.getChild(1); + RewriterStatement negation = new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("-").withOps(/*RewriterStatement.ensureFloat(ctx, idxFrom)*/idxFrom).consolidate(ctx); + RewriterStatement add = RewriterStatement.multiArgInstr(ctx, "+", /*RewriterStatement.ensureFloat(ctx, idxTo)*/idxTo, RewriterStatement.literal(ctx, 1.0D), negation); + components.add(add); + } + + RewriterStatement out = RewriterStatement.multiArgInstr(ctx, "*", sumBody); + out.getChild(0).getOperands().addAll(components); + return foldConstants(out, ctx); + } + + if (isDirectlyDependent(sumBody, ownerId)) + return sum; + + if (sumBody.trueInstruction().equals("*")) { + // We have to assume here, that this instruction is not referenced anywhere else in the graph + List argList = sumBody.getChild(0).getOperands(); + List toRemove = new ArrayList<>(argList.size()); + + for (RewriterStatement stmt : argList) { + if (!checkSubgraphDependency(stmt, ownerId, checked)) + toRemove.add(stmt); + } + + if (!toRemove.isEmpty()) { + argList.removeAll(toRemove); + + if (argList.size() == 1) { + idxExpr.getOperands().set(1, argList.get(0)); + } + + toRemove.add(sum); + + return RewriterStatement.multiArgInstr(ctx, "*", toRemove.toArray(RewriterStatement[]::new)); + } + } else if (sumBody.trueInstruction().equals("+")) { + // TODO: What about sum(+(A, *(a, B)))? We could pull out a + + // We have to assume here, that this instruction is not referenced anywhere else in the graph + List argList = sumBody.getChild(0).getOperands(); + List toRemove = new ArrayList<>(argList.size()); + + for (RewriterStatement stmt : argList) { + if (!checkSubgraphDependency(stmt, ownerId, checked)) + toRemove.add(stmt); + } + + if (!toRemove.isEmpty()) { + argList.removeAll(toRemove); + + if (argList.size() == 1) { + idxExpr.getOperands().set(1, argList.get(0)); + } + + RewriterStatement outerSum = RewriterStatement.multiArgInstr(ctx, "+", toRemove.toArray(RewriterStatement[]::new)); + List mul = new ArrayList<>(); + + for (RewriterStatement idx : idxExpr.getChild(0).getOperands()) { + RewriterStatement neg = new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("-").withOps(/*RewriterStatement.ensureFloat(ctx, idx.getChild(0))*/idx.getChild(0)).consolidate(ctx); + RewriterStatement msum = RewriterStatement.multiArgInstr(ctx, "+", /*RewriterStatement.ensureFloat(ctx, idx.getChild(1))*/idx.getChild(1), neg, RewriterStatement.literal(ctx, 1.0)); + mul.add(msum); + } + + mul.add(outerSum); + RewriterStatement mulStmt = RewriterStatement.multiArgInstr(ctx, "*", mul.toArray(RewriterStatement[]::new)); + + return RewriterStatement.multiArgInstr(ctx, "+", mulStmt, sum); + } + } + + return sum; + } + + // Returns true if the subgraph is dependent on the corresponding owner + private static boolean checkSubgraphDependency(RewriterStatement expr, UUID id, Map checked) { + Boolean b = checked.get(expr); + + if (b != null) + return b; + + if (expr.isInstruction() && expr.trueInstruction().equals("_idx")) { + UUID mid = (UUID) expr.getMeta("ownerId"); + boolean isDependent = id.equals(mid); + + if (isDependent) { + checked.put(expr, true); + return true; + } + } + + for (RewriterStatement stmt : expr.getOperands()) { + if (checkSubgraphDependency(stmt, id, checked)) { + checked.put(expr, true); + return true; + } + } + + checked.put(expr, false); + return false; + } + + private static boolean isDirectlyDependent(RewriterStatement child, UUID ownerId) { + if (child.isInstruction() && child.trueInstruction().equals("_idx")) { + UUID mid = (UUID) child.getMeta("_ownerId"); + return ownerId.equals(mid); + } + + return false; + } + + public static RewriterStatement foldConstants(RewriterStatement stmt, final RuleContext ctx) { + Map replaced = new HashMap<>(); + RewriterStatement ret = foldConstantsRecursively(stmt, ctx, replaced); + ret.prepareForHashing(); + ret.recomputeHashCodes(ctx); + return ret; + } + + private static RewriterStatement foldConstantsRecursively(RewriterStatement cur, final RuleContext ctx, Map alreadyFolded) { + if (!cur.isInstruction()) + return cur; + + RewriterStatement folded = alreadyFolded.get(cur); + + if (folded != null) + return folded; + + alreadyFolded.put(cur, cur); + + for (int i = 0; i < cur.getOperands().size(); i++) + cur.getOperands().set(i, foldConstantsRecursively(cur.getChild(i), ctx, alreadyFolded)); + + cur.updateMetaObjects(el -> foldConstantsRecursively(el, ctx, alreadyFolded)); + + RewriterStatement ret = cur; + + switch (cur.trueInstruction()) { + case "+": + case "*": + case "min": + case "max": + ret = foldNaryReducible(cur, ctx); + break; + case "_EClass": + ret = foldEClass(cur, ctx); + break; + default: + if (cur.getOperands().size() == 1) + ret = foldUnary(cur, ctx); + break; + } + + ret.refreshReturnType(ctx); + alreadyFolded.put(cur, ret); + return ret; + } + + private static RewriterStatement foldEClass(RewriterStatement stmt, final RuleContext ctx) { + RewriterStatement lit = stmt.getLiteralStatement(); + if (lit != null) + return lit; + return stmt; + } + + private static RewriterStatement foldNaryReducible(RewriterStatement stmt, final RuleContext ctx) { + List argList; + if (stmt.getChild(0).isArgumentList()) + argList = stmt.getChild(0).getOperands(); + else + argList = stmt.getOperands(); + + if (argList.isEmpty()) + throw new IllegalArgumentException(stmt.toString(ctx)); + + if (stmt.isInstruction() && (stmt.trueInstruction().equals("min") || stmt.trueInstruction().equals("max")) && argList.size() == 1 && !List.of("FLOAT", "INT", "BOOL").contains(argList.get(0).getResultingDataType(ctx))) + return stmt; + + if (argList.size() < 2) + return argList.get(0); + + int[] literals = IntStream.range(0, argList.size()).filter(i -> argList.get(i).isLiteral()).toArray(); + + if (literals.length == 1) { + Object literal = argList.get(literals[0]).getLiteral(); + if (literal instanceof Number) { + RewriterStatement overwrite = ConstantFoldingUtils.overwritesLiteral((Number) literal, stmt.trueInstruction(), ctx); + if (overwrite != null) + return overwrite; + } + + // Check if is neutral element + if (ConstantFoldingUtils.isNeutralElement(argList.get(literals[0]).getLiteral(), stmt.trueInstruction())) { + RewriterStatement neutral = argList.get(literals[0]); + argList.remove(literals[0]); + + if (argList.size() == 1) + return argList.get(0); + else if (argList.isEmpty()) + return neutral; + } + } + + if (literals.length < 2) + return stmt; + + String rType = stmt.getResultingDataType(ctx); + + BiFunction foldingFunction = ConstantFoldingUtils.foldingBiFunction(stmt.trueInstruction(), rType); + + RewriterDataType foldedLiteral = new RewriterDataType(); + Number val = null; + + for (int literal : literals) + val = foldingFunction.apply(val, argList.get(literal)); + + + RewriterStatement overwrite = ConstantFoldingUtils.overwritesLiteral(val, stmt.trueInstruction(), ctx); + if (overwrite != null) + return overwrite; + + foldedLiteral.as(val.toString()).ofType(rType).asLiteral(val).consolidate(ctx); + + argList.removeIf(RewriterStatement::isLiteral); + + if (argList.isEmpty() || !ConstantFoldingUtils.isNeutralElement(foldedLiteral.getLiteral(), stmt.trueInstruction())) + argList.add(foldedLiteral); + + ConstantFoldingUtils.cancelOutNary(stmt.trueInstruction(), argList); + + if (argList.size() == 1) + return argList.get(0); + + return stmt; + } + + private static RewriterStatement foldUnary(RewriterStatement stmt, final RuleContext ctx) { + RewriterStatement child = stmt.getChild(0); + + if (!child.isLiteral()) + return stmt; + + boolean isFloat = stmt.getResultingDataType(ctx).equals("FLOAT"); + + switch (stmt.trueInstruction()) { + case "inv": + if (isFloat) + return RewriterStatement.literal(ctx, 1.0 / child.floatLiteral()); + else + return RewriterStatement.literal(ctx, 1L / child.intLiteral()); + case "-": + if (isFloat) + return RewriterStatement.literal(ctx, -child.floatLiteral()); + else + return RewriterStatement.literal(ctx, -child.intLiteral()); + } + + // Not implemented yet + return stmt; + } + + public static RewriterStatement cleanupUnecessaryIndexExpressions(RewriterStatement stmt, final RuleContext ctx) { + RewriterStatement mNew = cleanupIndexExprRecursively(stmt, ctx); + + if (mNew != null) + stmt.moveRootTo(mNew); + + recursivePostCleanup(mNew != null ? mNew : stmt); + + return mNew; + } + + private static RewriterStatement cleanupIndexExprRecursively(RewriterStatement cur, final RuleContext ctx) { + for (int i = 0; i < cur.getOperands().size(); i++) { + RewriterStatement mNew = cleanupIndexExprRecursively(cur.getChild(i), ctx); + + if (mNew != null) + cur.getOperands().set(i, mNew); + } + + return cleanupIndexExpr(cur); + } + + private static void recursivePostCleanup(RewriterStatement cur) { + for (RewriterStatement child : cur.getOperands()) + recursivePostCleanup(child); + + postCleanupIndexExpr(cur); + } + + private static RewriterStatement cleanupIndexExpr(RewriterStatement cur) { + if (!cur.isInstruction() || !cur.trueInstruction().equals("sum")) + return null; + + RewriterStatement base = cur; + cur = cur.getChild(0); + + if (!cur.isInstruction() || !cur.trueInstruction().equals("_idxExpr")) + return null; + + if (!cur.getChild(1).isInstruction() || !cur.getChild(1).trueInstruction().equals("ifelse") || !cur.getChild(1,2).isLiteral() || cur.getChild(1,2).floatLiteral() != 0.0D) + return null; + + RewriterStatement query = cur.getChild(1, 0); + + if (query.isInstruction() && query.trueInstruction().equals("==")) { + RewriterStatement idx1 = query.getChild(0); + RewriterStatement idx2 = query.getChild(1); + + if (idx1.isInstruction() && idx2.isInstruction() && idx1.trueInstruction().equals("_idx") && idx2.trueInstruction().equals("_idx")) { + List indices = cur.getChild(0).getOperands(); + RewriterStatement indexFromUpperLevel = null; + if (idx1 == idx2) { + cur.getOperands().set(1, cur.getChild(1, 1)); + } else if (indices.contains(idx1)) { + boolean removed = indices.remove(idx2); + indexFromUpperLevel = removed ? null : idx2; + + if (removed) { + cur.getOperands().set(1, cur.getChild(1, 1)); + cur.getChild(1).forEachPreOrder(cur2 -> { + for (int i = 0; i < cur2.getOperands().size(); i++) { + if (cur2.getChild(i).equals(idx2)) + cur2.getOperands().set(i, idx1); + } + + return true; + }, true); + } + } else if (indices.contains(idx2)) { + indexFromUpperLevel = idx1; + } + + if (indexFromUpperLevel != null) { + cur.getOperands().set(1, cur.getChild(1, 1)); + final RewriterStatement fIdxUpperLevel = indexFromUpperLevel; + final RewriterStatement fIdxLowerLevel = idx1 == indexFromUpperLevel ? idx2 : idx1; + cur.getChild(1).forEachPreOrder(cur2 -> { + for (int i = 0; i < cur2.getOperands().size(); i++) { + if (cur2.getChild(i).equals(fIdxLowerLevel)) + cur2.getOperands().set(i, fIdxUpperLevel); + } + + return true; + }, true); + indices.remove(idx2); + } + + if (indices.isEmpty()) { + return cur.getChild(1); + } + } + } + + return base; + } + + // To unify ifelse (e.g. ifelse(a == b, a+b, a-b) => ifelse(a == b, a+a, a-b) + private static void postCleanupIndexExpr(RewriterStatement cur) { + if (!cur.isInstruction() || !cur.trueInstruction().equals("ifelse") || !cur.getChild(2).isLiteral() || cur.getChild(2).floatLiteral() != 0.0D) + return; + + RewriterStatement query = cur.getChild(0); + + if (query.isInstruction() && query.trueInstruction().equals("==")) { + RewriterStatement idx1 = query.getChild(0); + RewriterStatement idx2 = query.getChild(1); + + if (idx1.isInstruction() && idx2.isInstruction() && idx1.trueInstruction().equals("_idx") && idx2.trueInstruction().equals("_idx")) { + // Then we just choose the first index + cur.getChild(1).forEachPreOrder(cur2 -> { + for (int i = 0; i < cur2.getOperands().size(); i++) { + if (cur2.getChild(i).equals(idx2)) + cur2.getOperands().set(i, idx1); + } + + return true; + }, true); + cur.getChild(2).forEachPreOrder(cur2 -> { + for (int i = 0; i < cur2.getOperands().size(); i++) { + if (cur2.getChild(i).equals(idx2)) + cur2.getOperands().set(i, idx1); + } + + return true; + }, true); + } + } + } + + public static void renameIllegalVarnames(final RuleContext ctx, RewriterStatement... stmts) { + MutableInt matrixVarCtr = new MutableInt(0); + MutableInt scalarVarCtr = new MutableInt(0); + + Set varnames = new HashSet<>(); + for (RewriterStatement stmt : stmts) { + stmt.forEachPreOrder(cur -> { + if (cur.isInstruction()) + return true; + + varnames.add(cur.getId()); + return true; + }, false); + } + + for (RewriterStatement stmt : stmts) { + stmt.forEachPreOrder(cur -> { + if (cur.isInstruction() || cur.isLiteral()) + return true; + + boolean isMatrix = cur.getResultingDataType(ctx).equals("MATRIX"); + + if (cur.getId().equals("?")) { + cur.rename(getVarname(varnames, isMatrix ? matrixVarCtr : scalarVarCtr, isMatrix)); + return true; + } + + if (cur.getId().contains("_")) { + cur.rename(getVarname(varnames, isMatrix? matrixVarCtr : scalarVarCtr, isMatrix)); + } + + try { + UUID.fromString(cur.getId()); + // If it could parse, then we should rename + cur.rename(getVarname(varnames, isMatrix ? matrixVarCtr : scalarVarCtr, isMatrix)); + return true; + } catch (Exception e) { + // Then this is not a UUID + } + + return true; + }, false); + } + } + + private static String getVarname(Set existingNames, MutableInt mInt, boolean matrix) { + char origChar; + + if (matrix) + origChar = 'A'; + else + origChar = 'a'; + + char ch = (char)(origChar + mInt.getAndIncrement()); + + while (existingNames.contains(String.valueOf(ch))) + ch = (char)(origChar + mInt.getAndIncrement()); + + return String.valueOf(ch); + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/utils/StatementUtils.java b/src/main/java/org/apache/sysds/hops/rewriter/utils/StatementUtils.java new file mode 100644 index 00000000000..055e2691bfb --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/utils/StatementUtils.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter.utils; + +import org.apache.sysds.hops.rewriter.RewriterInstruction; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.RuleContext; + +public class StatementUtils { + public static RewriterStatement max(final RuleContext ctx, RewriterStatement... of) { + if (of.length == 1) + return of[0]; + + if (of.length == 2) + return new RewriterInstruction("max", ctx, of); + + throw new UnsupportedOperationException(); + } + + public static RewriterStatement min(final RuleContext ctx, RewriterStatement... of) { + if (of.length == 1) + return of[0]; + + if (of.length == 2) + return new RewriterInstruction("min", ctx, of); + + throw new UnsupportedOperationException(); + } + + public static RewriterStatement length(final RuleContext ctx, RewriterStatement matrix) { + if (!matrix.getResultingDataType(ctx).equals("MATRIX")) + throw new IllegalArgumentException(matrix.toParsableString(ctx)); + + return new RewriterInstruction("*", ctx, matrix.getNRow(), matrix.getNCol()); + } + + public static RewriterStatement add(final RuleContext ctx, RewriterStatement... terms) { + if (terms.length == 1) + return terms[0]; + + return new RewriterInstruction("+", ctx, new RewriterInstruction("argList", ctx, terms)); + } +} diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java b/src/main/java/org/apache/sysds/utils/Statistics.java index e6fdf5db3cd..561a99a7d3a 100644 --- a/src/main/java/org/apache/sysds/utils/Statistics.java +++ b/src/main/java/org/apache/sysds/utils/Statistics.java @@ -47,6 +47,7 @@ import org.apache.sysds.utils.stats.RecompileStatistics; import org.apache.sysds.utils.stats.SparkStatistics; import org.apache.sysds.utils.stats.TransformStatistics; +import scala.Tuple2; import java.lang.management.CompilationMXBean; import java.lang.management.GarbageCollectorMXBean; @@ -338,6 +339,35 @@ public static void stopRunTimer() { public static long getRunTime() { return execEndTime - execStartTime; } + + private static HashMap appliedGeneratedRewrites = new HashMap<>(); + private static HashMap, Integer> appliedGeneratedRewritesCounts = new HashMap<>(); + private static boolean recordGeneratedRewrites = false; + private static String currentTestName = ""; + + public static void recordAppliedGeneratedRewrites(boolean record) { + recordGeneratedRewrites = record; + } + + public static void applyGeneratedRewrite(String rewrite) { + if (recordGeneratedRewrites) { + appliedGeneratedRewrites.compute(rewrite, (k, v) -> v == null ? 1 : v + 1); + if (!currentTestName.isEmpty()) + appliedGeneratedRewritesCounts.compute(new Tuple2<>(rewrite, currentTestName), (k, v) -> v == null ? 1 : v + 1); + } + } + + public static Map getAppliedRewrites() { + return appliedGeneratedRewrites; + } + + public static Map, Integer> getAdvancedAppliedRewrites() { + return appliedGeneratedRewritesCounts; + } + + public static void setCurrentTestName(String testName) { + currentTestName = testName; + } public static void reset() { diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java index 6b280301afb..8e496c189f8 100644 --- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java +++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java @@ -31,6 +31,8 @@ import java.net.InetSocketAddress; import java.net.ServerSocket; import java.nio.charset.Charset; +import java.nio.file.Files; +import java.nio.file.Paths; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; @@ -57,6 +59,8 @@ import org.apache.sysds.conf.DMLConfig; import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.hops.fedplanner.FTypes.FType; +import org.apache.sysds.hops.rewriter.generated.RewriteAutomaticallyGenerated; +import org.apache.sysds.hops.rewriter.RewriterRuntimeUtils; import org.apache.sysds.lops.Lop; import org.apache.sysds.lops.compile.Dag; import org.apache.sysds.parser.ParseException; @@ -90,6 +94,7 @@ import org.junit.After; import org.junit.Assert; import org.junit.Before; +import scala.Tuple4; /** *

@@ -105,6 +110,51 @@ * */ public abstract class AutomatedTestBase { + protected static final boolean RECORD_GENERATED_REWRITES = false; + protected static final boolean ALLOW_GENERATED_REWRITES = false; + protected static final String BASE_DATA_DIR = null; + + + ///// THESE SHOULD NOT BE MODIFIED ///// + private static String currentTestName = ""; + + + static { + RewriterRuntimeUtils.setupIfNecessary(); + + if (RECORD_GENERATED_REWRITES) { + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + StringBuilder csvBuilder2 = new StringBuilder(); + csvBuilder2.append("Rewrite;Count\n"); + + Statistics.getAppliedRewrites().forEach((k, v) -> { + csvBuilder2.append(k); + csvBuilder2.append(';'); + csvBuilder2.append(v); + csvBuilder2.append('\n'); + }); + + StringBuilder csvBuilder3 = new StringBuilder(); + csvBuilder3.append("Rewrite;TestName;Count\n"); + + Statistics.getAdvancedAppliedRewrites().forEach((k, v) -> { + csvBuilder3.append(k._1); + csvBuilder3.append(';'); + csvBuilder3.append(k._2); + csvBuilder3.append(';'); + csvBuilder3.append(v); + csvBuilder3.append('\n'); + }); + + try { + Files.writeString(Paths.get(BASE_DATA_DIR + "applied_rewrites.csv"), csvBuilder2.toString()); + Files.writeString(Paths.get(BASE_DATA_DIR + "rewrite_info.csv"), csvBuilder3.toString()); + } catch (IOException e) { + e.printStackTrace(); + } + })); + } + } private static final Log LOG = LogFactory.getLog(AutomatedTestBase.class.getName()); @@ -1139,6 +1189,9 @@ protected void runRScript() { */ protected void runRScript(boolean newWay) { + if (RewriterRuntimeUtils.interceptAll) + return; + String executionFile = sourceDirectory + selectedTest + ".R"; if(fullRScriptName != null) executionFile = fullRScriptName; @@ -1388,6 +1441,21 @@ protected ByteArrayOutputStream runTest(boolean newWay, boolean exceptionExpecte String errMessage, int maxSparkInst) { try{ final List out = new ArrayList<>(); + + if (RECORD_GENERATED_REWRITES) { + if (currentTestName == null || !currentTestName.equals(this.getClass().getSimpleName())) { + currentTestName = this.getClass().getSimpleName(); + } + + Statistics.reset(); + RewriteAutomaticallyGenerated.totalTimeNanos = 0; + RewriteAutomaticallyGenerated.callCount = 0; + RewriteAutomaticallyGenerated.maxTimeNanos = -1; + + Statistics.recordAppliedGeneratedRewrites(true); + Statistics.setCurrentTestName(currentTestName); + } + Thread t = new Thread( () -> out.add(runTestWithTimeout(newWay, exceptionExpected, expectedException, errMessage, maxSparkInst)), "TestRunner_main"); @@ -1437,6 +1505,10 @@ private ByteArrayOutputStream runTestWithTimeout(boolean newWay, boolean excepti cleanupScratchSpace(); ArrayList args = new ArrayList<>(); + if (ALLOW_GENERATED_REWRITES) { + args.add("-applyGeneratedRewrites"); + } + // setup arguments to SystemDS if(DEBUG) { args.add("-Dsystemds.logging=trace"); diff --git a/src/test/java/org/apache/sysds/test/applications/L2SVMTest.java b/src/test/java/org/apache/sysds/test/applications/L2SVMTest.java index 534b058425a..42bf618f5a2 100644 --- a/src/test/java/org/apache/sysds/test/applications/L2SVMTest.java +++ b/src/test/java/org/apache/sysds/test/applications/L2SVMTest.java @@ -49,7 +49,7 @@ public L2SVMTest(int rows, int cols, double sp, boolean intercept) { numRecords = rows; numFeatures = cols; sparsity = sp; - intercept = this.intercept; + this.intercept = intercept; } @Parameters diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterNormalFormTests.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterNormalFormTests.java new file mode 100644 index 00000000000..add648bbc62 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterNormalFormTests.java @@ -0,0 +1,561 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.codegen.rewrite; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.apache.sysds.hops.rewriter.TopologicalSort; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.util.function.Function; + +public class RewriterNormalFormTests { + protected static final Log LOG = LogFactory.getLog(RewriterNormalFormTests.class.getName()); + + private static RuleContext ctx; + private static Function canonicalConverter; + + @BeforeClass + public static void setup() { + ctx = RewriterUtils.buildDefaultContext(); + canonicalConverter = RewriterUtils.buildCanonicalFormConverter(ctx, false); + } + + //e.g., matrix(1,nrow(X),ncol(X))/X -> 1/X + @Test + public void testUnnecessaryVectorize() { + RewriterStatement stmt1 = RewriterUtils.parse("/(const(A, 1.0), A)", ctx, "MATRIX:A", "LITERAL_FLOAT:1.0"); + RewriterStatement stmt2 = RewriterUtils.parse("/(1.0, A)", ctx, "MATRIX:A", "LITERAL_FLOAT:1.0"); + + assert match(stmt1, stmt2); + } + + @Test + public void testRemoveUnnecessaryBinaryOperation() { + RewriterStatement stmt1 = RewriterUtils.parse("*(1.0, A)", ctx, "MATRIX:A", "LITERAL_FLOAT:1.0"); + RewriterStatement stmt2 = RewriterUtils.parse("A", ctx, "MATRIX:A", "LITERAL_FLOAT:1.0"); + + assert match(stmt1, stmt2); + } + + @Test + public void testFuseDatagenAndBinaryOperation() { + RewriterStatement stmt1 = RewriterUtils.parse("*(rand(nrow(A), ncol(A), -1.0, 1.0), a)", ctx, "MATRIX:A", "FLOAT:a", "LITERAL_FLOAT:1.0,-1.0"); + RewriterStatement stmt2 = RewriterUtils.parse("rand(nrow(A), ncol(A), -(a), a)", ctx, "MATRIX:A", "FLOAT:a"); + + assert match(stmt1, stmt2); + } + + //@Test + public void testFuseDatagenAndMinusOperation() { + RewriterStatement stmt1 = RewriterUtils.parse("-(rand(nrow(A), ncol(A), -2.0, 1.0))", ctx, "MATRIX:A", "LITERAL_FLOAT:1.0,-2.0"); + RewriterStatement stmt2 = RewriterUtils.parse("rand(nrow(A), ncol(A), -1.0, 2.0)", ctx, "MATRIX:A", "LITERAL_FLOAT:-1.0,2.0"); + + assert match(stmt1, stmt2); + } + + @Test + public void testCanonicalizeMatrixMultScalarAdd() { + RewriterStatement stmt1 = RewriterUtils.parse("+(eps, %*%(A, t(B)))", ctx, "MATRIX:A,B", "FLOAT:eps"); + RewriterStatement stmt2 = RewriterUtils.parse("+(%*%(A, t(B)), eps)", ctx, "MATRIX:A,B", "FLOAT:eps"); + + assert match(stmt1, stmt2); + } + + @Test + public void testCanonicalizeMatrixMultScalarAdd2() { + RewriterStatement stmt1 = RewriterUtils.parse("-(%*%(A, t(B)), eps)", ctx, "MATRIX:A,B", "FLOAT:eps"); + RewriterStatement stmt2 = RewriterUtils.parse("+(%*%(A, t(B)), -(eps))", ctx, "MATRIX:A,B", "FLOAT:eps"); + + assert match(stmt1, stmt2); + } + + @Test + public void testSimplifyMultiBinaryToBinaryOperation() { + RewriterStatement stmt1 = RewriterUtils.parse("-(1.0, *(A,B))", ctx, "MATRIX:A,B", "LITERAL_FLOAT:1.0"); + RewriterStatement stmt2 = RewriterUtils.parse("1-*(A, B)", ctx, "MATRIX:A,B", "LITERAL_FLOAT:1.0"); + + assert match(stmt1, stmt2); + } + + @Test + public void testSimplifyDistributiveBinaryOperation() { + RewriterStatement stmt1 = RewriterUtils.parse("-(A, *(B,A))", ctx, "MATRIX:A,B,C", "LITERAL_FLOAT:1.0"); + RewriterStatement stmt2 = RewriterUtils.parse("*(-(1.0,B), A)", ctx, "MATRIX:A,B,C", "LITERAL_FLOAT:1.0"); + + assert match(stmt1, stmt2); + } + + @Test + public void testSimplifyBushyBinaryOperation() { + RewriterStatement stmt1 = RewriterUtils.parse("*(A,*(B, %*%(C, colVec(D))))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0"); + RewriterStatement stmt2 = RewriterUtils.parse("*(*(A,B), %*%(C, colVec(D)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert RewriterStatement.MatcherContext.exactMatch(ctx, stmt1, stmt2).match(); + } + + @Test + public void testSimplifyUnaryAggReorgOperation() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(t(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0"); + RewriterStatement stmt2 = RewriterUtils.parse("sum(A)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0"); + + assert match(stmt1, stmt2); + } + + @Test + public void testRemoveUnnecessaryAggregates() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(rowSums(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0"); + RewriterStatement stmt2 = RewriterUtils.parse("sum(A)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0"); + + assert match(stmt1, stmt2); + } + + @Test + public void testSimplifyBinaryMatrixScalarOperation() { + RewriterStatement stmt1 = RewriterUtils.parse("as.scalar(*(A,a))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0"); + RewriterStatement stmt2 = RewriterUtils.parse("*(as.scalar(A),a)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0"); + + assert match(stmt1, stmt2); + } + + @Test + public void testPushdownUnaryAggTransposeOperation() { + RewriterStatement stmt1 = RewriterUtils.parse("colSums(t(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0"); + RewriterStatement stmt2 = RewriterUtils.parse("t(rowSums(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0"); + + assert match(stmt1, stmt2); + } + + @Test + public void testPushdownCSETransposeScalarOperation() { + // Introduce a dummy instruction * as I don't support the assignment operator + RewriterStatement stmt1 = RewriterUtils.parse("*(t(A), t(sq(A)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0,2.0"); + RewriterStatement stmt2 = RewriterUtils.parse("*(t(A), sq(t(A)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0,2.0"); + + assert match(stmt1, stmt2); + } + + @Test + public void testPushdownSumBinaryMult() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(*(a,A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0,2.0"); + RewriterStatement stmt2 = RewriterUtils.parse("*(a, sum(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0,2.0"); + + assert match(stmt1, stmt2); + } + + @Test + public void testSimplifyTraceMatrixMult() { + RewriterStatement stmt1 = RewriterUtils.parse("trace(%*%(A,B))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0,2.0"); + RewriterStatement stmt2 = RewriterUtils.parse("sum(*(A, t(B)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0,2.0"); + + assert match(stmt1, stmt2); + } + + @Test + public void testSimplifySlicedMatrixMult() { + RewriterStatement stmt1 = RewriterUtils.parse("[](%*%(A,B), 1, 1)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0,2.0", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("as.scalar(%*%(rowVec(A), colVec(B)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0,2.0", "LITERAL_INT:1"); + + assert match(stmt1, stmt2); + } + + @Test + public void testRemoveUnnecessaryReorgOperation() { + RewriterStatement stmt1 = RewriterUtils.parse("t(t(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0,2.0", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("A", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0,2.0", "LITERAL_INT:1"); + + assert match(stmt1, stmt2); + } + + //@Test + public void testRemoveUnnecessaryReorgOperation2() { + RewriterStatement stmt1 = RewriterUtils.parse("rev(rev(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0,2.0", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("A", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0,2.0", "LITERAL_INT:1"); + + assert match(stmt1, stmt2); + } + + @Test + public void testSimplifyTransposeAggBinBinaryChains() { + RewriterStatement stmt1 = RewriterUtils.parse("t(+(%*%(t(A),t(B)), C))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0,2.0", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("+(%*%(B,A), t(C))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0,2.0", "LITERAL_INT:1"); + + assert match(stmt1, stmt2); + } + + @Test + public void testRemoveUnnecessaryMinus() { + RewriterStatement stmt1 = RewriterUtils.parse("-(-(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0,2.0", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("A", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0,2.0", "LITERAL_INT:1"); + + assert match(stmt1, stmt2); + } + + @Test + public void testFuseLogNzUnaryOperation() { + RewriterStatement stmt1 = RewriterUtils.parse("*(!=(A,0.0), log(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("log_nz(A)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1"); + + assert match(stmt1, stmt2); + } + + @Test + public void testFuseLogNzBinaryOperation() { + RewriterStatement stmt1 = RewriterUtils.parse("*(!=(A,0.0), log(A, a))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("log_nz(A, a)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1"); + + assert match(stmt1, stmt2); + } + + //@Test + public void testSimplifyNotOverComparisons() { + RewriterStatement stmt1 = RewriterUtils.parse("!(>(A,B))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("<=(A,B)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1"); + + assert match(stmt1, stmt2); + } + + ///// DYNAMIC SIMPLIFICATIONS ////// + + @Test + public void testRemoveEmptyRightIndexing() { + // We do not directly support the specification of nnz, but we can emulate such a matrix by multiplying with 0 + RewriterStatement stmt1 = RewriterUtils.parse("[](*(A, 0.0), 1, nrow(A), 1, 1)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("const(colVec(A), 0.0)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1"); + + assert match(stmt1, stmt2); + } + + @Test + public void testRemoveUnnecessaryRightIndexing() { + RewriterStatement stmt1 = RewriterUtils.parse("[](colVec(A), 1, nrow(A), 1, 1)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("colVec(A)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1"); + + assert match(stmt1, stmt2); + } + + @Test + public void testRemoveUnnecessaryReorgOperation3() { + RewriterStatement stmt1 = RewriterUtils.parse("t(cellMat(A)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("cellMat(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1"); + + assert match(stmt1, stmt2); + } + + //@Test + public void testRemoveUnnecessaryOuterProduct() { + RewriterStatement stmt1 = RewriterUtils.parse("*(A, %*%(colVec(B), const(t(colVec(B)), 1.0)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("*(A, colVec(B))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1"); + + assert match(stmt1, stmt2); + } + + @Test + public void testRemoveUnnecessaryIfElseOperation() { + // Ifelse is not directly supported yet but only on scalars. Thus, we will our index expression syntax to reflect that statement + // Note that we "cheated" here a bit as we index using nrow(A) and ncol(A). We would not get a match if we used nrow(B)... + RewriterStatement stmt1 = RewriterUtils.parse("_m($1:_idx(1, nrow(A)), $2:_idx(1, ncol(A)), ifelse(TRUE, [](A, $1, $2), [](B, $1, $2)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE"); + RewriterStatement stmt2 = RewriterUtils.parse("A", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE"); + + assert match(stmt1, stmt2); + } + + @Test + public void testFuseDatagenAndReorgOperation() { + RewriterStatement stmt1 = RewriterUtils.parse("t(rand(i, 1, 0.0, 1.0))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("rand(1, i, 0.0, 1.0)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + @Test + public void testSimplifyColwiseAggregate() { + RewriterStatement stmt1 = RewriterUtils.parse("colSums(rowVec(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("rowVec(A)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + @Test + public void testSimplifyRowwiseAggregate() { + RewriterStatement stmt1 = RewriterUtils.parse("rowSums(colVec(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("colVec(A)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + // We don't have broadcasting semantics + @Test + public void testSimplifyColSumsMVMult() { + RewriterStatement stmt1 = RewriterUtils.parse("colSums(*(colVec(A), colVec(B)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("%*%(t(colVec(B)), colVec(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + // We don't have broadcasting semantics + @Test + public void testSimplifyRowSumsMVMult() { + RewriterStatement stmt1 = RewriterUtils.parse("rowSums(*(rowVec(A), rowVec(B)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("%*%(rowVec(A), t(rowVec(B)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + @Test + public void testSimplifyUnnecessaryAggregate() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(cellMat(A)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("as.scalar(A)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + @Test + public void testSimplifyEmptyAggregate() { + // We emulate an empty matrix by multiplying by zero + RewriterStatement stmt1 = RewriterUtils.parse("sum(*(A, 0.0))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("0.0", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + @Test + public void testSimplifyEmptyReorgOperation() { + // We emulate an empty matrix by multiplying by zero + RewriterStatement stmt1 = RewriterUtils.parse("t(*(A, 0.0))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("const(t(A), 0.0)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + // This is a hacky workaround + @Test + public void testSimplifyEmptyMatrixMult() { + // We emulate an empty matrix by multiplying by zero + // Note that we pass the dimension info of the matrix multiply to get the same e-class assertions + RewriterStatement stmt1 = RewriterUtils.parse("%*%(*(A, 0.0), B)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("const(%*%(A, B), 0.0)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + // We need to explicitly assert A and B + stmt2.givenThatEqual(stmt2.getChild(0, 1).getNRow(), stmt2.getChild(0, 0).getNCol(), ctx); + stmt2.recomputeAssertions(); + + assert match(stmt1, stmt2, true); + } + + @Test + public void testSimplifyEmptyMatrixMult2() { + RewriterStatement stmt1 = RewriterUtils.parse("%*%(colVec(A), cast.MATRIX(1.0))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("colVec(A)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + @Test + public void testSimplifyScalarMatrixMult() { + RewriterStatement stmt1 = RewriterUtils.parse("%*%(colVec(A), cast.MATRIX(a))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("*(colVec(A), as.scalar(cast.MATRIX(a)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + @Test + public void testSimplifyDistributiveMatrixMult() { + RewriterStatement stmt1 = RewriterUtils.parse("+(%*%(A, B), %*%(A, C))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("%*%(A, +(B, C)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + // Note that we did not implement the overloaded diag(A) operation as we defined diag(A) as setting all other entries to zero (which is not how it is actually handled by SystemDS) + // In this case, we obtain the same rewrite, even though the diag operation is different + @Test + public void testSimplifySumDiagToTrace() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(diag(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("trace(A)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + // Note that we did not implement the overloaded diag(A) operation as we defined diag(A) as setting all other entries to zero (which is not how it is actually handled by SystemDS) + // In this case, we obtain the same equivalence, but in case of our implementation the rewrite would not be beneficial + @Test + public void testPushdownBinaryOperationOnDiag() { + RewriterStatement stmt1 = RewriterUtils.parse("*(diag(A), a)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("diag(*(A, a))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + @Test + public void testPushdownSumOnAdditiveBinary() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(+(A, B))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("+(sum(A), sum(B))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + // We need to assert that the dimensions are the same, which we currently cannot do implicitly through an expression + stmt2.givenThatEqualDimensions(stmt2.getChild(0, 0), stmt2.getChild(1, 0), ctx); + + assert match(stmt1, stmt2); + } + + @Test + public void testSimplifyDotProductSum() { + RewriterStatement stmt1 = RewriterUtils.parse("cast.MATRIX(sum(sq(colVec(A))))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("%*%(t(colVec(A)), colVec(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + @Test + public void testFuseSumSquared() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(sq(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("sumSq(A)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + @Test + public void testFuseAxpyBinaryOperationChain() { + RewriterStatement stmt1 = RewriterUtils.parse("+(A, *(a, B))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("+*(A, a, B)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + @Test + public void testFuseAxpyBinaryOperationChain2() { + RewriterStatement stmt1 = RewriterUtils.parse("-(A, *(a, B))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("-*(A, a, B)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + @Test + public void testReorderMinusMatrixMult() { + RewriterStatement stmt1 = RewriterUtils.parse("%*%(-(t(A)), B)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("-(%*%(t(A), B))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + @Test + public void testSimplifySumMatrixMult() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(%*%(A, B))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("sum(*(t(colSums(A)), rowSums(B)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + @Test + public void testSimplifyEmptyBinaryOperation() { + RewriterStatement stmt1 = RewriterUtils.parse("*(A, const(A, 0.0))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("const(A, 0.0)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + @Test + public void testSimplifyEmptyBinaryOperation2() { + RewriterStatement stmt1 = RewriterUtils.parse("+(A, const(A, 0.0))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("A", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + @Test + public void testSimplifyEmptyBinaryOperation3() { + RewriterStatement stmt1 = RewriterUtils.parse("-(A, const(A, 0.0))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("A", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + //@Test + public void testSimplifyScalarMVBinaryOperation() { + RewriterStatement stmt1 = RewriterUtils.parse("*(A, colVec(colVec(B)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("*(A, as.scalar(B))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + @Test + public void testSimplifyNnzComputation() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(!=(A, 0.0))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("_nnz(A)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + // We only support concrete literals (which is a current limitation of this framework) + @Test + public void testSimplifyNrowNcolComputation() { + // We simulate a matrix with known dimensions by doing a concrete left-indexing + RewriterStatement stmt1 = RewriterUtils.parse("nrow([](A, 1, 5, 1, 5))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1,5", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("5", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1,5", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + // We only support concrete literals (which is a current limitation of this framework) + @Test + public void testSimplifyNrowNcolComputation2() { + // We simulate a matrix with known dimensions by doing a concrete left-indexing + RewriterStatement stmt1 = RewriterUtils.parse("ncol([](A, 1, 5, 1, 5))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1,5", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("5", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1,5", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + // We only support concrete literals (which is a current limitation of this framework) + @Test + public void testSimplifyNrowNcolComputation3() { + // We simulate a matrix with known dimensions by doing a concrete left-indexing + RewriterStatement stmt1 = RewriterUtils.parse("length([](A, 1, 5, 1, 5))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1,5", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("25", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1,25", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + private boolean match(RewriterStatement stmt1, RewriterStatement stmt2) { + return match(stmt1, stmt2, false); + } + + private boolean match(RewriterStatement stmt1, RewriterStatement stmt2, boolean debug) { + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + return RewriterStatement.MatcherContext.exactMatch(ctx, stmt1, stmt2).debug(debug).match(); + } +} diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterRuleValidationTest.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterRuleValidationTest.java new file mode 100644 index 00000000000..3bab52ba8b0 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterRuleValidationTest.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.codegen.rewrite; + +import org.apache.sysds.hops.rewriter.generated.RewriteAutomaticallyGenerated; +import org.apache.sysds.hops.rewriter.estimators.RewriterCostEstimator; +import org.apache.sysds.hops.rewriter.rule.RewriterRule; +import org.apache.sysds.hops.rewriter.rule.RewriterRuleCreator; +import org.apache.sysds.hops.rewriter.rule.RewriterRuleSet; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.junit.BeforeClass; + +import java.io.FileWriter; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.List; +import java.util.function.Function; + +public class RewriterRuleValidationTest { + + public static String RAW_FILE_PATH; // Must be specified + public static String FILE_PATH; // Must be specified + + private static RuleContext ctx; + private static Function canonicalConverter; + + @BeforeClass + public static void setup() { + ctx = RewriterUtils.buildDefaultContext(); + canonicalConverter = RewriterUtils.buildCanonicalFormConverter(ctx, false); + } + + + //@Test + public void test() { + try { + List lines = Files.readAllLines(Paths.get(RAW_FILE_PATH)); + RewriterRuleSet ruleSet = RewriterRuleSet.deserialize(lines, ctx); + RewriterRuleCreator ruleCreator = new RewriterRuleCreator(ctx); + + int ctr = 0; + for (RewriterRule rule : ruleSet.getRules()) { + if (ctr % 10 == 0) + System.out.println("Done: " + ctr + " / " + ruleSet.getRules().size()); + + ctr++; + try { + System.out.println(rule.getStmt1().toParsableString(ctx) + " => " + rule.getStmt2().toParsableString(ctx)); + long preCost = RewriterCostEstimator.estimateCost(rule.getStmt1(), ctx); + long postCost = RewriterCostEstimator.estimateCost(rule.getStmt2(), ctx); + System.out.println(ruleCreator.registerRule(rule, preCost, postCost, true, canonicalConverter)); + } catch (Exception e) { + e.printStackTrace(); + } + } + //System.out.println(ruleSet.toJavaCode("GeneratedRewriteClass", false)); + String serialized = ruleCreator.getRuleSet().serialize(); + //System.out.println(serialized); + + try (FileWriter writer = new FileWriter(FILE_PATH)) { + writer.write(serialized); + } catch (IOException ex) { + ex.printStackTrace(); + } + } catch (IOException e) { + e.printStackTrace(); + } + } +} diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterStreamTests.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterStreamTests.java new file mode 100644 index 00000000000..58c324c7f22 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterStreamTests.java @@ -0,0 +1,1751 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.codegen.rewrite; + +import org.apache.commons.lang3.mutable.MutableObject; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.hops.rewriter.rule.RewriterRuleCreator; +import org.apache.sysds.hops.rewriter.assertions.RewriterAssertionUtils; +import org.apache.sysds.hops.rewriter.assertions.RewriterAssertions; +import org.apache.sysds.hops.rewriter.estimators.RewriterCostEstimator; +import org.apache.sysds.hops.rewriter.RewriterDatabase; +import org.apache.sysds.hops.rewriter.rule.RewriterHeuristic; +import org.apache.sysds.hops.rewriter.rule.RewriterRule; +import org.apache.sysds.hops.rewriter.rule.RewriterRuleBuilder; +import org.apache.sysds.hops.rewriter.rule.RewriterRuleSet; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.junit.BeforeClass; +import org.junit.Test; +import scala.Tuple2; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; + +public class RewriterStreamTests { + protected static final Log LOG = LogFactory.getLog(RewriterStreamTests.class.getName()); + + private static RuleContext ctx; + private static Function canonicalConverter; + + @BeforeClass + public static void setup() { + ctx = RewriterUtils.buildDefaultContext(); + canonicalConverter = RewriterUtils.buildCanonicalFormConverter(ctx, false); + } + + @Test + public void testAdditionFloat1() { + RewriterStatement stmt = RewriterUtils.parse("+(+(a, b), 1)", ctx, "MATRIX:A,B,C", "FLOAT:a,b", "LITERAL_INT:0,1"); + stmt = canonicalConverter.apply(stmt); + LOG.info(stmt.toParsableString(ctx, true)); + assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, RewriterUtils.parse("+(argList(a, b, 1))", ctx, "FLOAT:a,b", "LITERAL_INT:0,1"), stmt)); + } + + @Test + public void testAdditionFloat2() { + RewriterStatement stmt = RewriterUtils.parse("+(1, +(a, b))", ctx, "FLOAT:a,b", "LITERAL_INT:0,1"); + stmt = canonicalConverter.apply(stmt); + LOG.info(stmt.toParsableString(ctx, true)); + assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, RewriterUtils.parse("+(argList(a, b, 1))", ctx, "FLOAT:a,b", "LITERAL_INT:0,1"), stmt)); + } + + @Test + public void testAdditionMatrix1() { + RewriterStatement stmt1 = RewriterUtils.parse("+(+(A, B), 1)", ctx, "MATRIX:A,B", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("+(+(B, 1), A)", ctx, "MATRIX:A,B", "LITERAL_INT:1"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testSubtractionFloat1() { + RewriterStatement stmt = RewriterUtils.parse("+(-(a, b), 1)", ctx, "MATRIX:A,B,C", "FLOAT:a,b", "LITERAL_INT:0,1"); + RewriterStatement stmt2 = RewriterUtils.parse("+(argList(-(b), a, 1))", ctx, "FLOAT:a,b", "LITERAL_INT:0,1"); + stmt = canonicalConverter.apply(stmt); + stmt2 = canonicalConverter.apply(stmt2); + LOG.info(stmt.toParsableString(ctx, true)); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt)); + } + + @Test + public void testSubtractionFloat2() { + RewriterStatement stmt = RewriterUtils.parse("+(1, -(a, -(b, c)))", ctx, "MATRIX:A,B,C", "FLOAT:a,b,c", "LITERAL_INT:0,1"); + RewriterStatement stmt2 = RewriterUtils.parse("+(argList(-(b), a, c, 1))", ctx, "FLOAT:a,b, c", "LITERAL_INT:0,1"); + stmt = canonicalConverter.apply(stmt); + stmt2 = canonicalConverter.apply(stmt2); + LOG.info(stmt.toParsableString(ctx, true)); + assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt)); + } + + // Fusion will no longer be pursued + /*@Test + public void testFusedPlanMatrixGeneration() { + RewriterStatement stmt = RewriterUtils.parse("+(1, +(A, B))", ctx, "MATRIX:A,B", "LITERAL_INT:0,1"); + stmt = converter.apply(stmt); + RewriterStatement fused = RewriterUtils.buildFusedPlan(stmt, ctx); + LOG.info("Orig: " + stmt.toParsableString(ctx, true)); + LOG.info("Fused: " + (fused == null ? null : fused.toParsableString(ctx, true))); + } + + @Test + public void testFusedPlanAggregationGeneration() { + RewriterStatement stmt = RewriterUtils.parse("sum(*(/(A, B), B))", ctx, "MATRIX:A,B", "LITERAL_INT:0,1"); + stmt = converter.apply(stmt); + RewriterStatement fused = RewriterUtils.buildFusedPlan(stmt, ctx); + LOG.info("Orig: " + stmt.toParsableString(ctx, true)); + LOG.info("Fused: " + (fused == null ? null : fused.toParsableString(ctx, true))); + } + + @Test + public void testFusedPlanAdvancedAggregationGeneration() { + RewriterStatement stmt = RewriterUtils.parse("sum(*(t(A), B))", ctx, "MATRIX:A,B", "LITERAL_INT:0,1"); + stmt = converter.apply(stmt); + RewriterStatement fused = RewriterUtils.buildFusedPlan(stmt, ctx); + LOG.info("Orig: " + stmt.toParsableString(ctx, true)); + LOG.info("Fused: " + (fused == null ? null : fused.toParsableString(ctx, true))); + }*/ + + @Test + public void testReorgEquivalence() { + RewriterStatement stmt1 = RewriterUtils.parse("t(t(A))", ctx, "MATRIX:A"); + RewriterStatement stmt2 = RewriterUtils.parse("A", ctx, "MATRIX:A"); + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testTraceEquivalence1() { + RewriterStatement stmt = RewriterUtils.parse("trace(%*%(A, B))", ctx, "MATRIX:A,B"); + RewriterStatement stmt2 = RewriterUtils.parse("sum(*(t(A), B))", ctx, "MATRIX:A,B"); + stmt = canonicalConverter.apply(stmt); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt)); + } + + @Test + public void testTraceEquivalence2() { + RewriterStatement stmt = RewriterUtils.parse("trace(%*%(A, B))", ctx, "MATRIX:A,B"); + RewriterStatement stmt2 = RewriterUtils.parse("sum(*(A, t(B)))", ctx, "MATRIX:A,B"); + stmt = canonicalConverter.apply(stmt); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt)); + } + + @Test + public void testTraceEquivalence3() { + RewriterStatement stmt = RewriterUtils.parse("trace(*(A, B))", ctx, "MATRIX:A,B"); + RewriterStatement stmt2 = RewriterUtils.parse("sum(*(diag(A), diag(B)))", ctx, "MATRIX:A,B"); + stmt = canonicalConverter.apply(stmt); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt)); + } + + @Test + public void testAggEquivalence() { + RewriterStatement stmt = RewriterUtils.parse("sum(%*%(A, B))", ctx, "MATRIX:A,B"); + RewriterStatement stmt2 = RewriterUtils.parse("sum(*(colSums(A), t(rowSums(B))))", ctx, "MATRIX:A,B"); + stmt = canonicalConverter.apply(stmt); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt)); + } + + @Test + public void testSumEquality6() { + RewriterStatement stmt = RewriterUtils.parse("sum(+(B, sum(*(a, A))))", ctx, "MATRIX:A,B", "FLOAT:a"); + RewriterStatement stmt2 = RewriterUtils.parse("sum(+(B, *(a, sum(A))))", ctx, "MATRIX:A,B", "FLOAT:a"); + stmt = canonicalConverter.apply(stmt); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt)); + } + + @Test + public void testSumEquality() { + RewriterStatement stmt = RewriterUtils.parse("sum(+(B, sum(*(a, A))))", ctx, "MATRIX:A,B", "FLOAT:a"); + //RewriterStatement stmt2 = RewriterUtils.parse("+(*(a, length(A)), sum(+(B, sum(A))))", ctx, "MATRIX:A,B", "FLOAT:a"); + RewriterStatement stmt3 = RewriterUtils.parse("sum(+(B, *(a, sum(A))))", ctx, "MATRIX:A,B", "FLOAT:a"); + stmt = canonicalConverter.apply(stmt); + //stmt2 = canonicalConverter.apply(stmt2); + stmt3 = canonicalConverter.apply(stmt3); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt3.toParsableString(ctx, true)); + LOG.info("=========="); + //LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt3, stmt)); + //assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt)); + } + + @Test + public void testArgListSelectionPushdown() { + RewriterStatement stmt = RewriterUtils.parse("[](+(A, 1), 1, 1)", ctx, "MATRIX:A", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("+([](A, 1, 1), 1)", ctx, "MATRIX:A", "LITERAL_INT:1"); + stmt = canonicalConverter.apply(stmt); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt)); + } + + @Test + public void testDistributiveLaw1() { + RewriterStatement stmt = RewriterUtils.parse("*(+(a, b), c)", ctx, "FLOAT:a,b,c"); + RewriterStatement stmt2 = RewriterUtils.parse("+(*(a, c), *(b, c))", ctx, "FLOAT:a,b,c"); + stmt = canonicalConverter.apply(stmt); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt)); + } + + @Test + public void testDistributiveLaw2() { + RewriterStatement stmt = RewriterUtils.parse("*(a, +(b, c))", ctx, "FLOAT:a,b,c"); + RewriterStatement stmt2 = RewriterUtils.parse("+(*(a, b), *(a, c))", ctx, "FLOAT:a,b,c"); + stmt = canonicalConverter.apply(stmt); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt)); + } + + @Test + public void testEClassProperties() { + RewriterStatement stmt = RewriterUtils.parse("*(+(A, B), nrow(A))", ctx, "MATRIX:A,B"); + RewriterStatement stmt2 = RewriterUtils.parse("*(+(A, B), nrow(B))", ctx, "MATRIX:A,B"); + stmt = canonicalConverter.apply(stmt); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt)); + } + + @Test + public void testRealExamples1() { + RewriterStatement stmt1 = RewriterUtils.parse("t(%*%(t(U),V))", ctx, "MATRIX:U,V"); + RewriterStatement stmt2 = RewriterUtils.parse("%*%(t(V), U)", ctx, "MATRIX:U,V"); + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + //TopologicalSort.sort(stmt1, ctx); + //TopologicalSort.sort(stmt2, ctx); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void test() { + RewriterStatement stmt = RewriterUtils.parse("t(A)", ctx, "MATRIX:A,B"); + RewriterStatement stmt2 = RewriterUtils.parse("A", ctx, "FLOAT:A,B"); + stmt = canonicalConverter.apply(stmt); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert !stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt)); + } + + @Test + public void test2() { + RewriterStatement stmt = RewriterUtils.parse("+(0.0,*(2,%*%(t(X),T)))", ctx, "MATRIX:T,X", "FLOAT:0.0", "INT:2"); + stmt = canonicalConverter.apply(stmt); + + LOG.info(stmt.toParsableString(ctx)); + } + + @Test + public void test3() { + RewriterStatement stmt = RewriterUtils.parse("+(+(A,X),t(X))", ctx, "MATRIX:X,A"); + stmt = canonicalConverter.apply(stmt); + + LOG.info(stmt.toParsableString(ctx)); + } + + @Test + public void test4() { + RewriterDatabase db = new RewriterDatabase(); + RewriterStatement stmt = RewriterUtils.parse("trace(%*%(A, B))", ctx, "MATRIX:A,B"); + RewriterStatement stmt2 = RewriterUtils.parse("sum(*(A, t(B)))", ctx, "MATRIX:A,B"); + stmt = canonicalConverter.apply(stmt); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + db.insertEntry(ctx, stmt); + + assert !db.insertEntry(ctx, stmt2); + } + + @Test + public void testForFailure() { + RewriterStatement stmt = RewriterUtils.parse("[](hIndex,i,i,1,1)", ctx, "MATRIX:hIndex", "INT:i", "LITERAL_INT:1"); + stmt = canonicalConverter.apply(stmt); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + } + + @Test + public void testTypeConversions() { + RewriterStatement stmt1 = RewriterUtils.parse("+(TRUE, 1)", ctx, "LITERAL_BOOL:TRUE", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("+(1, 1)", ctx, "LITERAL_INT:1"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testCSE() { + RewriterStatement stmt1 = RewriterUtils.parse("+(*(a, b), *(b, a))", ctx, "FLOAT:a,b"); + RewriterStatement stmt2 = RewriterUtils.parse("+($1:*(a, b), $1)", ctx, "FLOAT:a,b"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + RewriterDatabase db = new RewriterDatabase(); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + db.insertEntry(ctx, stmt1); + + assert !db.insertEntry(ctx, stmt2); + } + + @Test + public void testExactMatch() { + RewriterStatement stmt1 = RewriterUtils.parse("+(*(a, b), *(b, a))", ctx, "FLOAT:a,b"); + RewriterStatement stmt2 = RewriterUtils.parse("+($1:*(a, b), $1)", ctx, "FLOAT:a,b"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + assert stmt2.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt1, stmt2)); + } + + //@Test + public void testMinEquivalence() { + RewriterStatement stmt1 = RewriterUtils.parse("min(min(A), min(B))", ctx, "MATRIX:A,B"); + RewriterStatement stmt2 = RewriterUtils.parse("min(A, B)", ctx, "MATRIX:A,B"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testSumEquivalence() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(A)", ctx, "MATRIX:A,B"); + RewriterStatement stmt2 = RewriterUtils.parse("sum(t(A))", ctx, "MATRIX:A,B"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + /*@Test + public void testSimpleAlgebra1() { + RewriterStatement stmt1 = RewriterUtils.parse("-(X, *(Y, X))", ctx, "MATRIX:X,Y"); + RewriterStatement stmt2 = RewriterUtils.parse("*(-(1, Y), X)", ctx, "MATRIX:X,Y", "LITERAL_INT:1"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + }*/ + + @Test + public void testSimpleAlgebra2() { + RewriterStatement stmt1 = RewriterUtils.parse("diag(*(X, 7))", ctx, "MATRIX:X,Y", "LITERAL_INT:7"); + RewriterStatement stmt2 = RewriterUtils.parse("*(diag(X), 7)", ctx, "MATRIX:X,Y", "LITERAL_INT:7"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + /*@Test + public void testSimpleAlgebra3() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(+(+(X, 7), Y))", ctx, "MATRIX:X,Y", "LITERAL_INT:7"); + RewriterStatement stmt2 = RewriterUtils.parse("+(+(sum(X), 7), sum(Y))", ctx, "MATRIX:X,Y", "LITERAL_INT:7"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + }*/ + + @Test + public void testSimpleAlgebra4() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(-(+(+(X, 7), Y)))", ctx, "MATRIX:X,Y", "LITERAL_INT:7"); + + RewriterStatement matX = RewriterUtils.parse("X", ctx, "MATRIX:X"); + RewriterStatement matY = RewriterUtils.parse("Y", ctx, "MATRIX:Y"); + Map vars = new HashMap<>(); + vars.put("X", matX); + vars.put("Y", matY); + RewriterStatement stmt2 = RewriterUtils.parse("-(+(sum(+(X, 7)), sum(Y)))", ctx, vars, "LITERAL_INT:7"); + stmt2.givenThatEqual(vars.get("X").getNCol(), vars.get("Y").getNCol(), stmt2, ctx); + stmt2.givenThatEqual(vars.get("X").getNRow(), vars.get("Y").getNRow(), stmt2, ctx); + stmt2 = stmt2.recomputeAssertions(); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testSimpleSumPullOut() { + RewriterStatement stmt1 = RewriterUtils.parse("-(sum(+(A, 7)))", ctx, "MATRIX:A", "LITERAL_FLOAT:7"); + RewriterStatement stmt2 = RewriterUtils.parse("sum(+(-(A), -7))", ctx, "MATRIX:A", "LITERAL_FLOAT:-7"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testSimpleInverseEquivalence() { + RewriterStatement stmt1 = RewriterUtils.parse("inv(A)", ctx, "MATRIX:A,B", "LITERAL_INT:7"); + RewriterStatement stmt2 = RewriterUtils.parse("-(inv(-(A)))", ctx, "MATRIX:A,B", "LITERAL_INT:7"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + //@Test + public void testBackrefInequality() { + // Some example where _backRef() is not the same as another one + // As we need to compare to the meta-data + assert false; + } + + @Test + public void myTest() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(-(X, 7))", ctx, "MATRIX:X,Y", "LITERAL_INT:1,7", "INT:a", "LITERAL_FLOAT:7.0"); + stmt1 = canonicalConverter.apply(stmt1); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + } + + @Test + public void myTest2() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(_idxExpr(_idx(1, 7), -(a)))", ctx, "MATRIX:X,Y", "LITERAL_INT:1,7", "INT:a"); + stmt1 = canonicalConverter.apply(stmt1); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + } + + @Test + public void myTest3() { + RewriterStatement stmt = RewriterUtils.parse("%*%(X,[](B,1,ncol(X),1,ncol(B)))", ctx, "MATRIX:X,B,intercept", "LITERAL_INT:1"); + stmt = canonicalConverter.apply(stmt); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + } + + @Test + public void myTest4() { + RewriterStatement stmt = RewriterUtils.parse("*(CBind(t(KM),KM_cols_select),KM_cols_select)", ctx, "MATRIX:KM,KM_cols_select"); + stmt = canonicalConverter.apply(stmt); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + } + + @Test + public void myTest5() { + RewriterStatement stmt = RewriterUtils.parse("*(CBind(A, A),A)", ctx, "MATRIX:A"); + stmt = canonicalConverter.apply(stmt); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + } + + @Test + public void myTest6() { + RewriterStatement stmt = RewriterUtils.parse("rowSums(<=(D,minD))", ctx, "MATRIX:D,minD"); + stmt = canonicalConverter.apply(stmt); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + } + + @Test + public void myTest7() { + String stmtStr = "MATRIX:combined\n" + + "FLOAT:int0,int496,int236,int618\n" + + "LITERAL_INT:1,2\n" + + "INT:parsertemp71754,int497,int280\n" + + "&(RBind(!=([](combined,1,-(parsertemp71754,int497),1,ncol(combined)),[](combined,2,nrow(combined),1,ncol(combined))),rand(1,1,int0,int496)),RBind(rand(1,1,int618,int236),!=([](combined,1,-(parsertemp71754,int280),1,ncol(combined)),[](combined,2,nrow(combined),1,ncol(combined)))))"; + + RewriterStatement stmt = RewriterUtils.parse(stmtStr, ctx); + stmt = canonicalConverter.apply(stmt); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + } + + @Test + public void myTest8() { + String stmtStr = "MATRIX:prec_chol,X,mu\n" + + "INT:i,k\n" + + "LITERAL_INT:1,5\n" + + "%*%(X,[](prec_chol,1,*(i,ncol(X)),1,5))"; + + RewriterStatement stmt = RewriterUtils.parse(stmtStr, ctx); + stmt = canonicalConverter.apply(stmt); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + } + + @Test + public void myTest9() { + String stmtStr = "MATRIX:A,scale_X,shift_X,parsertemp282257,parsertemp282256,parsertemp282259,parsertemp282258\n" + + "INT:m_ext\n" + + "LITERAL_INT:1\n" + + "+(%*%(diag(scale_X),t(+(%*%(parsertemp282256,A),%*%(shift_X,A)))),%*%(shift_X,[](t(+(parsertemp282257,parsertemp282258)),m_ext,m_ext,1,nrow(parsertemp282259))))"; + + RewriterStatement stmt = RewriterUtils.parse(stmtStr, ctx); + stmt = canonicalConverter.apply(stmt); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + } + + @Test + public void myTest10() { + String stmtStr = "MATRIX:P,minD,D,X\n" + + "/(%*%(t(/(<=(D,minD),rowSums(P))),X),t(colSums(/(<=(D,minD),rowSums(P)))))"; + + RewriterStatement stmt = RewriterUtils.parse(stmtStr, ctx); + stmt = canonicalConverter.apply(stmt); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + } + + @Test + public void testConstantFolding1() { + RewriterStatement stmt1 = RewriterUtils.parse("*(1, A)", ctx, "MATRIX:A", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("A", ctx, "MATRIX:A"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testConstantFolding2() { + RewriterStatement stmt1 = RewriterUtils.parse("+(A, 0)", ctx, "MATRIX:A", "LITERAL_INT:0"); + RewriterStatement stmt2 = RewriterUtils.parse("A", ctx, "MATRIX:A"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testConstantFolding3() { + RewriterStatement stmt1 = RewriterUtils.parse("+(A, *(3, -(1, 1)))", ctx, "MATRIX:A", "LITERAL_INT:1,3"); + RewriterStatement stmt2 = RewriterUtils.parse("A", ctx, "MATRIX:A"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testConstantFolding4() { + RewriterStatement stmt1 = RewriterUtils.parse("*(A, 0)", ctx, "MATRIX:A", "LITERAL_FLOAT:0"); + RewriterStatement stmt2 = RewriterUtils.parse("rand(nrow(A), ncol(A), 0, 0)", ctx, "MATRIX:A", "LITERAL_FLOAT:0"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testAdvancedEquivalence1() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(+(A, -7))", ctx, "MATRIX:A", "LITERAL_FLOAT:-7"); + RewriterStatement stmt2 = RewriterUtils.parse("sum(-(A, 7))", ctx, "MATRIX:A", "LITERAL_FLOAT:7"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testInequality() { + RewriterStatement stmt1 = RewriterUtils.parse("/(*(A, A), B)", ctx, "MATRIX:A,B"); + RewriterStatement stmt2 = RewriterUtils.parse("/(*(A, A), sum(B))", ctx, "MATRIX:A,B"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert !stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testDiagEquivalence() { + RewriterStatement stmt1 = RewriterUtils.parse("diag(diag(A))", ctx, "MATRIX:A,B"); + RewriterStatement stmt2 = RewriterUtils.parse("diag(A)", ctx, "MATRIX:A,B"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testRIXInequality() { + RewriterStatement stmt1 = RewriterUtils.parse("+(A, [](B, 1, nrow(A), 1, ncol(A)))", ctx, "MATRIX:A,B", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("+(A, B)", ctx, "MATRIX:A,B"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert !stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void convergenceTest() { + String stmtStr = "MATRIX:dl_matrix\n" + + "INT:i,j,46307663-5c68-48ba-aa86-c1c36de45dbe\n" + + "LITERAL_INT:1,2\n" + + "[](dl_matrix,+(i,-(2)),-(i,2),1,1)"; + + RewriterStatement stmt = RewriterUtils.parse(stmtStr, ctx); + stmt = canonicalConverter.apply(stmt); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + } + + @Test + public void someTest() { + RewriterStatement stmt1 = RewriterUtils.parse("+([](%*%(A,B),151,151,1,ncol(B)),C)", ctx, "MATRIX:A,B,C", "LITERAL_INT:1,151"); + RewriterStatement stmt2 = RewriterUtils.parse("+([](C,151,151,1,ncol(B)),%*%(A,B))", ctx, "MATRIX:A,B,C", "LITERAL_INT:1,151"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert !stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void my_Test() { + RewriterStatement stmt1 = RewriterUtils.parse("[](A, 1, 1, 151, 151)", ctx, "MATRIX:A,B,C", "LITERAL_INT:1,151"); + + stmt1 = canonicalConverter.apply(stmt1); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + } + + @Test + public void testSumEquality2() { + RewriterStatement stmt1 = RewriterUtils.parse("rowSums(colSums(A))", ctx, "MATRIX:A"); + RewriterStatement stmt2 = RewriterUtils.parse("as.matrix(sum(A))", ctx, "MATRIX:A"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testSumEquality3() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(%*%(A, B))", ctx, "MATRIX:A,B"); + RewriterStatement stmt2 = RewriterUtils.parse("as.scalar(%*%(colSums(A), rowSums(B)))", ctx, "MATRIX:A,B"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testSumEquality4() { + RewriterStatement stmt1 = RewriterUtils.parse("%*%(t(colVec(A)), colVec(A))", ctx, "MATRIX:A,B", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("as.matrix(sum(*(colVec(A), colVec(A))))", ctx, "MATRIX:A,B", "LITERAL_INT:1"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testSumEquality5() { + RewriterStatement stmt1 = RewriterUtils.parse("rowSums([](A, 1, nrow(A), 1, 1))", ctx, "MATRIX:A,B", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("[](A, 1, nrow(A), 1, 1)", ctx, "MATRIX:A,B", "LITERAL_INT:1"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testSimpleConvergence() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(a)", ctx, "FLOAT:a"); + + stmt1 = canonicalConverter.apply(stmt1); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + } + + @Test + public void testImplicitInequality() { + RewriterStatement stmt1 = RewriterUtils.parse("+([](A,1, nrow(A), 1, 1), B)", ctx, "MATRIX:A,B", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("+([](A,1, nrow(A), 1, 1), [](B, 1, nrow(B), 1, 1))", ctx, "MATRIX:A,B", "LITERAL_INT:1"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert !stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testTraceEquivalence() { + RewriterStatement stmt1 = RewriterUtils.parse("trace(%*%(t(S),R))", ctx, "MATRIX:S,R", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("sum(*(S,R))", ctx, "MATRIX:S,R", "LITERAL_INT:1"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + LOG.info("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx)); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testMMEquivalence() { + RewriterStatement stmt1 = RewriterUtils.parse("%*%(A,*(b, B))", ctx, "MATRIX:A,B", "FLOAT:b", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("*(b, %*%(A, B))", ctx, "MATRIX:A,B", "FLOAT:b", "LITERAL_INT:1"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + LOG.info("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx)); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info(stmt1.getAssertions(ctx)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + LOG.info(stmt2.getAssertions(ctx)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testMMEquivalence2() { + RewriterStatement stmt1 = RewriterUtils.parse("cast.MATRIX(sum(*(t(rowVec(A)), colVec(B))))", ctx, "MATRIX:A,B", "FLOAT:b", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("%*%(rowVec(A), colVec(B))", ctx, "MATRIX:A,B", "FLOAT:b", "LITERAL_INT:1"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + LOG.info("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx)); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testColSumEquivalence4() { + RewriterStatement stmt1 = RewriterUtils.parse("colSums(*(A, b))", ctx, "MATRIX:A,B", "FLOAT:b", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("*(b, colSums(A))", ctx, "MATRIX:A,B", "FLOAT:b", "LITERAL_INT:1"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + LOG.info("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx)); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testColSumEquivalence5() { + RewriterStatement stmt1 = RewriterUtils.parse("colSums(*(A, b))", ctx, "MATRIX:A,B", "FLOAT:b", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("*(b, colSums(A))", ctx, "MATRIX:A,B", "FLOAT:b", "LITERAL_INT:1"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + LOG.info("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx)); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testZeroElimination() { + RewriterStatement stmt1 = RewriterUtils.parse("*(A,0.0)", ctx, "MATRIX:A,B", "FLOAT:b", "LITERAL_INT:1", "LITERAL_FLOAT:0.0"); + RewriterStatement stmt2 = RewriterUtils.parse("const(A, 0.0)", ctx, "MATRIX:A,B", "FLOAT:b", "LITERAL_INT:1", "LITERAL_FLOAT:0.0"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + LOG.info("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx)); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testMMScalarPullout() { + RewriterStatement stmt1 = RewriterUtils.parse("%*%(*(A, b), B)", ctx, "MATRIX:A,B", "FLOAT:b", "LITERAL_INT:1", "LITERAL_FLOAT:0.0"); + RewriterStatement stmt2 = RewriterUtils.parse("*(b, %*%(A, B))", ctx, "MATRIX:A,B", "FLOAT:b", "LITERAL_INT:1", "LITERAL_FLOAT:0.0"); + + long cost1 = RewriterCostEstimator.estimateCost(stmt1, ctx); + long cost2 = RewriterCostEstimator.estimateCost(stmt2, ctx); + + LOG.info("Cost1: " + cost1); + LOG.info("Cost2: " + cost2); + + assert cost2 == cost1; + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testWrong() { + RewriterStatement stmt1 = RewriterUtils.parse("*(sum(colVec(A)),colSums(B))", ctx, "MATRIX:A,B", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("%*%(colVec(A),colSums(B))", ctx, "MATRIX:A,B", "LITERAL_INT:1"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + LOG.info("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx)); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert !stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testWrong2() { + RewriterStatement stmt1 = RewriterUtils.parse("*(a,1.0)", ctx, "FLOAT:a", "LITERAL_FLOAT:1.0"); + RewriterStatement stmt2 = RewriterUtils.parse("a", ctx, "FLOAT:a", "LITERAL_INT:1"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + LOG.info("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx)); + + RewriterStatement newStmt = canonicalConverter.apply(stmt1); + LOG.info(newStmt); + LOG.info(stmt1); + //stmt2 = canonicalConverter.apply(stmt2); + + /*LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1));*/ + } + + //@Test + public void testRev() { + RewriterStatement stmt1 = RewriterUtils.parse("rev(rev(A))", ctx, "MATRIX:A,B", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("A", ctx, "MATRIX:A,B", "LITERAL_INT:1"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + LOG.info("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx)); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testTrace() { + RewriterStatement stmt1 = RewriterUtils.parse("trace(%*%(B,B))", ctx, "MATRIX:A,B", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("sum(*(A, t(B)))", ctx, "MATRIX:A,B", "LITERAL_INT:1"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + LOG.info("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx)); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + stmt1.compress(); + stmt2.compress(); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert !stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testFused1() { + RewriterStatement stmt1 = RewriterUtils.parse("1-*(A, B)", ctx, "MATRIX:A,B"); + RewriterStatement stmt2 = RewriterUtils.parse("-(1.0, *(A, B))", ctx, "MATRIX:A,B", "LITERAL_FLOAT:1.0"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + LOG.info("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx)); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testFused2() { + RewriterStatement stmt1 = RewriterUtils.parse("+(a, 1-*(A, B))", ctx, "MATRIX:A,B", "FLOAT:a"); + RewriterStatement stmt2 = RewriterUtils.parse("-(1.0, -(*(A, B), a))", ctx, "MATRIX:A,B", "FLOAT:a", "LITERAL_FLOAT:1.0"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + LOG.info("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx)); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testFused3() { + RewriterStatement stmt1 = RewriterUtils.parse("log_nz(A)", ctx, "MATRIX:A,B", "FLOAT:a"); + RewriterStatement stmt2 = RewriterUtils.parse("*(!=(0.0, A), log(A))", ctx, "MATRIX:A,B", "LITERAL_FLOAT:0.0"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + LOG.info("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx)); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testFused4() { + RewriterStatement stmt1 = RewriterUtils.parse("log_nz(A, a)", ctx, "MATRIX:A,B", "FLOAT:a"); + RewriterStatement stmt2 = RewriterUtils.parse("*(!=(0.0, A), log(A, a))", ctx, "MATRIX:A,B", "FLOAT:a", "LITERAL_FLOAT:0.0"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + LOG.info("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx)); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testFused5() { + RewriterStatement stmt1 = RewriterUtils.parse("sq(1-*(A,A))", ctx, "MATRIX:A,B", "FLOAT:a"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + + stmt1 = canonicalConverter.apply(stmt1); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + } + + @Test + public void testFused6() { + RewriterStatement stmt1 = RewriterUtils.parse("/(A,A)", ctx, "MATRIX:A,B", "FLOAT:a"); + RewriterStatement stmt2 = RewriterUtils.parse("/(A,rev(A))", ctx, "MATRIX:A,B", "FLOAT:a", "LITERAL_FLOAT:0.0"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + LOG.info("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx)); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert !stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testFused7() { + RewriterStatement stmt1 = RewriterUtils.parse("+*(A,a,B)", ctx, "MATRIX:A,B", "FLOAT:a"); + RewriterStatement stmt2 = RewriterUtils.parse("+(*(a, B), A)", ctx, "MATRIX:A,B", "FLOAT:a"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + LOG.info("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx)); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testFused8() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(!=(0.0, A))", ctx, "MATRIX:A,B", "LITERAL_FLOAT:0.0"); + RewriterStatement stmt2 = RewriterUtils.parse("_nnz(A)", ctx, "MATRIX:A,B", "FLOAT:a"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + LOG.info("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx)); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testFusedCompilation() { + RewriterStatement stmt1 = RewriterUtils.parse("+(a,*2(1-*(B,B)))", ctx, "MATRIX:A,B", "FLOAT:a"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + + stmt1 = canonicalConverter.apply(stmt1); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + } + + @Test + public void testSum() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(+(a,A))", ctx, "MATRIX:A,B", "FLOAT:a"); + RewriterStatement stmt2 = RewriterUtils.parse("+(*(a, length(A)), sum(A))", ctx, "MATRIX:A,B", "FLOAT:a", "LITERAL_FLOAT:0.0"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + LOG.info("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx)); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testRowSums() { + RewriterStatement stmt1 = RewriterUtils.parse("*(rowSums(/(a,C)),b)", ctx, "MATRIX:A,B,C", "FLOAT:a,b"); + RewriterStatement stmt2 = RewriterUtils.parse("rowSums(/(*(a,b),C))", ctx, "MATRIX:A,B,C", "FLOAT:a,b", "LITERAL_FLOAT:0.0"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + LOG.info("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx)); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testRowSums2() { + RewriterStatement stmt1 = RewriterUtils.parse("rowSums(*(A,+(B,1.0)))", ctx, "MATRIX:A,B,C", "FLOAT:a,b", "LITERAL_FLOAT:1.0"); + RewriterStatement stmt2 = RewriterUtils.parse("+(rowSums(A), rowSums(*(B,A)))", ctx, "MATRIX:A,B,C", "FLOAT:a,b", "LITERAL_FLOAT:1.0"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + LOG.info("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx)); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testDistrib3() { + RewriterStatement stmt1 = RewriterUtils.parse("*(A,+(B,1.0))", ctx, "MATRIX:A,B,C", "FLOAT:a,b", "LITERAL_FLOAT:1.0"); + RewriterStatement stmt2 = RewriterUtils.parse("+(A, *(B,A))", ctx, "MATRIX:A,B,C", "FLOAT:a,b", "LITERAL_FLOAT:1.0"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + LOG.info("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx)); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testRev2() { + RewriterStatement stmt1 = RewriterUtils.parse("trace(rev(A))", ctx, "MATRIX:A,B,C", "FLOAT:a,b", "LITERAL_FLOAT:1.0"); + RewriterStatement stmt2 = RewriterUtils.parse("trace(A)", ctx, "MATRIX:A,B,C", "FLOAT:a,b", "LITERAL_FLOAT:1.0"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + LOG.info("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx)); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert !stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testSumInequality() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(+(a,*(B,c)))", ctx, "MATRIX:B", "FLOAT:a,c"); + RewriterStatement stmt2 = RewriterUtils.parse("*(a, sum(+(B,c)))", ctx, "MATRIX:B", "FLOAT:a,c", "LITERAL_FLOAT:0.0"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + LOG.info("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx)); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert !stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testDiag1() { + RewriterStatement stmt1 = RewriterUtils.parse("diag(+(A, B))", ctx, "MATRIX:A,B"); + RewriterStatement stmt2 = RewriterUtils.parse("+(diag(A), diag(B))", ctx, "MATRIX:A,B"); + + long cost1 = RewriterCostEstimator.estimateCost(stmt1, ctx); + long cost2 = RewriterCostEstimator.estimateCost(stmt2, ctx); + + LOG.info("Cost1: " + cost1); + LOG.info("Cost2: " + cost2); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + assert cost1 > cost2; + } + + @Test + public void testDiag2() { + RewriterStatement stmt1 = RewriterUtils.parse("trace(A)", ctx, "MATRIX:A,B"); + RewriterStatement stmt2 = RewriterUtils.parse("trace(diag(A))", ctx, "MATRIX:A,B"); + + long cost1 = RewriterCostEstimator.estimateCost(stmt1, ctx); + long cost2 = RewriterCostEstimator.estimateCost(stmt2, ctx); + + LOG.info("Cost1: " + cost1); + LOG.info("Cost2: " + cost2); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testDiag3() { + RewriterStatement stmt1 = RewriterUtils.parse("%*%(diag(A), diag(B))", ctx, "MATRIX:A,B"); + RewriterStatement stmt2 = RewriterUtils.parse("*(diag(A), diag(B))", ctx, "MATRIX:A,B"); + + long cost1 = RewriterCostEstimator.estimateCost(stmt1, ctx); + long cost2 = RewriterCostEstimator.estimateCost(stmt2, ctx); + + LOG.info("Cost1: " + cost1); + LOG.info("Cost2: " + cost2); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testConstFold() { + RewriterStatement stmt1 = RewriterUtils.parse("-(+(1.0,a), 1.0)", ctx, "FLOAT:a", "LITERAL_FLOAT:1.0"); + RewriterStatement stmt2 = RewriterUtils.parse("a", ctx, "FLOAT:a"); + + long cost1 = RewriterCostEstimator.estimateCost(stmt1, ctx); + long cost2 = RewriterCostEstimator.estimateCost(stmt2, ctx); + + LOG.info("Cost1: " + cost1); + LOG.info("Cost2: " + cost2); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + //@Test + public void testConst() { + RewriterStatement stmt1 = RewriterUtils.parse("min(const(A, a))", ctx, "FLOAT:a", "MATRIX:A"); + RewriterStatement stmt2 = RewriterUtils.parse("a", ctx, "FLOAT:a"); + + long cost1 = RewriterCostEstimator.estimateCost(stmt1, ctx); + long cost2 = RewriterCostEstimator.estimateCost(stmt2, ctx); + + LOG.info("Cost1: " + cost1); + LOG.info("Cost2: " + cost2); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testMin() { + RewriterStatement stmt1 = RewriterUtils.parse("+(A, min(B))", ctx, "MATRIX:A,B"); + + long cost1 = RewriterCostEstimator.estimateCost(stmt1, ctx); + + LOG.info("Cost1: " + cost1); + + stmt1 = canonicalConverter.apply(stmt1); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + } + + @Test + public void testBoolDiag() { + RewriterStatement stmt1 = RewriterUtils.parse("diag(!=(A,A))", ctx, "MATRIX:A,B"); + + long cost1 = RewriterCostEstimator.estimateCost(stmt1, ctx); + + LOG.info("Cost1: " + cost1); + + stmt1 = canonicalConverter.apply(stmt1); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + } + + @Test + public void testWrong3() { + RewriterStatement stmt1 = RewriterUtils.parse("*(A, /(A,C))", ctx, "MATRIX:A,C"); + RewriterStatement stmt2 = RewriterUtils.parse("*(sum(*(C,A)), A)", ctx, "MATRIX:A,C"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert !stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testWrong4() { + // TODO: Rule "Element selection pushdown" seems to be an issue here + RewriterStatement stmt1 = RewriterUtils.parse("/(A, rev(A))", ctx, "MATRIX:A"); + RewriterStatement stmt2 = RewriterUtils.parse("/(A, A)", ctx, "MATRIX:A"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert !stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testWrong5() { + RewriterStatement stmt1 = RewriterUtils.parse("*2(-(B,B))", ctx, "MATRIX:B"); + RewriterStatement stmt2 = RewriterUtils.parse("*2(-(a, B))", ctx, "MATRIX:B", "FLOAT:a"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert !stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testWrong6() { + RewriterStatement stmt1 = RewriterUtils.parse("%*%(t(+(A,A)), B)", ctx, "MATRIX:A,B,C", "LITERAL_FLOAT:1.0"); + RewriterStatement stmt2 = RewriterUtils.parse("%*%(t(A), +(B, B))", ctx, "MATRIX:A,B,C", "FLOAT:a"); + + RewriterStatement can1 = canonicalConverter.apply(stmt1); + RewriterStatement can2 = canonicalConverter.apply(stmt2); + + stmt1 = RewriterRuleCreator.createCommonForm(stmt1, stmt2, can1, can2, ctx)._1; + RewriterAssertions assertions = RewriterAssertionUtils.buildImplicitAssertions(stmt1, ctx); + RewriterAssertionUtils.buildImplicitAssertion(stmt2, assertions, stmt1, ctx); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + LOG.info(RewriterCostEstimator.getRawCostFunction(stmt1, ctx, new MutableObject<>(assertions), false).toParsableString(ctx)); + LOG.info(RewriterCostEstimator.getRawCostFunction(stmt2, ctx, new MutableObject<>(assertions), false).toParsableString(ctx)); + LOG.info(RewriterCostEstimator.compareCosts(List.of(stmt1, stmt2), assertions, ctx, false, 5)); + Set> t = RewriterCostEstimator.findOptima(RewriterCostEstimator.compareCosts(List.of(stmt1, stmt2), assertions, ctx, true, 5)); + LOG.info(t); + + assert !stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, can1, can2)); + } + + @Test + public void testWrong7() { + RewriterStatement stmt1 = RewriterUtils.parse("*(+(B,B),A)", ctx, "MATRIX:A,B,C", "LITERAL_FLOAT:1.0"); + RewriterStatement stmt2 = RewriterUtils.parse("%*%(t(A), +(B, B))", ctx, "MATRIX:A,B,C", "FLOAT:a"); + + RewriterStatement can1 = canonicalConverter.apply(stmt1); + RewriterStatement can2 = canonicalConverter.apply(stmt2); + + stmt1 = RewriterRuleCreator.createCommonForm(stmt1, stmt2, can1, can2, ctx)._1; + RewriterAssertions assertions = RewriterAssertionUtils.buildImplicitAssertions(stmt1, ctx); + RewriterAssertionUtils.buildImplicitAssertion(stmt2, assertions, stmt1, ctx); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + LOG.info(RewriterCostEstimator.getRawCostFunction(stmt1, ctx, new MutableObject<>(assertions), false).toParsableString(ctx)); + LOG.info(RewriterCostEstimator.getRawCostFunction(stmt2, ctx, new MutableObject<>(assertions), false).toParsableString(ctx)); + LOG.info(RewriterCostEstimator.compareCosts(List.of(stmt1, stmt2), assertions, ctx, false, 5)); + Set> t = RewriterCostEstimator.findOptima(RewriterCostEstimator.compareCosts(List.of(stmt1, stmt2), assertions, ctx, true, 5)); + LOG.info(t); + + assert !stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, can1, can2)); + } + + @Test + public void testConstInequivality() { + RewriterStatement stmt1 = RewriterUtils.parse("%*%(const(A, 0.0), A)", ctx, "MATRIX:A", "LITERAL_FLOAT:0.0"); + RewriterStatement stmt2 = RewriterUtils.parse("const(A, 0.0)", ctx, "MATRIX:A", "LITERAL_FLOAT:0.0"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert !stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testSumEquality7() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(+(a, A))", ctx, "MATRIX:A", "FLOAT:a"); + RewriterStatement stmt2 = RewriterUtils.parse("+(*(a, length(A)), sum(A))", ctx, "MATRIX:A", "FLOAT:a"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testSumEquality8() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(const(A,1.0))", ctx, "MATRIX:A", "LITERAL_FLOAT:1.0"); + RewriterStatement stmt2 = RewriterUtils.parse("length(A)", ctx, "MATRIX:A", "FLOAT:a"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testSparsityComparison() { + RewriterStatement stmt1 = RewriterUtils.parse("+(*(A, B),*(A, C))", ctx, "MATRIX:A,B,C", "LITERAL_FLOAT:1.0"); + RewriterStatement stmt2 = RewriterUtils.parse("*(A, +(B, C))", ctx, "MATRIX:A,B,C", "FLOAT:a"); + + RewriterStatement can1 = canonicalConverter.apply(stmt1); + RewriterStatement can2 = canonicalConverter.apply(stmt2); + + stmt1 = RewriterRuleCreator.createCommonForm(stmt1, stmt2, can1, can2, ctx)._1; + RewriterAssertions assertions = RewriterAssertionUtils.buildImplicitAssertions(stmt1, ctx); + RewriterAssertionUtils.buildImplicitAssertion(stmt2, assertions, stmt1, ctx); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + LOG.info(RewriterCostEstimator.getRawCostFunction(stmt1, ctx, new MutableObject<>(assertions), false).toParsableString(ctx)); + LOG.info(RewriterCostEstimator.getRawCostFunction(stmt2, ctx, new MutableObject<>(assertions), false).toParsableString(ctx)); + LOG.info(RewriterCostEstimator.compareCosts(List.of(stmt1, stmt2), assertions, ctx, false, 5)); + Set> t = RewriterCostEstimator.findOptima(RewriterCostEstimator.compareCosts(List.of(stmt1, stmt2), assertions, ctx, true, 5)); + LOG.info(t); + + assert can2.match(RewriterStatement.MatcherContext.exactMatch(ctx, can1, can2)); + } + + @Test + public void testTEST() { + RewriterStatement stmt1 = RewriterUtils.parse("t(/(<=(A,B),rowSums(<=(C,B))))", ctx, "MATRIX:A,B,C,D,E", "LITERAL_FLOAT:1.0", "LITERAL_INT:1"); + + stmt1 = canonicalConverter.apply(stmt1); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + } +} diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterTopologySortTests.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterTopologySortTests.java new file mode 100644 index 00000000000..a34a73b3774 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterTopologySortTests.java @@ -0,0 +1,265 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.codegen.rewrite; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.apache.sysds.hops.rewriter.TopologicalSort; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.util.function.Function; + +public class RewriterTopologySortTests { + protected static final Log LOG = LogFactory.getLog(RewriterTopologySortTests.class.getName()); + private static RuleContext ctx; + private static Function converter; + + @BeforeClass + public static void setup() { + ctx = RewriterUtils.buildDefaultContext(); + converter = RewriterUtils.buildCanonicalFormConverter(ctx, false); + } + + @Test + public void testSimpleEquivalence1() { + RewriterStatement stmt = RewriterUtils.parse("+(*(a, b), *(a, c))", ctx, "FLOAT:a,b,c"); + RewriterStatement stmt2 = RewriterUtils.parse("+(*(b, a), *(c, a))", ctx, "FLOAT:a,b,c"); + stmt = converter.apply(stmt); + stmt2 = converter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt)); + } + + @Test + public void testSimpleEquivalence2() { + // Here, a and b are indistinguishable + // Thus, the topological sort has to decide a random but consistent order + RewriterStatement stmt = RewriterUtils.parse("+(*(a, b), *(b, a))", ctx, "FLOAT:a,b"); + RewriterStatement stmt2 = RewriterUtils.parse("+(*(b, a), *(b, a))", ctx, "FLOAT:a,b"); + stmt = converter.apply(stmt); + stmt2 = converter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt)); + } + + @Test + public void testSimpleEquivalence3() { + RewriterStatement stmt = RewriterUtils.parse("+(-(*(a, b)), *(b, a))", ctx, "FLOAT:a,b"); + RewriterStatement stmt2 = RewriterUtils.parse("+(*(b, a), -(*(b, a)))", ctx, "FLOAT:a,b"); + stmt = converter.apply(stmt); + stmt2 = converter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt)); + } + + @Test + public void testSimpleEquivalence4() { + RewriterStatement stmt = RewriterUtils.parse("+(*(-(a), b), *(b, a))", ctx, "FLOAT:a,b"); + RewriterStatement stmt2 = RewriterUtils.parse("+(*(a, -(b)), *(b, a))", ctx, "FLOAT:a,b"); + stmt = converter.apply(stmt); + stmt2 = converter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt)); + } + + @Test + public void testSimpleEquivalence5() { + RewriterStatement stmt = RewriterUtils.parse("+(1, 2)", ctx, "LITERAL_INT:1,2"); + RewriterStatement stmt2 = RewriterUtils.parse("+(2, 1)", ctx, "LITERAL_INT:1,2"); + stmt = converter.apply(stmt); + stmt2 = converter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt)); + } + + @Test + public void testSimpleEquivalence6() { + RewriterStatement stmt = RewriterUtils.parse("+(*(a, b), *(*(a, b), c))", ctx, "FLOAT:a,b,c"); + RewriterStatement stmt2 = RewriterUtils.parse("+(*(*(a, b), c), *(a, b))", ctx, "FLOAT:a,b,c"); + stmt = converter.apply(stmt); + stmt2 = converter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt)); + } + + @Test + public void testSimpleEquivalence7() { + RewriterStatement stmt = RewriterUtils.parse("+(*(a, b), *(/(a, b), /(b, a)))", ctx, "FLOAT:a,b,c"); + RewriterStatement stmt2 = RewriterUtils.parse("+(*(/(a, b), /(b, a)), *(a, b))", ctx, "FLOAT:a,b,c"); + stmt = converter.apply(stmt); + stmt2 = converter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt)); + } + + @Test + public void testSimpleEquivalence9() { + RewriterStatement stmt = RewriterUtils.parse("+(*(-(a), b), *(a, a))", ctx, "FLOAT:a,b"); + RewriterStatement stmt2 = RewriterUtils.parse("+(*(a, -(b)), *(a, a))", ctx, "FLOAT:a,b"); + stmt = converter.apply(stmt); + stmt2 = converter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt)); + } + + @Test + public void testSimpleEquivalence10() { + RewriterStatement stmt = RewriterUtils.parse("+(argList(*(argList(a,b)),*(argList(a,inv(b),b,inv(a)))))", ctx, "FLOAT:a,b,c"); + RewriterStatement stmt2 = RewriterUtils.parse("+(argList(*(argList(a,inv(b),b,inv(a))),*(argList(a,b))))", ctx, "FLOAT:a,b,c"); + TopologicalSort.sort(stmt, ctx); + TopologicalSort.sort(stmt2, ctx); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt)); + } + + @Test + public void test4() { + RewriterStatement stmt = RewriterUtils.parse("sum(*(A, A))", ctx, "MATRIX:A"); + stmt = converter.apply(stmt); + + LOG.info(stmt.toParsableString(ctx, true)); + } + + @Test + public void test5() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(_idxExpr($1:_idx(1,_EClass(argList(nrow(A),nrow(B)))),*(argList([](B,$1,$1),[](A,$1,$1)))))", ctx, "MATRIX:A,B", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("sum(_idxExpr($1:_idx(1,_EClass(argList(nrow(B),nrow(A)))),*(argList([](B,$1,$1),[](A,$1,$1)))))", ctx, "MATRIX:A,B", "LITERAL_INT:1"); + TopologicalSort.sort(stmt1, ctx); + TopologicalSort.sort(stmt2, ctx); + LOG.info(stmt1.toParsableString(ctx)); + LOG.info(stmt2.toParsableString(ctx)); + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testComplex1() { + RewriterStatement stmt1 = RewriterUtils.parse("_m($1:_idx(1,ncol(V)),$2:_idx(1,ncol(U)),sum(_idxExpr($3:_idx(1,_EClass(argList(nrow(V),nrow(U)))),*(argList([](V,$3,$1),[](U,$3,$2))))))", ctx, "MATRIX:U,V", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("_m($1:_idx(1,ncol(V)),$2:_idx(1,ncol(U)),sum(_idxExpr($3:_idx(1,_EClass(argList(nrow(U),nrow(V)))),*(argList([](U,$3,$2),[](V,$3,$1))))))", ctx, "MATRIX:U,V", "LITERAL_INT:1"); + + TopologicalSort.sort(stmt1, ctx); + TopologicalSort.sort(stmt2, ctx); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testComplex2() { + RewriterStatement stmt1 = RewriterUtils.parse("_m($1:_idx(1,ncol(V)),$2:_idx(1,ncol(U)),sum(_idxExpr($3:_idx(1,_EClass(argList(nrow(V),nrow(U)))),1.0)))", ctx, "MATRIX:U,V", "LITERAL_INT:1", "LITERAL_FLOAT:1.0"); + RewriterStatement stmt2 = RewriterUtils.parse("_m($1:_idx(1,ncol(V)),$2:_idx(1,ncol(U)),sum(_idxExpr($3:_idx(1,_EClass(argList(nrow(U),nrow(V)))),1.0)))", ctx, "MATRIX:U,V", "LITERAL_INT:1", "LITERAL_FLOAT:1.0"); + + TopologicalSort.sort(stmt1, ctx); + TopologicalSort.sort(stmt2, ctx); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testComplex3() { + RewriterStatement stmt1 = RewriterUtils.parse("_m(ncol(V),ncol(U),as.float(_EClass(argList(nrow(V),nrow(U))))))", ctx, "MATRIX:U,V", "LITERAL_INT:1", "LITERAL_FLOAT:1.0"); + RewriterStatement stmt2 = RewriterUtils.parse("_m(ncol(V),ncol(U),as.float(_EClass(argList(nrow(U),nrow(V))))))", ctx, "MATRIX:U,V", "LITERAL_INT:1", "LITERAL_FLOAT:1.0"); + + TopologicalSort.sort(stmt1, ctx); + TopologicalSort.sort(stmt2, ctx); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testSimple() { + RewriterStatement stmt = RewriterUtils.parse("*(argList(a, sum(b), a))", ctx, "FLOAT:a,b"); + TopologicalSort.sort(stmt, ctx); + + String parsableString = stmt.toParsableString(ctx); + LOG.info(parsableString); + assert "*(argList(a,a,sum(b)))".equals(parsableString); + } + + @Test + public void test2() { + RewriterStatement stmt1 = RewriterUtils.parse("+(argList(_EClass(argList(1, ncol(A), ncol(B))), _EClass(argList(nrow(C),nrow(B),nrow(A)))))", ctx, "MATRIX:A,B,C", "LITERAL_INT:1", "LITERAL_FLOAT:1.0"); + RewriterStatement stmt2 = RewriterUtils.parse("+(argList(_EClass(argList(1, ncol(A), ncol(B))), _EClass(argList(nrow(A),nrow(C),nrow(B)))))", ctx, "MATRIX:A,B,C", "LITERAL_INT:1", "LITERAL_FLOAT:1.0"); + + TopologicalSort.sort(stmt1, ctx); + TopologicalSort.sort(stmt2, ctx); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + +} diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/AssertionTests.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/AssertionTests.java new file mode 100644 index 00000000000..6f9db682bcb --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/AssertionTests.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.codegen.rewrite.functions; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.hops.rewriter.assertions.RewriterAssertions; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.apache.sysds.test.component.codegen.rewrite.RewriterTopologySortTests; +import org.junit.BeforeClass; +import org.junit.Test; + +public class AssertionTests { + protected static final Log LOG = LogFactory.getLog(AssertionTests.class.getName()); + + private static RuleContext ctx; + + @BeforeClass + public static void setup() { + ctx = RewriterUtils.buildDefaultContext(); + } + + @Test + public void test1() { + RewriterAssertions assertion = new RewriterAssertions(ctx); + RewriterStatement stmt1 = RewriterUtils.parse("*(*(nrow(A), nrow(B)), *(nrow(C), nrow(A)))", ctx, "MATRIX:A,B,C"); + RewriterStatement nrowA = stmt1.getOperands().get(0).getOperands().get(0); + RewriterStatement nrowB = stmt1.getOperands().get(0).getOperands().get(1); + RewriterStatement nrowC = stmt1.getOperands().get(1).getOperands().get(0); + RewriterStatement nrowA2 = stmt1.getOperands().get(1).getOperands().get(1); + + assert assertion.addEqualityAssertion(nrowA, nrowC, stmt1); + LOG.info(assertion.getAssertions(nrowA)); + + assert !assertion.addEqualityAssertion(nrowA, nrowC, stmt1); + LOG.info(assertion.getAssertions(nrowC)); + + assert assertion.addEqualityAssertion(nrowC, nrowB, stmt1); + LOG.info(assertion.getAssertions(nrowC)); + + LOG.info(assertion.getAssertions(nrowA2)); + } + +} diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CodeExecutionTest.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CodeExecutionTest.java new file mode 100644 index 00000000000..481a896bad5 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CodeExecutionTest.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.codegen.rewrite.functions; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.hops.rewriter.dml.DMLExecutor; +import org.apache.sysds.test.component.codegen.rewrite.RewriterTopologySortTests; +import org.junit.Test; + +public class CodeExecutionTest { + protected static final Log LOG = LogFactory.getLog(CodeExecutionTest.class.getName()); + + @Test + public void test() { + String str = "X = rand(rows=5000, cols=5000, sparsity=0.1)\n" + + "Y = rand(rows=5000, cols=5000, sparsity=0.1)\n" + + "R = X*Y\n" + + "print(lineage(R))"; + DMLScript.APPLY_GENERATED_REWRITES = true; + DMLExecutor.executeCode(str, false, "-applyGeneratedRewrites"); + } +} diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CodeGenConditionTests.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CodeGenConditionTests.java new file mode 100644 index 00000000000..5f0c6da1e0f --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CodeGenConditionTests.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.codegen.rewrite.functions; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.hops.rewriter.rule.RewriterRule; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.apache.sysds.hops.rewriter.codegen.CodeGenCondition; +import org.apache.sysds.test.component.codegen.rewrite.RewriterTopologySortTests; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Function; + +public class CodeGenConditionTests { + protected static final Log LOG = LogFactory.getLog(CodeGenConditionTests.class.getName()); + + private static RuleContext ctx; + + @BeforeClass + public static void setup() { + ctx = RewriterUtils.buildDefaultContext(); + } + + @Test + public void test1() { + String ruleStr = "MATRIX:A\n" + + "\n" + + "t(t(A))\n" + + "=>\n" + + "A"; + + RewriterRule rule = RewriterUtils.parseRule(ruleStr, ctx); + + List cgcs = CodeGenCondition.buildCondition(List.of(rule), 1, ctx); + } + + @Test + public void test2() { + String ruleStr = "MATRIX:A\n" + + "\n" + + "t(t(A))\n" + + "=>\n" + + "A"; + + RewriterRule rule = RewriterUtils.parseRule(ruleStr, ctx); + + String ruleStr2 = "MATRIX:A,B\n" + + "\n" + + "+(t(A), t(B))\n" + + "=>\n" + + "t(+(A, B))"; + + RewriterRule rule2 = RewriterUtils.parseRule(ruleStr2, ctx); + + String ruleStr3 = "MATRIX:A,B\n" + + "\n" + + "%*%(t(A), t(B))\n" + + "=>\n" + + "t(%*%(B, A))"; + + RewriterRule rule3 = RewriterUtils.parseRule(ruleStr3, ctx); + + Map fNames = new HashMap<>(); + fNames.put(rule, "rule1"); + fNames.put(rule2, "rule2"); + fNames.put(rule3, "rule3"); + + List cgcs = CodeGenCondition.buildCondition(List.of(rule, rule2, rule3), 1, ctx); + LOG.info(CodeGenCondition.getSelectionString(cgcs, 0, fNames, ctx)); + } + + @Test + public void test3() { + String ruleStr = "MATRIX:A\nFLOAT:b\n" + + "\n" + + "!=(-(b,rev(A)),A)\n" + + "=>\n" + + "!=(A,-(b,A))"; + + RewriterRule rule = RewriterUtils.parseRule(ruleStr, ctx); + + String ruleStr2 = "MATRIX:A,B\n" + + "\n" + + "!=(-(B,rev(A)),A)\n" + + "=>\n" + + "!=(A,-(B,A))"; + + RewriterRule rule2 = RewriterUtils.parseRule(ruleStr2, ctx); + + String ruleStr3 = "MATRIX:A,B,C\n" + + "\n" + + "+(*(A,C),*(A,B))\n" + + "=>\n" + + "*(A,+(B,C))"; + + RewriterRule rule3 = RewriterUtils.parseRule(ruleStr3, ctx); + + String ruleStr4 = "MATRIX:A,B,C\n" + + "\n" + + "+(*(A,C),*(B,A))\n" + + "=>\n" + + "*(A,+(B,C))"; + + RewriterRule rule4 = RewriterUtils.parseRule(ruleStr4, ctx); + + String ruleStr5 = "MATRIX:B,C\nFLOAT:a\n" + + "\n" + + "+(*(a,C),*(B,a))\n" + + "=>\n" + + "*(a,+(B,C))"; + + RewriterRule rule5 = RewriterUtils.parseRule(ruleStr5, ctx); + + Map fNames = new HashMap<>(); + fNames.put(rule, "rule1"); + fNames.put(rule2, "rule2"); + fNames.put(rule3, "rule3"); + fNames.put(rule4, "rule4"); + fNames.put(rule5, "rule5"); + + List cgcs = CodeGenCondition.buildCondition(List.of(rule, rule2, rule3, rule4, rule5), 1, ctx); + LOG.info(cgcs); + LOG.info(CodeGenCondition.getSelectionString(cgcs, 0, fNames, ctx)); + } +} diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CodeGenTests.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CodeGenTests.java new file mode 100644 index 00000000000..b439b92dd5c --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CodeGenTests.java @@ -0,0 +1,314 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.codegen.rewrite.functions; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.common.Types; +import org.apache.sysds.hops.AggUnaryOp; +import org.apache.sysds.hops.BinaryOp; +import org.apache.sysds.hops.DataGenOp; +import org.apache.sysds.hops.Hop; +import org.apache.sysds.hops.LiteralOp; +import org.apache.sysds.hops.ReorgOp; +import org.apache.sysds.hops.rewrite.HopRewriteUtils; +import org.apache.sysds.hops.rewriter.codegen.RewriterCodeGen; +import org.apache.sysds.hops.rewriter.rule.RewriterRule; +import org.apache.sysds.hops.rewriter.rule.RewriterRuleBuilder; +import org.apache.sysds.hops.rewriter.rule.RewriterRuleSet; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.apache.sysds.parser.DataExpression; +import org.apache.sysds.parser.DataIdentifier; +import org.apache.sysds.test.component.codegen.rewrite.RewriterTopologySortTests; +import org.junit.BeforeClass; +import org.junit.Test; +import scala.Tuple2; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.function.Function; + +public class CodeGenTests { + protected static final Log LOG = LogFactory.getLog(CodeGenTests.class.getName()); + + private static RuleContext ctx; + + @BeforeClass + public static void setup() { + ctx = RewriterUtils.buildDefaultContext(); + } + + @Test + public void test1() { + RewriterStatement stmt1 = RewriterUtils.parse("+(1, 1)", ctx, "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("2", ctx, "LITERAL_INT:2"); + + RewriterRule rule = new RewriterRuleBuilder(ctx, "testRule") + .setUnidirectional(true) + .completeRule(stmt1, stmt2) + .build(); + + LOG.info(RewriterCodeGen.generateClass("MRuleTest", List.of(new Tuple2<>("testRule", rule)), false, false, ctx, false, false)); + + try { + Function f = RewriterCodeGen.compileRewrites("MRuleTest", List.of(new Tuple2<>("testRule", rule)), ctx, false, false); + Hop l = new LiteralOp(1); + Hop add = new BinaryOp("test", Types.DataType.SCALAR, Types.ValueType.INT64, Types.OpOp2.PLUS, l, l); + Hop result = f.apply(add); + + assert result instanceof LiteralOp && ((LiteralOp) result).getLongValue() == 2; + } catch (Exception e) { + e.printStackTrace(); + assert false; + } + } + + @Test + public void test2() { + HashMap vars = new HashMap<>(); + vars.put("A", RewriterUtils.parse("A", ctx, "MATRIX:A")); + vars.put("B", RewriterUtils.parse("B", ctx, "MATRIX:B")); + RewriterStatement stmt1 = RewriterUtils.parse("+(t(A), t(B))", ctx, vars); + RewriterStatement stmt2 = RewriterUtils.parse("t(+(A, B))", ctx, vars); + + RewriterRule rule = new RewriterRuleBuilder(ctx, "testRule") + .setUnidirectional(true) + .completeRule(stmt1, stmt2) + .build(); + + LOG.info(RewriterCodeGen.generateClass("MRuleTest", List.of(new Tuple2<>("testRule", rule)), false, false, ctx, false, false)); + + try { + Function f = RewriterCodeGen.compileRewrites("MRuleTest", List.of(new Tuple2<>("testRule", rule)), ctx, false, false); + HashMap inputParams = new HashMap<>(); + inputParams.put(DataExpression.RAND_ROWS, new LiteralOp(100)); + inputParams.put(DataExpression.RAND_COLS, new LiteralOp(100)); + inputParams.put(DataExpression.RAND_MIN, new LiteralOp(0.0)); + inputParams.put(DataExpression.RAND_MAX, new LiteralOp(1.0)); + Hop A = new DataGenOp(Types.OpOpDG.RAND, new DataIdentifier("A", Types.DataType.MATRIX, Types.ValueType.FP64), inputParams); + Hop B = new DataGenOp(Types.OpOpDG.RAND, new DataIdentifier("B", Types.DataType.MATRIX, Types.ValueType.FP64), inputParams); + Hop tA = new ReorgOp("t", Types.DataType.MATRIX, Types.ValueType.FP64, Types.ReOrgOp.TRANS, A); + Hop tB = new ReorgOp("t", Types.DataType.MATRIX, Types.ValueType.FP64, Types.ReOrgOp.TRANS, B); + Hop add = new BinaryOp("test", Types.DataType.MATRIX, Types.ValueType.FP64, Types.OpOp2.PLUS, tA, tB); + Hop result = f.apply(add); + + assert result instanceof ReorgOp && result.getInput().size() == 1 && result.getInput(0) instanceof BinaryOp; + } catch (Exception e) { + e.printStackTrace(); + assert false; + } + } + + @Test + public void test3() { + HashMap vars = new HashMap<>(); + vars.put("A", RewriterUtils.parse("A", ctx, "MATRIX:A")); + vars.put("B", RewriterUtils.parse("B", ctx, "MATRIX:B")); + RewriterStatement stmt1 = RewriterUtils.parse("^(t(A), t(B))", ctx, vars); + RewriterStatement stmt2 = RewriterUtils.parse("t(^(A, B))", ctx, vars); + + RewriterRule rule = new RewriterRuleBuilder(ctx, "testRule") + .setUnidirectional(true) + .completeRule(stmt1, stmt2) + .build(); + + LOG.info(RewriterCodeGen.generateClass("MRuleTest", List.of(new Tuple2<>("testRule", rule)), false, false, ctx, false, false)); + + try { + Function f = RewriterCodeGen.compileRewrites("MRuleTest", List.of(new Tuple2<>("testRule", rule)), ctx, false, false); + HashMap inputParams = new HashMap<>(); + inputParams.put(DataExpression.RAND_ROWS, new LiteralOp(100)); + inputParams.put(DataExpression.RAND_COLS, new LiteralOp(100)); + inputParams.put(DataExpression.RAND_MIN, new LiteralOp(0.0)); + inputParams.put(DataExpression.RAND_MAX, new LiteralOp(1.0)); + Hop A = new DataGenOp(Types.OpOpDG.RAND, new DataIdentifier("A", Types.DataType.MATRIX, Types.ValueType.FP64), inputParams); + Hop B = new DataGenOp(Types.OpOpDG.RAND, new DataIdentifier("B", Types.DataType.MATRIX, Types.ValueType.FP64), inputParams); + Hop tA = new ReorgOp("t", Types.DataType.MATRIX, Types.ValueType.FP64, Types.ReOrgOp.TRANS, A); + Hop tB = new ReorgOp("t", Types.DataType.MATRIX, Types.ValueType.FP64, Types.ReOrgOp.TRANS, B); + Hop pow = new BinaryOp("test", Types.DataType.MATRIX, Types.ValueType.FP64, Types.OpOp2.POW, tA, tB); + Hop result = f.apply(pow); + + assert result instanceof ReorgOp && result.getInput().size() == 1 && result.getInput(0) instanceof BinaryOp; + } catch (Exception e) { + e.printStackTrace(); + assert false; + } + } + + @Test + public void test4() { + HashMap vars = new HashMap<>(); + vars.put("A", RewriterUtils.parse("A", ctx, "MATRIX:A")); + vars.put("B", RewriterUtils.parse("B", ctx, "MATRIX:B")); + RewriterStatement stmt1 = RewriterUtils.parse("%*%(t(A), t(B))", ctx, vars); + RewriterStatement stmt2 = RewriterUtils.parse("t(%*%(B, A))", ctx, vars); + + RewriterRule rule = new RewriterRuleBuilder(ctx, "testRule") + .setUnidirectional(true) + .completeRule(stmt1, stmt2) + .build(); + + LOG.info(RewriterCodeGen.generateClass("MRuleTest", List.of(new Tuple2<>("testRule", rule)), false, false, ctx, false, false)); + + try { + Function f = RewriterCodeGen.compileRewrites("MRuleTest", List.of(new Tuple2<>("testRule", rule)), ctx, false, false); + HashMap inputParams = new HashMap<>(); + inputParams.put(DataExpression.RAND_ROWS, new LiteralOp(100)); + inputParams.put(DataExpression.RAND_COLS, new LiteralOp(100)); + inputParams.put(DataExpression.RAND_MIN, new LiteralOp(0.0)); + inputParams.put(DataExpression.RAND_MAX, new LiteralOp(1.0)); + Hop A = new DataGenOp(Types.OpOpDG.RAND, new DataIdentifier("A", Types.DataType.MATRIX, Types.ValueType.FP64), inputParams); + Hop B = new DataGenOp(Types.OpOpDG.RAND, new DataIdentifier("B", Types.DataType.MATRIX, Types.ValueType.FP64), inputParams); + Hop tA = new ReorgOp("t", Types.DataType.MATRIX, Types.ValueType.FP64, Types.ReOrgOp.TRANS, A); + Hop tB = new ReorgOp("t", Types.DataType.MATRIX, Types.ValueType.FP64, Types.ReOrgOp.TRANS, B); + Hop matmul = HopRewriteUtils.createMatrixMultiply(tA, tB); + Hop result = f.apply(matmul); + + assert result instanceof ReorgOp && result.getInput().size() == 1 && HopRewriteUtils.isMatrixMultiply(result.getInput(0)); + } catch (Exception e) { + e.printStackTrace(); + assert false; + } + } + + @Test + public void test5() { + HashMap vars = new HashMap<>(); + vars.put("A", RewriterUtils.parse("A", ctx, "MATRIX:A")); + vars.put("B", RewriterUtils.parse("B", ctx, "MATRIX:B")); + RewriterStatement stmt1 = RewriterUtils.parse("rowSums(t(A))", ctx, vars); + RewriterStatement stmt2 = RewriterUtils.parse("t(colSums(A))", ctx, vars); + + RewriterRule rule = new RewriterRuleBuilder(ctx, "testRule") + .setUnidirectional(true) + .completeRule(stmt1, stmt2) + .build(); + + LOG.info(RewriterCodeGen.generateClass("MRuleTest", List.of(new Tuple2<>("testRule", rule)), false, false, ctx, false, false)); + + try { + Function f = RewriterCodeGen.compileRewrites("MRuleTest", List.of(new Tuple2<>("testRule", rule)), ctx, false, false); + HashMap inputParams = new HashMap<>(); + inputParams.put(DataExpression.RAND_ROWS, new LiteralOp(100)); + inputParams.put(DataExpression.RAND_COLS, new LiteralOp(100)); + inputParams.put(DataExpression.RAND_MIN, new LiteralOp(0.0)); + inputParams.put(DataExpression.RAND_MAX, new LiteralOp(1.0)); + Hop A = new DataGenOp(Types.OpOpDG.RAND, new DataIdentifier("A", Types.DataType.MATRIX, Types.ValueType.FP64), inputParams); + Hop tA = new ReorgOp("t", Types.DataType.MATRIX, Types.ValueType.FP64, Types.ReOrgOp.TRANS, A); + Hop rowSums = HopRewriteUtils.createAggUnaryOp(tA, Types.AggOp.SUM, Types.Direction.Row); + Hop result = f.apply(rowSums); + + assert result instanceof ReorgOp && result.getInput().size() == 1 && result.getInput(0) instanceof AggUnaryOp; + } catch (Exception e) { + e.printStackTrace(); + assert false; + } + } + + @Test + public void generateExample() { + String ruleStr = "MATRIX:B\nFLOAT:a,c\n+(a,-(B,c))\n=>\n+(-(a,c),B)"; + RewriterRule rule = RewriterUtils.parseRule(ruleStr, ctx); + RewriterRuleSet rs = new RewriterRuleSet(ctx, List.of(rule)); + RewriterCodeGen.DEBUG = false; + String code = rs.toJavaCode("Test", false, false, true, false); + LOG.info(code); + } + + @Test + public void generateExample2() { + String ruleStr = "MATRIX:A\n+(A,A)\n=>\n*2(A)"; + RewriterRule rule = RewriterUtils.parseRule(ruleStr, ctx); + RewriterRuleSet rs = new RewriterRuleSet(ctx, List.of(rule)); + RewriterCodeGen.DEBUG = false; + String code = rs.toJavaCode("Test", false, false, true, false); + LOG.info(code); + } + + @Test + public void testConditional() { + String ruleStr = "MATRIX:Xm,tmp852\n" + + "FLOAT:tmp65855\n" + + "\n" + + "%*%(t(/(Xm,tmp65855)),tmp852)\n" + + "=>\n" + + "{\n" + + "%*%(t(Xm),/(tmp852,tmp65855))\n" + + "/(%*%(t(Xm),tmp852),tmp65855)\n" + + "t(/(%*%(t(tmp852),Xm),tmp65855))\n" + + "}"; + RewriterRule rule = RewriterUtils.parseRule(ruleStr, ctx); + RewriterRuleSet rs = new RewriterRuleSet(ctx, List.of(rule)); + rs.determineConditionalApplicability(); + RewriterCodeGen.DEBUG = false; + String code = rs.toJavaCode("GeneratedRewriteClass", false, true, true, false); + LOG.info(code); + } + + @Test + public void testLiteral() { + String ruleStr = "MATRIX:A\n" + + "\n" + + "-(+(A, $1:literal.FLOAT()), $2:literal.FLOAT())\n" + + "=>\n" + + "+(A, -($1, $2))"; + RewriterRule rule = RewriterUtils.parseRule(ruleStr, ctx); + RewriterRuleSet rs = new RewriterRuleSet(ctx, List.of(rule)); + rs.determineConditionalApplicability(); + RewriterCodeGen.DEBUG = false; + String code = rs.toJavaCode("GeneratedRewriteClass", false, true, true, false); + LOG.info(code); + } + + @Test + public void testCFold() { + String ruleStr = "LITERAL_FLOAT:1.0,2.0\n" + + "\n" + + "+(1.0,1.0)\n" + + "=>\n" + + "2.0"; + RewriterRule rule = RewriterUtils.parseRule(ruleStr, ctx); + RewriterRuleSet rs = new RewriterRuleSet(ctx, List.of(rule)); + rs.determineConditionalApplicability(); + RewriterCodeGen.DEBUG = false; + String code = rs.toJavaCode("GeneratedRewriteClass", false, true, true, false); + LOG.info(code); + } + + //@Test + public void codeGen() { + List files = List.of("/Users/janniklindemann/Dev/Rewrite-Generator-Reproducibility/data/rules_end_to_end.dml"); + //List files = List.of(RewriteAutomaticallyGenerated.FILE_PATH_MB); + String targetPath = "/Users/janniklindemann/Dev/MScThesis/other/GeneratedRewriteClass.java"; + + try { + // This is to specify that the generated code should print to the console if it modifies the DAG + // This should be disabled when generating production code + RewriterCodeGen.DEBUG = false; + RewriterCodeGen.generateRewritesFromFiles(files, targetPath, true, 3, true, false, ctx); + } catch (IOException e) { + e.printStackTrace(); + } + } + +} diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CostEstimates.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CostEstimates.java new file mode 100644 index 00000000000..dde5f991378 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CostEstimates.java @@ -0,0 +1,395 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.codegen.rewrite.functions; + +import org.apache.commons.lang3.mutable.MutableObject; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.hops.rewriter.rule.RewriterRule; +import org.apache.sysds.hops.rewriter.assertions.RewriterAssertions; +import org.apache.sysds.hops.rewriter.estimators.RewriterCostEstimator; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.apache.sysds.test.component.codegen.rewrite.RewriterTopologySortTests; +import org.junit.BeforeClass; +import org.junit.Test; +import scala.Tuple2; +import scala.Tuple3; + +import java.util.List; +import java.util.Set; +import java.util.function.Function; + +public class CostEstimates { + protected static final Log LOG = LogFactory.getLog(CostEstimates.class.getName()); + + private static RuleContext ctx; + private static Function canonicalConverter; + + @BeforeClass + public static void setup() { + ctx = RewriterUtils.buildDefaultContext(); + canonicalConverter = RewriterUtils.buildCanonicalFormConverter(ctx, true); + } + + @Test + public void test1() { + RewriterStatement stmt = RewriterUtils.parse("%*%(+(A,B), C)", ctx, "MATRIX:A,B,C"); + MutableObject assertionRef = new MutableObject<>(); + long cost1 = RewriterCostEstimator.estimateCost(stmt, ctx, assertionRef); + LOG.info(cost1); + long cost2 = RewriterCostEstimator.estimateCost(stmt.getChild(0), ctx, assertionRef); + LOG.info(cost2); + assert cost2 < cost1; + } + + @Test + public void test2() { + RewriterStatement stmt = RewriterUtils.parse("*(+(1, 1), 2)", ctx, "LITERAL_INT:1,2"); + LOG.info(canonicalConverter.apply(stmt)); + } + + @Test + public void test3() { + RewriterStatement stmt = RewriterUtils.parse("_EClass(argList(1, ncol(X)))", ctx, "LITERAL_INT:1", "MATRIX:X"); + LOG.info(canonicalConverter.apply(stmt)); + } + + @Test + public void test4() { + RewriterStatement stmt1 = RewriterUtils.parse("t(%*%(+(A,B), C))", ctx, "MATRIX:A,B,C"); + RewriterStatement stmt2 = RewriterUtils.parse("%*%(t(C), t(+(A,B)))", ctx, "MATRIX:A,B,C"); + long cost1 = RewriterCostEstimator.estimateCost(stmt1, el -> 2000L, ctx); + long cost2 = RewriterCostEstimator.estimateCost(stmt2, el -> 2000L, ctx); + LOG.info("Cost1: " + cost1); + LOG.info("Cost2: " + cost2); + LOG.info("Ratio: " + ((double)cost1)/cost2); + assert cost1 < cost2; + } + + @Test + public void test5() { + RewriterStatement stmt1 = RewriterUtils.parse("t(/(*(A, B), C))", ctx, "MATRIX:A,B,C"); + RewriterStatement stmt2 = RewriterUtils.parse("/(*(t(A), t(B)), t(C))", ctx, "MATRIX:A,B,C"); + long cost1 = RewriterCostEstimator.estimateCost(stmt1, el -> 2000L, ctx); + long cost2 = RewriterCostEstimator.estimateCost(stmt2, el -> 2000L, ctx); + LOG.info("Cost1: " + cost1); + LOG.info("Cost2: " + cost2); + LOG.info("Ratio: " + ((double)cost1)/cost2); + assert cost1 < cost2; + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void test6() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(+(A, B))", ctx, "MATRIX:A,B,C"); + RewriterStatement stmt2 = RewriterUtils.parse("+(sum(A), sum(B))", ctx, "MATRIX:A,B,C"); + stmt2.givenThatEqualDimensions(stmt2.getChild(0, 0), stmt2.getChild(1, 0), ctx); + long cost1 = RewriterCostEstimator.estimateCost(stmt1, el -> 2000L, ctx); + long cost2 = RewriterCostEstimator.estimateCost(stmt2, el -> 2000L, ctx); + LOG.info("Cost1: " + cost1); + LOG.info("Cost2: " + cost2); + LOG.info("Ratio: " + ((double)cost2)/cost1); + assert cost2 < cost1; + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void test7() { + RewriterStatement stmt1 = RewriterUtils.parse("cast.MATRIX(sum(A))", ctx, "MATRIX:A,B,C"); + RewriterStatement stmt2 = RewriterUtils.parse("rowSums(colSums(A))", ctx, "MATRIX:A,B,C"); + long cost1 = RewriterCostEstimator.estimateCost(stmt1, el -> 2000L, ctx); + long cost2 = RewriterCostEstimator.estimateCost(stmt2, el -> 2000L, ctx); + LOG.info("Cost1: " + cost1); + LOG.info("Cost2: " + cost2); + LOG.info("Ratio: " + ((double)cost1)/cost2); + assert cost1 < cost2; + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void test8() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(*(diag(A), diag(B)))", ctx, "MATRIX:A,B,C"); + RewriterStatement stmt2 = RewriterUtils.parse("trace(*(A, B))", ctx, "MATRIX:A,B,C"); + + long cost1 = RewriterCostEstimator.estimateCost(stmt1, el -> 2000L, ctx); + long cost2 = RewriterCostEstimator.estimateCost(stmt2, el -> 2000L, ctx); + LOG.info("Cost1: " + cost1); + LOG.info("Cost2: " + cost2); + LOG.info("Ratio: " + ((double)cost1)/cost2); + assert cost1 < cost2; + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void test9() { + String stmt1Str = "MATRIX:WM\n" + + "FLOAT:m2X,c19b086e-34d2-46dd-9651-7b6d1d16e459\n" + + "LITERAL_FLOAT:1.0\n" + + "sqrt(*(m2X,/(sum(WM),-(c19b086e-34d2-46dd-9651-7b6d1d16e459,1.0))))"; + String stmt2Str = "MATRIX:1167aa9b-102a-4bae-9801-8b18d210f954\n" + + "FLOAT:m2,41d7e6fb-d4a7-45cf-89cb-cea8ecf3430a\n" + + "LITERAL_FLOAT:1.0\n" + + "sqrt(/(*(m2,sum(1167aa9b-102a-4bae-9801-8b18d210f954)),-(41d7e6fb-d4a7-45cf-89cb-cea8ecf3430a,1.0)))"; + + RewriterStatement stmt1 = RewriterUtils.parse(stmt1Str, ctx); + RewriterStatement stmt2 = RewriterUtils.parse(stmt2Str, ctx); + + long cost1 = RewriterCostEstimator.estimateCost(stmt1, el -> 2000L, ctx); + long cost2 = RewriterCostEstimator.estimateCost(stmt2, el -> 2000L, ctx); + LOG.info("Cost1: " + cost1); + LOG.info("Cost2: " + cost2); + LOG.info("Ratio: " + ((double)cost1)/cost2); + assert cost1 == cost2; + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void test10() { + String stmt1Str = "INT:num_records\n" + + "LITERAL_INT:3\n" + + "*(num_records,3)"; + String stmt2Str = "LITERAL_INT:3\n" + + "INT:run_index\n" + + "*(3,run_index)"; + + RewriterStatement stmt1 = RewriterUtils.parse(stmt1Str, ctx); + RewriterStatement stmt2 = RewriterUtils.parse(stmt2Str, ctx); + + long cost1 = RewriterCostEstimator.estimateCost(stmt1, el -> 2000L, ctx); + long cost2 = RewriterCostEstimator.estimateCost(stmt2, el -> 2000L, ctx); + LOG.info("Cost1: " + cost1); + LOG.info("Cost2: " + cost2); + LOG.info("Ratio: " + ((double)cost1)/cost2); + assert cost1 == cost2; + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void test11() { + String stmtStr1 = "MATRIX:A,p_CG,z\n" + + "FLOAT:trust_delta_sq\n" + + "*(cast.FLOAT(A),cast.FLOAT(%*%(p_CG,z)))"; + String stmtStr2 = "MATRIX:A,p_CG,z\n" + + "FLOAT:trust_delta_sq\n" + + "*(cast.FLOAT(%*%(p_CG,z)),cast.FLOAT(A))"; + + RewriterStatement stmt1 = RewriterUtils.parse(stmtStr1, ctx); + RewriterStatement stmt2 = RewriterUtils.parse(stmtStr2, ctx); + + long cost1 = RewriterCostEstimator.estimateCost(stmt1, el -> 2000L, ctx); + long cost2 = RewriterCostEstimator.estimateCost(stmt2, el -> 2000L, ctx); + LOG.info("Cost1: " + cost1); + LOG.info("Cost2: " + cost2); + LOG.info("Ratio: " + ((double)cost1)/cost2); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void test12() { + String stmtStr1 = "MATRIX:A,B\n" + + "LITERAL_INT:1\n" + + "+([](A, 1, nrow(A), 1, 1),B)"; + String stmtStr2 = "MATRIX:A,B\n" + + "LITERAL_INT:1\n" + + "+([](A, 1, nrow(A), 1, ncol(A)), B)"; + + RewriterStatement stmt1 = RewriterUtils.parse(stmtStr1, ctx); + RewriterStatement stmt2 = RewriterUtils.parse(stmtStr2, ctx); + + long cost1 = RewriterCostEstimator.estimateCost(stmt1, el -> 2000L, ctx); + long cost2 = RewriterCostEstimator.estimateCost(stmt2, el -> 2000L, ctx); + LOG.info("Cost1: " + cost1); + LOG.info("Cost2: " + cost2); + LOG.info("Ratio: " + ((double)cost1)/cost2); + + assert cost1 < cost2; + } + + @Test + public void test13() { + String stmtStr1 = "MATRIX:A,B\n" + + "LITERAL_INT:1\n" + + "[](rowSums(A), 1, nrow(A), 1, 1)"; + String stmtStr2 = "MATRIX:A,B\n" + + "LITERAL_INT:1\n" + + "rowSums(A)"; + + RewriterStatement stmt1 = RewriterUtils.parse(stmtStr1, ctx); + RewriterStatement stmt2 = RewriterUtils.parse(stmtStr2, ctx); + + long cost1 = RewriterCostEstimator.estimateCost(stmt1, el -> 2000L, ctx); + long cost2 = RewriterCostEstimator.estimateCost(stmt2, el -> 2000L, ctx); + LOG.info("Cost1: " + cost1); + LOG.info("Cost2: " + cost2); + LOG.info("Ratio: " + ((double)cost1)/cost2); + + assert cost2 < cost1; + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void test14() { + RewriterStatement stmt1 = RewriterUtils.parse("t(t(A))", ctx, "MATRIX:A"); + MutableObject assertionRef = new MutableObject<>(); + long maxCost = RewriterCostEstimator.estimateCost(stmt1, ctx, assertionRef); + Tuple2, Boolean> allowedCombinations = RewriterCostEstimator.determineSingleReferenceRequirement(stmt1, RewriterCostEstimator.DEFAULT_COST_FN, assertionRef.getValue(), 0, maxCost, ctx); + LOG.info(allowedCombinations._1); + LOG.info("AllowCombinations: " + allowedCombinations._2); + assert allowedCombinations._1.size() == 1; + } + + @Test + public void test15() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(rowSums(A))", ctx, "MATRIX:A"); + RewriterStatement stmt2 = RewriterUtils.parse("sum(A)", ctx, "MATRIX:A"); + MutableObject assertionRef = new MutableObject<>(); + long maxCost = RewriterCostEstimator.estimateCost(stmt1, ctx, assertionRef); + long fullCost = RewriterCostEstimator.estimateCost(stmt2, ctx, assertionRef); + Tuple2, Boolean> allowedCombinations = RewriterCostEstimator.determineSingleReferenceRequirement(stmt1, RewriterCostEstimator.DEFAULT_COST_FN, assertionRef.getValue(), fullCost, maxCost, ctx); + LOG.info(allowedCombinations._1); + LOG.info("AllowCombinations: " + allowedCombinations._2); + assert allowedCombinations._1.isEmpty(); + } + + @Test + public void test16() { + RewriterStatement stmt1 = RewriterUtils.parse("+(colSums(A),[](B,1,1,1,ncol(B)))", ctx, "MATRIX:A,B", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("+(colSums(A),colSums([](B,1,1,1,ncol(B))))", ctx, "MATRIX:A,B", "LITERAL_INT:1"); + long cost1 = RewriterCostEstimator.estimateCost(stmt1, ctx); + long cost2 = RewriterCostEstimator.estimateCost(stmt2, ctx); + assert cost1 < cost2; + } + + @Test + public void test17() { + RewriterStatement stmt1 = RewriterUtils.parse("%*%(colVec(A),B)", ctx, "MATRIX:A,B", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("%*%(colSums(colVec(A)),B)", ctx, "MATRIX:A,B", "LITERAL_INT:1"); + long cost1 = RewriterCostEstimator.estimateCost(stmt1, ctx); + long cost2 = RewriterCostEstimator.estimateCost(stmt2, ctx); + LOG.info("Cost1: " + cost1); + LOG.info("Cost2: " + cost2); + assert cost1 < cost2; + } + + @Test + public void test18() { + String ruleStr = + "MATRIX:tmp55220\n" + + "FLOAT:tmp23781\n" + + "\n" + + "/(t(tmp55220),tmp23781)\n" + + "=>\n" + + "t(/(tmp55220,tmp23781))"; + + RewriterRule rule = RewriterUtils.parseRule(ruleStr, ctx); + + List, Long, Long>> cmp = RewriterCostEstimator.compareCosts(rule.getStmt1(), rule.getStmt2(), rule.getStmt1().getAssertions(ctx), ctx, false, 0, false); + + LOG.info(cmp); + long cost1 = RewriterCostEstimator.estimateCost(rule.getStmt1(), ctx); + long cost2 = RewriterCostEstimator.estimateCost(rule.getStmt2(), ctx); + LOG.info("Cost1: " + cost1); + LOG.info("Cost2: " + cost2); + assert cost1 == cost2; + } + + @Test + public void test19() { + String ruleStr = + "MATRIX:tmp14587,tmp76084\n" + + "FLOAT:one_over_sqrt_two_pi\n" + + "\n" + + "*(tmp14587,/(one_over_sqrt_two_pi,tmp76084))\n" + + "=>\n" + + "/(*(one_over_sqrt_two_pi,tmp14587),tmp76084)"; + + RewriterRule rule = RewriterUtils.parseRule(ruleStr, ctx); + + List, Long, Long>> cmp = RewriterCostEstimator.compareCosts(rule.getStmt1(), rule.getStmt2(), rule.getStmt1().getAssertions(ctx), ctx, false, 0, false); + + LOG.info(cmp); + long cost1 = RewriterCostEstimator.estimateCost(rule.getStmt1(), ctx); + long cost2 = RewriterCostEstimator.estimateCost(rule.getStmt2(), ctx); + LOG.info("Cost1: " + cost1); + LOG.info("Cost2: " + cost2); + assert cost1 == cost2; + } +} diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/DMLCodeGenTest.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/DMLCodeGenTest.java new file mode 100644 index 00000000000..46a6069a7c8 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/DMLCodeGenTest.java @@ -0,0 +1,248 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.codegen.rewrite.functions; + +import org.apache.commons.lang3.mutable.MutableBoolean; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.hops.rewriter.dml.DMLCodeGenerator; +import org.apache.sysds.hops.rewriter.dml.DMLExecutor; +import org.apache.sysds.hops.rewriter.rule.RewriterRule; +import org.apache.sysds.hops.rewriter.rule.RewriterRuleCreator; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.apache.sysds.test.component.codegen.rewrite.RewriterTopologySortTests; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.util.UUID; +import java.util.function.Function; + +public class DMLCodeGenTest { + protected static final Log LOG = LogFactory.getLog(DMLCodeGenTest.class.getName()); + + private static RuleContext ctx; + private static Function canonicalConverter; + + @BeforeClass + public static void setup() { + ctx = RewriterUtils.buildDefaultContext(); + canonicalConverter = RewriterUtils.buildCanonicalFormConverter(ctx, false); + } + + @Test + public void test1() { + RewriterStatement stmt = RewriterUtils.parse("trace(+(A, t(B)))", ctx, "MATRIX:A,B"); + LOG.info(DMLCodeGenerator.generateDML(stmt)); + } + + @Test + public void test2() { + String ruleStr1 = "MATRIX:A\nt(t(A))\n=>\nA"; + String ruleStr2 = "MATRIX:A\nrowSums(t(A))\n=>\nt(colSums(A))"; + RewriterRule rule1 = RewriterUtils.parseRule(ruleStr1, ctx); + RewriterRule rule2 = RewriterUtils.parseRule(ruleStr2, ctx); + + //RewriterRuleSet ruleSet = new RewriterRuleSet(ctx, List.of(rule1, rule2)); + String sessionId = UUID.randomUUID().toString(); + String validationScript = DMLCodeGenerator.generateRuleValidationDML(rule2, DMLCodeGenerator.EPS, sessionId, ctx); + LOG.info("Validation script:"); + LOG.info(validationScript); + MutableBoolean valid = new MutableBoolean(true); + DMLExecutor.executeCode(validationScript, line -> { + if (!line.startsWith(sessionId)) + return; + + if (!line.endsWith("valid: TRUE")) { + DMLExecutor.println("An invalid rule was found!"); + DMLExecutor.println(line); + valid.setValue(false); + } + }); + + LOG.info("Exiting..."); + assert valid.booleanValue(); + } + + @Test + public void test3() { + String ruleStr2 = "MATRIX:A,B\nt(*(A,t(B)))\n=>\n*(t(A),B)"; + RewriterRule rule2 = RewriterUtils.parseRule(ruleStr2, ctx); + + assert RewriterRuleCreator.validateRuleCorrectnessAndGains(rule2, ctx); + } + + @Test + public void test4() { + // Should already be implemented + String ruleStr2 = "MATRIX:A,B\nt(+(A,t(B)))\n=>\n+(t(A),B)"; + RewriterRule rule2 = RewriterUtils.parseRule(ruleStr2, ctx); + + assert RewriterRuleCreator.validateRuleCorrectnessAndGains(rule2, ctx); + } + + @Test + public void test5() { + String ruleStr2 = "MATRIX:A\nLITERAL_FLOAT:1,2\n-(+(1,A), 1)\n=>\n*(1,A)"; + RewriterRule rule2 = RewriterUtils.parseRule(ruleStr2, ctx); + + assert RewriterRuleCreator.validateRuleCorrectnessAndGains(rule2, ctx); + } + @Test + public void test6() { + String ruleStr2 = "MATRIX:?,B\nLITERAL_INT:1,2\n+(?,B)\n=>\n*(1,+(?,B))"; + RewriterRule rule2 = RewriterUtils.parseRule(ruleStr2, ctx); + + assert RewriterRuleCreator.validateRuleCorrectnessAndGains(rule2, ctx); + } + + @Test + public void test7() { + String ruleStr2 = "MATRIX:?,B\nLITERAL_INT:1,2\n+(?,B)\n=>\n*(1,+(?,B))"; + RewriterRule rule2 = RewriterUtils.parseRule(ruleStr2, ctx); + + assert RewriterRuleCreator.validateRuleCorrectnessAndGains(rule2, ctx); + } + + @Test + public void test8() { + String ruleStr = "MATRIX:8cbda53a-49a8-479f-bf34-baeeb1eb8b0f,is_LT_infinite,flip_pos\n" + + "\n" + + "+(%*%(is_LT_infinite,flip_pos),%*%(8cbda53a-49a8-479f-bf34-baeeb1eb8b0f,flip_pos))\n" + + "=>\n" + + "%*%(+(8cbda53a-49a8-479f-bf34-baeeb1eb8b0f,is_LT_infinite),flip_pos)"; + + RewriterRule rule = RewriterUtils.parseRule(ruleStr, ctx); + + assert RewriterRuleCreator.validateRuleCorrectnessAndGains(rule, ctx); + } + + @Test + public void testRev() { + String ruleStr = "MATRIX:A\n" + + "FLOAT:b\n" + + "\n" + + "rev(*(rev(A),b))\n" + + "=>\n" + + "*(A,b)"; + + RewriterRule rule = RewriterUtils.parseRule(ruleStr, ctx); + + LOG.info(DMLCodeGenerator.generateRuleValidationDML(rule, "test", ctx)); + + assert RewriterRuleCreator.validateRuleCorrectness(rule, ctx); + + assert RewriterRuleCreator.validateRuleApplicability(rule, ctx); + } + + @Test + public void testFused1() { + String ruleStr = "MATRIX:A\nLITERAL_FLOAT:0.0\n" + + "sum(!=(0.0,A))\n" + + "=>\n" + + "_nnz(A)"; + + RewriterRule rule = RewriterUtils.parseRule(ruleStr, ctx); + + LOG.info(DMLCodeGenerator.generateRuleValidationDML(rule, "test", ctx)); + + assert RewriterRuleCreator.validateRuleCorrectness(rule, ctx); + + assert RewriterRuleCreator.validateRuleApplicability(rule, ctx); + } + + @Test + public void testFused2() { + String ruleStr = "MATRIX:A,B\nLITERAL_FLOAT:0.0,1.0\n" + + "-(0.0, -(*(A,B), 1.0))\n" + + "=>\n" + + "1-*(A,B)"; + + RewriterRule rule = RewriterUtils.parseRule(ruleStr, ctx); + + LOG.info(DMLCodeGenerator.generateRuleValidationDML(rule, "test", ctx)); + + assert RewriterRuleCreator.validateRuleCorrectness(rule, ctx); + + assert RewriterRuleCreator.validateRuleApplicability(rule, ctx); + } + + @Test + public void testFused3() { + String ruleStr = "MATRIX:A,B\nLITERAL_FLOAT:0.0,1.0\n" + + "+(-(A,B),A)\n" + + "=>\n" + + "-(*2(A), B)"; + + RewriterRule rule = RewriterUtils.parseRule(ruleStr, ctx); + + LOG.info(DMLCodeGenerator.generateRuleValidationDML(rule, "test", ctx)); + + assert RewriterRuleCreator.validateRuleCorrectness(rule, ctx); + + assert RewriterRuleCreator.validateRuleApplicability(rule, ctx); + } + + @Test + public void testFused4() { + String ruleStr = "MATRIX:A,B,C\nLITERAL_FLOAT:0.0,1.0\n" + + "1-*(A, const(A, 0.0))\n" + + "=>\n" + + "const(A, 1.0)"; + + RewriterRule rule = RewriterUtils.parseRule(ruleStr, ctx); + + LOG.info(canonicalConverter.apply(rule.getStmt1()).toParsableString(ctx)); + LOG.info(canonicalConverter.apply(rule.getStmt2()).toParsableString(ctx)); + + //assert rule.getStmt1().match(RewriterStatement.MatcherContext.exactMatch(ctx, rule.getStmt2(), rule.getStmt1())); + + LOG.info(DMLCodeGenerator.generateRuleValidationDML(rule, "test", ctx)); + + assert RewriterRuleCreator.validateRuleCorrectness(rule, ctx); + + // As we have disabled operator fusion + assert !RewriterRuleCreator.validateRuleApplicability(rule, ctx, true, null); + } + + @Test + public void testFused5() { + String ruleStr = "MATRIX:A\n" + + "LITERAL_FLOAT:0.0\n" + + "\n" + + "sum(!=(0.0,A))\n" + + "=>\n" + + "_nnz(A)"; + + RewriterRule rule = RewriterUtils.parseRule(ruleStr, ctx); + + LOG.info(canonicalConverter.apply(rule.getStmt1()).toParsableString(ctx)); + LOG.info(canonicalConverter.apply(rule.getStmt2()).toParsableString(ctx)); + + //assert rule.getStmt1().match(RewriterStatement.MatcherContext.exactMatch(ctx, rule.getStmt2(), rule.getStmt1())); + + LOG.info(DMLCodeGenerator.generateRuleValidationDML(rule, "test", ctx)); + + assert RewriterRuleCreator.validateRuleCorrectness(rule, ctx); + + assert RewriterRuleCreator.validateRuleApplicability(rule, ctx, true, null); + } +} diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/MinimalDifference.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/MinimalDifference.java new file mode 100644 index 00000000000..778ef8ac7d1 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/MinimalDifference.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.codegen.rewrite.functions; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.apache.sysds.test.component.codegen.rewrite.RewriterTopologySortTests; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.util.function.Function; + +public class MinimalDifference { + protected static final Log LOG = LogFactory.getLog(MinimalDifference.class.getName()); + + private static RuleContext ctx; + + @BeforeClass + public static void setup() { + ctx = RewriterUtils.buildDefaultContext(); + } + + @Test + public void test1() { + RewriterStatement stmt1 = RewriterUtils.parse("t(t(A))", ctx, "MATRIX:A"); + RewriterStatement stmt2 = RewriterUtils.parse("t(A)", ctx, "MATRIX:A"); + + RewriterStatement.MatcherContext mCtx = RewriterStatement.MatcherContext.findMinimalDifference(ctx, stmt2, stmt1); + stmt1.match(mCtx); + LOG.info("Minimal Difference: "); + LOG.info(mCtx.getFirstMismatch()._1.toParsableString(ctx)); + LOG.info(mCtx.getFirstMismatch()._2.toParsableString(ctx)); + } + + @Test + public void test2() { + RewriterStatement stmt1 = RewriterUtils.parse("-(A, t(+(A, A)))", ctx, "MATRIX:A"); + RewriterStatement stmt2 = RewriterUtils.parse("-(A, t(*(2, A)))", ctx, "MATRIX:A", "LITERAL_INT:2"); + + RewriterStatement.MatcherContext mCtx = RewriterStatement.MatcherContext.findMinimalDifference(ctx, stmt2, stmt1); + stmt1.match(mCtx); + LOG.info("Minimal Difference: "); + LOG.info(mCtx.getFirstMismatch()._1.toParsableString(ctx)); + LOG.info(mCtx.getFirstMismatch()._2.toParsableString(ctx)); + } +} diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/RewriterSearchUtilsTest.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/RewriterSearchUtilsTest.java new file mode 100644 index 00000000000..4014fc85b5c --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/RewriterSearchUtilsTest.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.codegen.rewrite.functions; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.hops.rewriter.utils.RewriterSearchUtils; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.apache.sysds.test.component.codegen.rewrite.RewriterTopologySortTests; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.util.Arrays; +import java.util.List; +import java.util.function.Function; + +public class RewriterSearchUtilsTest { + protected static final Log LOG = LogFactory.getLog(RewriterSearchUtilsTest.class.getName()); + + private static RuleContext ctx; + private static Function canonicalConverter; + + @BeforeClass + public static void setup() { + ctx = RewriterUtils.buildDefaultContext(); + canonicalConverter = RewriterUtils.buildCanonicalFormConverter(ctx, false); + } + + @Test + public void testDecode1() { + int l = 27; + int n = 5; + int[] digits = RewriterSearchUtils.fromBaseNNumber(l, n); + assert digits.length == 3 && digits[0] == 1 && digits[1] == 0 && digits[2] == 2; + } + + @Test + public void testDecode2() { + int l = 5; + int n = 5; + int[] digits = RewriterSearchUtils.fromBaseNNumber(l, n); + LOG.info(Arrays.toString(digits)); + assert digits.length == 2 && digits[0] == 1 && digits[1] == 0; + } + + @Test + public void testEncode1() { + int[] digits = new int[] { 1, 0, 2 }; + int[] digits2 = new int[] {4, 4, 4}; + int n = 5; + int l = RewriterSearchUtils.toBaseNNumber(digits, n); + int l2 = RewriterSearchUtils.toBaseNNumber(digits2, n); + LOG.info(l); + LOG.info(Integer.toBinaryString(l)); + LOG.info(l2); + LOG.info(Integer.toBinaryString(l2)); + assert l == 27; + } + + @Test + public void testRandomStatementGeneration() { + LOG.info(RewriterSearchUtils.getMaxSearchNumberForNumOps(3)); + int ctr = 0; + for (int i = 0; i < 20; i++) { + List ops = RewriterSearchUtils.decodeOrderedStatements(i); + //LOG.info("Idx: " + i); + //LOG.info(ops); + //LOG.info(RewriterAlphabetEncoder.buildAllPossibleDAGs(ops, ctx, false).size()); + for (RewriterStatement stmt : RewriterSearchUtils.buildAllPossibleDAGs(ops, ctx, true)) { + LOG.info("Base: " + stmt.toParsableString(ctx)); + for (RewriterStatement sstmt : RewriterSearchUtils.buildAssertionVariations(stmt, ctx)) { + canonicalConverter.apply(sstmt); + LOG.info(sstmt.toParsableString(ctx)); + //LOG.info("Raw: " + sstmt); + ctr++; + } + } + } + + LOG.info("Total DAGs: " + ctr); + } + + @Test + public void testRandomStatementGeneration2() { + int ctr = 0; + //for (int i = 0; i < 20; i++) { + List ops = List.of(RewriterSearchUtils.instructionAlphabet[3], RewriterSearchUtils.instructionAlphabet[16], RewriterSearchUtils.instructionAlphabet[6]); + //LOG.info("Idx: " + i); + //LOG.info(ops); + //LOG.info(RewriterAlphabetEncoder.buildAllPossibleDAGs(ops, ctx, false).size()); + for (RewriterStatement stmt : RewriterSearchUtils.buildAllPossibleDAGs(ops, ctx, true)) { + LOG.info("Base: " + stmt.toParsableString(ctx)); + for (RewriterStatement sstmt : RewriterSearchUtils.buildVariations(stmt, ctx)) { + canonicalConverter.apply(sstmt); + LOG.info(sstmt.toParsableString(ctx)); + //LOG.info("Raw: " + sstmt); + ctr++; + } + } + //} + + LOG.info("Total DAGs: " + ctr); + } + + @Test + public void test() { + RewriterStatement stmt = RewriterUtils.parse("+([](A, 1, 1, 1, 1), B)", ctx, "MATRIX:A,B", "LITERAL_INT:1"); + stmt = canonicalConverter.apply(stmt); + LOG.info(stmt.toParsableString(ctx)); + } + +} diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/RuleCreationTests.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/RuleCreationTests.java new file mode 100644 index 00000000000..eabe5138258 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/RuleCreationTests.java @@ -0,0 +1,295 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.codegen.rewrite.functions; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.hops.rewriter.rule.RewriterRule; +import org.apache.sysds.hops.rewriter.rule.RewriterRuleBuilder; +import org.apache.sysds.hops.rewriter.rule.RewriterRuleCreator; +import org.apache.sysds.hops.rewriter.rule.RewriterRuleSet; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.apache.sysds.test.component.codegen.rewrite.RewriterTopologySortTests; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.util.List; +import java.util.function.Function; + +public class RuleCreationTests { + protected static final Log LOG = LogFactory.getLog(RuleCreationTests.class.getName()); + + private static RuleContext ctx; + private static Function canonicalConverter; + + @BeforeClass + public static void setup() { + ctx = RewriterUtils.buildDefaultContext(); + canonicalConverter = RewriterUtils.buildCanonicalFormConverter(ctx, false); + } + + @Test + public void test1() { + RewriterStatement from = RewriterUtils.parse("t(%*%(t(U),V))", ctx, "MATRIX:U,V"); + RewriterStatement to = RewriterUtils.parse("%*%(t(U), V)", ctx, "MATRIX:U,V"); + RewriterStatement canonicalForm1 = canonicalConverter.apply(from); + RewriterStatement canonicalForm2 = canonicalConverter.apply(to); + + LOG.info("=========="); + LOG.info(canonicalForm1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(canonicalForm2.toParsableString(ctx, true)); + assert canonicalForm1.match(RewriterStatement.MatcherContext.exactMatch(ctx, canonicalForm2, canonicalForm1)); + + RewriterRule rule = RewriterRuleCreator.createRule(from, to, canonicalForm1, canonicalForm2, ctx); + LOG.info(rule); + } + + @Test + public void test2() { + RewriterStatement from = RewriterUtils.parse("t(t(A))", ctx, "MATRIX:A"); + RewriterStatement to = RewriterUtils.parse("A", ctx, "MATRIX:A"); + RewriterStatement canonicalForm1 = canonicalConverter.apply(from); + RewriterStatement canonicalForm2 = canonicalConverter.apply(to); + + LOG.info("=========="); + LOG.info(canonicalForm1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(canonicalForm2.toParsableString(ctx, true)); + assert canonicalForm1.match(RewriterStatement.MatcherContext.exactMatch(ctx, canonicalForm2, canonicalForm1)); + + RewriterRule rule = RewriterRuleCreator.createRule(from, to, canonicalForm1, canonicalForm2, ctx); + LOG.info(rule); + + RewriterRuleSet rs = new RewriterRuleSet(ctx, List.of(rule)); + + RewriterStatement testStmt = RewriterUtils.parse("t(t([](A, 1, ncol(A), 1, 1)))", ctx, "MATRIX:A", "LITERAL_INT:1"); + + RewriterRuleSet.ApplicableRule ar = rs.acceleratedFindFirst(testStmt); + + assert ar != null; + } + + @Test + public void validationTest1() { + RewriterRule rule = new RewriterRuleBuilder(ctx) + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A") + .parseGlobalVars("FLOAT:b") + .withParsedStatement("sum(/(A, b))") + .toParsedStatement("/(sum(A), b)") + .build(); + + assert RewriterRuleCreator.validateRuleCorrectnessAndGains(rule, ctx); + } + + @Test + public void validationTest2() { + RewriterRule rule = new RewriterRuleBuilder(ctx) + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .parseGlobalVars("FLOAT:b") + .withParsedStatement("rowSums(colSums(%*%(A, B)))") + .toParsedStatement("%*%(colSums(A), rowSums(B))") + .build(); + + assert RewriterRuleCreator.validateRuleCorrectness(rule, ctx); + assert !RewriterRuleCreator.validateRuleApplicability(rule, ctx); + } + + @Test + public void validationTest3() { + RewriterRule rule = new RewriterRuleBuilder(ctx) + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .withParsedStatement("cast.MATRIX(sum(rowVec(A)))") + .toParsedStatement("rowSums(rowVec(A))") + .build(); + + assert RewriterRuleCreator.validateRuleCorrectness(rule, ctx); + assert !RewriterRuleCreator.validateRuleApplicability(rule, ctx); + } + + @Test + public void test3() { + RewriterStatement from = RewriterUtils.parse("%*%(A,%*%(B,rowVec(C)))", ctx, "MATRIX:A,B,C"); + RewriterStatement to = RewriterUtils.parse("%*%(%*%(A,B),rowVec(C))", ctx, "MATRIX:A,B,C"); + RewriterStatement canonicalForm1 = canonicalConverter.apply(from); + RewriterStatement canonicalForm2 = canonicalConverter.apply(to); + + LOG.info("=========="); + LOG.info(canonicalForm1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(canonicalForm2.toParsableString(ctx, true)); + + assert canonicalForm1.match(RewriterStatement.MatcherContext.exactMatch(ctx, canonicalForm2, canonicalForm1)); + } + + @Test + public void test4() { + RewriterStatement from = RewriterUtils.parse("*(a,0.0)", ctx, "FLOAT:a", "LITERAL_FLOAT:0.0"); + RewriterStatement to = RewriterUtils.parse("0.0", ctx, "LITERAL_FLOAT:0.0"); + RewriterStatement canonicalForm1 = canonicalConverter.apply(from); + RewriterStatement canonicalForm2 = canonicalConverter.apply(to); + + LOG.info("=========="); + LOG.info(canonicalForm1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(canonicalForm2.toParsableString(ctx, true)); + + assert canonicalForm1.match(RewriterStatement.MatcherContext.exactMatch(ctx, canonicalForm2, canonicalForm1)); + + RewriterRule rule = RewriterRuleCreator.createRule(from, to, canonicalForm1, canonicalForm2, ctx); + LOG.info(rule); + + RewriterStatement from2 = RewriterUtils.parse("/(0.0,a)", ctx, "FLOAT:a", "LITERAL_FLOAT:0.0"); + RewriterStatement to2 = RewriterUtils.parse("0.0", ctx, "LITERAL_FLOAT:0.0"); + RewriterStatement canonicalForm12 = canonicalConverter.apply(from2); + RewriterStatement canonicalForm22 = canonicalConverter.apply(to2); + + LOG.info("=========="); + LOG.info(canonicalForm12.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(canonicalForm22.toParsableString(ctx, true)); + + assert canonicalForm12.match(RewriterStatement.MatcherContext.exactMatch(ctx, canonicalForm22, canonicalForm12)); + + RewriterRule rule2 = RewriterRuleCreator.createRule(from2, to2, canonicalForm12, canonicalForm22, ctx); + LOG.info(rule2); + + RewriterRuleSet rs = new RewriterRuleSet(ctx, List.of(rule, rule2)); + + RewriterStatement testStmt = RewriterUtils.parse("/(*(a,0.0), b)", ctx, "FLOAT:a,b", "LITERAL_FLOAT:0.0"); + + RewriterRuleSet.ApplicableRule ar = rs.acceleratedFindFirst(testStmt); + + assert ar != null; + + testStmt = ar.rule.apply(ar.matches.get(0), testStmt, true, false); + + LOG.info("HERE"); + LOG.info(testStmt.toParsableString(ctx)); + + ar = rs.acceleratedFindFirst(testStmt); + + assert ar != null; + + testStmt = ar.rule.apply(ar.matches.get(0), testStmt, true, false); + + LOG.info(testStmt); + } + + @Test + public void test5() { + RewriterRule rule1 = RewriterUtils.parseRule("FLOAT:a\nLITERAL_FLOAT:0.0\n*(a, 0.0)\n=>\n0.0", ctx); + RewriterRule rule2 = RewriterUtils.parseRule("FLOAT:a\nLITERAL_FLOAT:0.0\n/(0.0, a)\n=>\n0.0", ctx); + RewriterRule rule3 = RewriterUtils.parseRule("FLOAT:a,b\nLITERAL_FLOAT:0.0\n/(*(a, 0.0), b)\n=>\n0.0", ctx); + RewriterRuleCreator rc = new RewriterRuleCreator(ctx); + rc.registerRule(rule3, rule3.getStmt1().getCost(ctx), rule3.getStmt2().getCost(ctx), false, canonicalConverter); + rc.registerRule(rule2, rule2.getStmt1().getCost(ctx), rule2.getStmt2().getCost(ctx), false, canonicalConverter); + rc.registerRule(rule1, rule1.getStmt1().getCost(ctx), rule1.getStmt2().getCost(ctx), false, canonicalConverter); + + LOG.info(rc.getRuleSet().serialize()); + } + + @Test + public void test6() { + RewriterStatement from = RewriterUtils.parse("%*%(const(colVec(A),0.0),log_nz(B))", ctx, "MATRIX:A,B", "LITERAL_FLOAT:0.0"); + RewriterStatement to = RewriterUtils.parse("%*%(colVec(A),const(B,0.0))", ctx, "MATRIX:A,B", "LITERAL_FLOAT:0.0"); + RewriterStatement canonicalForm1 = canonicalConverter.apply(from); + RewriterStatement canonicalForm2 = canonicalConverter.apply(to); + + LOG.info("=========="); + LOG.info(canonicalForm1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(canonicalForm2.toParsableString(ctx, true)); + /*LOG.info(canonicalForm1.getChild(1, 1, 0)); + LOG.info(canonicalForm1.getChild(1, 1, 0).getNCol()); + LOG.info(canonicalForm1.getChild(1, 1, 0).getNRow()); + LOG.info(canonicalForm2.getChild(1, 1, 0)); + LOG.info(canonicalForm2.getChild(1, 1, 0).getNCol()); + LOG.info(canonicalForm2.getChild(1, 1, 0).getNRow());*/ + RewriterStatement.MatcherContext mCtx = RewriterStatement.MatcherContext.exactMatch(ctx, canonicalForm2, canonicalForm1); + if (!canonicalForm1.match(mCtx)) { + LOG.info(mCtx.getFirstMismatch()._1); + LOG.info(mCtx.getFirstMismatch()._2); + assert false; + } + + RewriterRule rule = RewriterRuleCreator.createRule(from, to, canonicalForm1, canonicalForm2, ctx); + LOG.info(rule); + } + + @Test + public void testTypeInvariantRuleRegistration() { + RewriterRule rule1 = RewriterUtils.parseRule("FLOAT:a\nLITERAL_FLOAT:0\n*(a,0)\n=>\na", ctx); + RewriterRule rule2 = RewriterUtils.parseRule("INT:a\nLITERAL_INT:0\n*(a,0)\n=>\na", ctx); + RewriterRuleCreator ruleCreator = new RewriterRuleCreator(ctx); + ruleCreator.registerRule(rule1, canonicalConverter, ctx); + + assert !ruleCreator.registerRule(rule2, canonicalConverter, ctx); + } + + @Test + public void testRuleElimination() { + String rs1 = + "MATRIX:tmp34827,tmp40318\n" + + "LITERAL_FLOAT:0.0\n" + + "\n" + + "+(%*%(tmp34827,tmp40318),0.0)\n" + + "=>\n" + + "%*%(tmp34827,tmp40318)"; + String rs2 = + "MATRIX:tmp34827,tmp40318\n" + + "LITERAL_FLOAT:0.0\n" + + "\n" + + "+(tmp34827,0.0)\n" + + "=>\n" + + "tmp34827"; + + RewriterRule rule1 = RewriterUtils.parseRule(rs1, ctx); + RewriterRule rule2 = RewriterUtils.parseRule(rs2, ctx); + + RewriterRuleCreator ruleCreator = new RewriterRuleCreator(ctx); + ruleCreator.registerRule(rule1, canonicalConverter, ctx); + + assert ruleCreator.registerRule(rule2, canonicalConverter, ctx); + LOG.info(ruleCreator.getRuleSet().getRules()); + assert ruleCreator.getRuleSet().getRules().size() == 1; + } + + @Test + public void testExpansiveRule() { + String rs1 = + "MATRIX:A,B\n" + + "LITERAL_FLOAT:0.0\n" + + "\n" + + "+*(A,0.0,B)\n" + + "=>\n" + + "+*(A,0.0,!=(B,B))"; + + RewriterRule rule1 = RewriterUtils.parseRule(rs1, ctx); + + RewriterRuleCreator ruleCreator = new RewriterRuleCreator(ctx); + assert !ruleCreator.registerRule(rule1, canonicalConverter, ctx); + } +} diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/RuleSerializationTest.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/RuleSerializationTest.java new file mode 100644 index 00000000000..d6ae07120f2 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/RuleSerializationTest.java @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.codegen.rewrite.functions; + +import org.apache.commons.lang3.mutable.MutableObject; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.hops.rewriter.assertions.RewriterAssertions; +import org.apache.sysds.hops.rewriter.estimators.RewriterCostEstimator; +import org.apache.sysds.hops.rewriter.rule.RewriterRule; +import org.apache.sysds.hops.rewriter.rule.RewriterRuleCreator; +import org.apache.sysds.hops.rewriter.rule.RewriterRuleSet; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.junit.BeforeClass; +import org.junit.Test; +import scala.Tuple2; + +import java.util.List; +import java.util.Set; +import java.util.function.Function; + +public class RuleSerializationTest { + protected static final Log LOG = LogFactory.getLog(RuleSerializationTest.class.getName()); + + private static RuleContext ctx; + private static Function canonicalConverter; + + @BeforeClass + public static void setup() { + ctx = RewriterUtils.buildDefaultContext(); + canonicalConverter = RewriterUtils.buildCanonicalFormConverter(ctx, false); + } + + @Test + public void test1() { + String ruleStr1 = "MATRIX:A\nt(t(A))\n=>\nA"; + String ruleStr2 = "MATRIX:A\nrowSums(t(A))\n=>\nt(colSums(A))"; + RewriterRule rule1 = RewriterUtils.parseRule(ruleStr1, ctx); + RewriterRule rule2 = RewriterUtils.parseRule(ruleStr2, ctx); + + RewriterRuleSet ruleSet = new RewriterRuleSet(ctx, List.of(rule1, rule2)); + String serialized = ruleSet.serialize(); + + LOG.info(serialized); + + RewriterRuleSet newRuleSet = RewriterRuleSet.deserialize(serialized, ctx); + String newSerialized = newRuleSet.serialize(); + + LOG.info(newSerialized); + + assert serialized.equals(newSerialized); + } + + @Test + public void test2() { + RewriterStatement from = RewriterUtils.parse("t(t(U))", ctx, "MATRIX:U,V"); + RewriterStatement to = RewriterUtils.parse("U", ctx, "MATRIX:U,V"); + RewriterStatement canonicalForm1 = canonicalConverter.apply(from); + RewriterStatement canonicalForm2 = canonicalConverter.apply(to); + + LOG.info("=========="); + LOG.info(canonicalForm1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(canonicalForm2.toParsableString(ctx, true)); + assert canonicalForm1.match(RewriterStatement.MatcherContext.exactMatch(ctx, canonicalForm2, canonicalForm1)); + + RewriterRule rule = RewriterRuleCreator.createRule(from, to, canonicalForm1, canonicalForm2, ctx); + from = rule.getStmt1(); + to = rule.getStmt2(); + + MutableObject assertionRef = new MutableObject<>(); + long fullCost = RewriterCostEstimator.estimateCost(to, ctx); + long maxCost = RewriterCostEstimator.estimateCost(from, ctx, assertionRef); + Tuple2, Boolean> result = RewriterCostEstimator.determineSingleReferenceRequirement(from, RewriterCostEstimator.DEFAULT_COST_FN, assertionRef.getValue(), fullCost, maxCost, ctx); + + assert result._1.size() == 1 && result._2; + + rule.setAllowedMultiReferences(result._1, result._2); + + String serialized = rule.toParsableString(ctx); + + LOG.info("::RULE"); + LOG.info(serialized); + LOG.info(""); + + RewriterRule newRule = RewriterUtils.parseRule(serialized, ctx); + String newSerialized = newRule.toParsableString(ctx); + + LOG.info(newSerialized); + + assert serialized.equals(newSerialized); + } + + @Test + public void test3() { + RewriterStatement from = RewriterUtils.parse("sum(t(U))", ctx, "MATRIX:U,V"); + RewriterStatement to = RewriterUtils.parse("sum(U)", ctx, "MATRIX:U,V"); + RewriterStatement canonicalForm1 = canonicalConverter.apply(from); + RewriterStatement canonicalForm2 = canonicalConverter.apply(to); + + LOG.info("=========="); + LOG.info(canonicalForm1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(canonicalForm2.toParsableString(ctx, true)); + assert canonicalForm1.match(RewriterStatement.MatcherContext.exactMatch(ctx, canonicalForm2, canonicalForm1)); + + RewriterRule rule = RewriterRuleCreator.createRule(from, to, canonicalForm1, canonicalForm2, ctx); + from = rule.getStmt1(); + to = rule.getStmt2(); + + MutableObject assertionRef = new MutableObject<>(); + long fullCost = RewriterCostEstimator.estimateCost(to, ctx); + long maxCost = RewriterCostEstimator.estimateCost(from, ctx, assertionRef); + Tuple2, Boolean> result = RewriterCostEstimator.determineSingleReferenceRequirement(from, RewriterCostEstimator.DEFAULT_COST_FN, assertionRef.getValue(), fullCost, maxCost, ctx); + + assert result._1.size() == 1 && result._2; + + rule.setAllowedMultiReferences(result._1, result._2); + + String serialized = rule.toParsableString(ctx); + + LOG.info("::RULE"); + LOG.info(serialized); + LOG.info(""); + + RewriterRule newRule = RewriterUtils.parseRule(serialized, ctx); + String newSerialized = newRule.toParsableString(ctx); + + LOG.info(newSerialized); + + assert serialized.equals(newSerialized); + } + + @Test + public void test4() { + String ruleStr1 = "MATRIX:W1_rand,tmp29911\n" + + "FLOAT:tmp65095\n" + + "\n" + + "*(tmp65095,%*%(W1_rand,t(tmp29911)))\n" + + "=>\n" + + "{\n" + + "t(%*%(*(tmp65095,tmp29911),t(W1_rand)))\n" + + "%*%(*(tmp65095,W1_rand),t(tmp29911))\n" + + "*(tmp65095,t(%*%(tmp29911,t(W1_rand))))\n" + + "}"; + RewriterRule rule1 = RewriterUtils.parseRule(ruleStr1, ctx); + LOG.info(rule1.toString()); + } +} diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/SparsityEstimationTest.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/SparsityEstimationTest.java new file mode 100644 index 00000000000..63af60ea230 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/SparsityEstimationTest.java @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.codegen.rewrite.functions; + +import org.apache.commons.lang3.mutable.MutableObject; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.hops.rewriter.rule.RewriterRule; +import org.apache.sysds.hops.rewriter.rule.RewriterRuleCreator; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.apache.sysds.hops.rewriter.assertions.RewriterAssertionUtils; +import org.apache.sysds.hops.rewriter.assertions.RewriterAssertions; +import org.apache.sysds.hops.rewriter.estimators.RewriterCostEstimator; +import org.apache.sysds.hops.rewriter.estimators.RewriterSparsityEstimator; +import org.apache.sysds.test.component.codegen.rewrite.RewriterTopologySortTests; +import org.junit.BeforeClass; +import org.junit.Test; +import scala.Tuple3; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Function; + +public class SparsityEstimationTest { + protected static final Log LOG = LogFactory.getLog(SparsityEstimationTest.class.getName()); + + private static RuleContext ctx; + private static Function canonicalConverter; + + @BeforeClass + public static void setup() { + ctx = RewriterUtils.buildDefaultContext(); + canonicalConverter = RewriterUtils.buildCanonicalFormConverter(ctx, false); + } + + @Test + public void test1() { + RewriterStatement stmt = RewriterUtils.parse("+*(A, 0.0, B)", ctx, "MATRIX:A,B", "LITERAL_FLOAT:0.0"); + LOG.info(RewriterSparsityEstimator.estimateNNZ(stmt, ctx).toParsableString(ctx)); + } + + @Test + public void test2() { + RewriterStatement stmt = RewriterUtils.parse("+*(A, a, B)", ctx, "MATRIX:A,B", "FLOAT:a"); + LOG.info(RewriterSparsityEstimator.estimateNNZ(stmt, ctx).toParsableString(ctx)); + } + + @Test + public void test3() { + RewriterStatement stmt = RewriterUtils.parse("%*%(A, -(B, A))", ctx, "MATRIX:A,B", "FLOAT:a"); + RewriterAssertionUtils.buildImplicitAssertions(stmt, stmt.getAssertions(ctx), ctx); + + Map estimates = RewriterSparsityEstimator.estimateAllNNZ(stmt, ctx); + + estimates.forEach((k, v) -> { + stmt.getAssertions(ctx).update(v); + LOG.info("K: " + k.toParsableString(ctx)); + LOG.info("NNZ: " + v.toParsableString(ctx)); + }); + + LOG.info("Rollup: " + RewriterSparsityEstimator.rollupSparsities(estimates.get(stmt), estimates, ctx).toParsableString(ctx)); + + Map nnzs = new HashMap<>(); + nnzs.put(stmt.getChild(0), 3000L); + nnzs.put(stmt.getChild(1, 0), 50000L); + + MutableObject assertionRef = new MutableObject<>(); + RewriterStatement costFunction = RewriterCostEstimator.getRawCostFunction(stmt, ctx, assertionRef, false); + costFunction = RewriterSparsityEstimator.rollupSparsities(costFunction, estimates, ctx); + + LOG.info(costFunction.toParsableString(ctx)); + + LOG.info("Dense cost: " + RewriterCostEstimator.estimateCost(stmt, ctx)); + LOG.info("Sparse cost: " + RewriterCostEstimator.computeCostFunction(costFunction, RewriterCostEstimator.DEFAULT_COST_FN, (el, tpl) -> nnzs.get(el.getChild(0)), assertionRef.getValue(), ctx)); + } + + @Test + public void test4() { + RewriterStatement from = RewriterUtils.parse("+(*(A, B), *(A, C))", ctx, "MATRIX:A,B,C"); + RewriterStatement to = RewriterUtils.parse("*(A, +(B, C))", ctx, "MATRIX:A,B,C"); + RewriterStatement canonicalForm1 = canonicalConverter.apply(from); + RewriterStatement canonicalForm2 = canonicalConverter.apply(to); + + LOG.info("=========="); + LOG.info(canonicalForm1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(canonicalForm2.toParsableString(ctx, true)); + assert canonicalForm1.match(RewriterStatement.MatcherContext.exactMatch(ctx, canonicalForm2, canonicalForm1)); + + RewriterRule rule = RewriterRuleCreator.createRule(from, to, canonicalForm1, canonicalForm2, ctx); + LOG.info(rule); + + RewriterAssertionUtils.buildImplicitAssertion(rule.getStmt1(), rule.getStmt1().getAssertions(ctx), rule.getStmt1(), ctx); + RewriterAssertionUtils.buildImplicitAssertion(rule.getStmt2(), rule.getStmt1().getAssertions(ctx), rule.getStmt2(), ctx); + + RewriterCostEstimator.compareCosts(rule.getStmt1(), rule.getStmt2(), rule.getStmt1().getAssertions(ctx), ctx, true, 5, false); + } + + @Test + public void test5() { + RewriterStatement from = RewriterUtils.parse("t(%*%(t(A), B))", ctx, "MATRIX:A,B,C"); + RewriterStatement to = RewriterUtils.parse("%*%(t(B), A)", ctx, "MATRIX:A,B,C"); + RewriterStatement canonicalForm1 = canonicalConverter.apply(from); + RewriterStatement canonicalForm2 = canonicalConverter.apply(to); + + LOG.info("=========="); + LOG.info(canonicalForm1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(canonicalForm2.toParsableString(ctx, true)); + assert canonicalForm1.match(RewriterStatement.MatcherContext.exactMatch(ctx, canonicalForm2, canonicalForm1)); + + RewriterRule rule = RewriterRuleCreator.createRule(from, to, canonicalForm1, canonicalForm2, ctx); + LOG.info(rule); + + RewriterAssertionUtils.buildImplicitAssertion(rule.getStmt1(), rule.getStmt1().getAssertions(ctx), rule.getStmt1(), ctx); + RewriterAssertionUtils.buildImplicitAssertion(rule.getStmt2(), rule.getStmt1().getAssertions(ctx), rule.getStmt2(), ctx); + //rule.getStmt2().unsafePutMeta("_assertions", rule.getStmt1().getAssertions(ctx)); + + List, Long, Long>> costs = RewriterCostEstimator.compareCosts(rule.getStmt1(), rule.getStmt2(), rule.getStmt1().getAssertions(ctx), ctx, false, 5, false); + LOG.info(costs); + LOG.info("Does sparsity have an impact on optimal expression? >> " + RewriterCostEstimator.doesHaveAnImpactOnOptimalExpression(costs, true, true, 0)); + } + + @Test + public void test6() { + RewriterStatement from = RewriterUtils.parse("t(+(A, B))", ctx, "MATRIX:A,B,C"); + RewriterStatement to = RewriterUtils.parse("+(t(A), t(B))", ctx, "MATRIX:A,B,C"); + RewriterStatement canonicalForm1 = canonicalConverter.apply(from); + RewriterStatement canonicalForm2 = canonicalConverter.apply(to); + + LOG.info("=========="); + LOG.info(canonicalForm1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(canonicalForm2.toParsableString(ctx, true)); + assert canonicalForm1.match(RewriterStatement.MatcherContext.exactMatch(ctx, canonicalForm2, canonicalForm1)); + + RewriterRule rule = RewriterRuleCreator.createRule(from, to, canonicalForm1, canonicalForm2, ctx); + LOG.info(rule); + + RewriterAssertionUtils.buildImplicitAssertion(rule.getStmt1(), rule.getStmt1().getAssertions(ctx), rule.getStmt1(), ctx); + RewriterAssertionUtils.buildImplicitAssertion(rule.getStmt2(), rule.getStmt1().getAssertions(ctx), rule.getStmt2(), ctx); + + List, Long, Long>> costs = RewriterCostEstimator.compareCosts(rule.getStmt1(), rule.getStmt2(), rule.getStmt1().getAssertions(ctx), ctx, false, 5, false); + LOG.info(costs); + LOG.info("Does sparsity have an impact on optimal expression? >> " + RewriterCostEstimator.doesHaveAnImpactOnOptimalExpression(costs, true, true, 0)); + LOG.info("Does anything have an impact on optimal expression? >> " + RewriterCostEstimator.doesHaveAnImpactOnOptimalExpression(costs, true, false, 0)); + } +} diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/SubtreeGeneratorTest.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/SubtreeGeneratorTest.java new file mode 100644 index 00000000000..de672c09ae4 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/SubtreeGeneratorTest.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.codegen.rewrite.functions; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.utils.RewriterSearchUtils; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.apache.sysds.test.component.codegen.rewrite.RewriterTopologySortTests; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.util.List; +import java.util.function.Function; + +public class SubtreeGeneratorTest { + protected static final Log LOG = LogFactory.getLog(SubtreeGeneratorTest.class.getName()); + + private static RuleContext ctx; + + @BeforeClass + public static void setup() { + ctx = RewriterUtils.buildDefaultContext(); + } + + @Test + public void test1() { + RewriterStatement stmt = RewriterUtils.parse("+(1, a)", ctx, "LITERAL_INT:1", "FLOAT:a"); + List subtrees = RewriterSearchUtils.generateSubtrees(stmt, ctx, 100); + + for (RewriterStatement sub : subtrees) { + LOG.info("=========="); + LOG.info(sub.toParsableString(ctx, true)); + } + + assert subtrees.size() == 2; + } + + @Test + public void test2() { + RewriterStatement stmt = RewriterUtils.parse("+(+(1, b), a)", ctx, "LITERAL_INT:1", "FLOAT:a,b"); + List subtrees = RewriterSearchUtils.generateSubtrees(stmt, ctx, 100); + + for (RewriterStatement sub : subtrees) { + LOG.info("=========="); + LOG.info(sub.toParsableString(ctx, true)); + } + + assert subtrees.size() == 3; + } + + @Test + public void test3() { + RewriterStatement stmt = RewriterUtils.parse("-(+(1.0,A),B)", ctx, "LITERAL_FLOAT:1.0", "MATRIX:A,B"); + List subtrees = RewriterSearchUtils.generateSubtrees(stmt, ctx, 100); + + for (RewriterStatement sub : subtrees) { + LOG.info("=========="); + LOG.info(sub.toParsableString(ctx, true)); + } + + assert subtrees.size() == 3; + } +} diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/TestRuleSet.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/TestRuleSet.java new file mode 100644 index 00000000000..0826a81cf51 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/TestRuleSet.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.codegen.rewrite.functions; + +import org.apache.sysds.hops.rewriter.rule.RewriterRule; +import org.apache.sysds.hops.rewriter.rule.RewriterRuleBuilder; +import org.apache.sysds.hops.rewriter.rule.RewriterRuleSet; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.util.List; + +public class TestRuleSet { + private static RuleContext ctx; + + @BeforeClass + public static void setup() { + ctx = RewriterUtils.buildDefaultContext(); + } + + @Test + public void test1() { + RewriterRule rule = new RewriterRuleBuilder(ctx) + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .withParsedStatement("sum(%*%(A, t(B)))") + .toParsedStatement("sum(*(A, B))") + .build(); + + RewriterRuleSet rs = new RewriterRuleSet(ctx, List.of(rule)); + + RewriterStatement stmt = RewriterUtils.parse("sum(%*%(colVec(A), t(colVec(B))))", ctx, "MATRIX:A,B"); + + RewriterRuleSet.ApplicableRule ar = rs.acceleratedFindFirst(stmt); + + assert ar != null; + + stmt = ar.rule.apply(ar.matches.get(0), stmt, ar.forward, false); + } + + @Test + public void test2() { + RewriterRule rule = new RewriterRuleBuilder(ctx) + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .withParsedStatement("as.matrix(sum(colVec(A)))") + .toParsedStatement("rowSums(rowVec(A))") + .build(); + + RewriterRuleSet rs = new RewriterRuleSet(ctx, List.of(rule)); + + RewriterStatement stmt = RewriterUtils.parse("as.matrix(sum(t(rowVec(A))))", ctx, "MATRIX:A,B"); + + RewriterRuleSet.ApplicableRule ar = rs.acceleratedFindFirst(stmt); + + assert ar != null; + + stmt = ar.rule.apply(ar.matches.get(0), stmt, ar.forward, false); + } +} diff --git a/src/test/resources/rewriterframework/expressions.db b/src/test/resources/rewriterframework/expressions.db new file mode 100644 index 00000000000..8b5397f8d4a --- /dev/null +++ b/src/test/resources/rewriterframework/expressions.db @@ -0,0 +1,18610 @@ + +::STMT +MATRIX:prediction,target +LITERAL_FLOAT:1.0 +*(/(1.0,nrow(target)),-(prediction,target)) +::STMT +MATRIX:parsertemp75086 +LITERAL_FLOAT:32.0 +*(parsertemp75086,32.0) +::STMT +LITERAL_FLOAT:1.0 +cast.MATRIX(1.0) +::STMT +MATRIX:y_corr,parsertemp171089,parsertemp171084,parsertemp171095 +FLOAT:float98,float133,float340 +LITERAL_FLOAT:-1.0,1.0,2.0 +*(+(*(sqrt(parsertemp171084),-1.0),/(+(float340,parsertemp171089),+(float98,parsertemp171095))),-(1.0,*(2.0,>(y_corr,float133)))) +::STMT +MATRIX:parsertemp109934 +LITERAL_FLOAT:42.0 +*(parsertemp109934,42.0) +::STMT +MATRIX:posSampleVariances,negSampleVariances +FLOAT:int210,parsertemp31048,parsertemp31047,int867,int429,parsertemp31053,parsertemp31052,int196 +LITERAL_FLOAT:2.0 +/(^(+(/(posSampleVariances,int429),/(negSampleVariances,int210)),2.0),+(/(^(posSampleVariances,int196),*(parsertemp31047,parsertemp31048)),/(^(negSampleVariances,int867),*(parsertemp31052,parsertemp31053)))) +::STMT +MATRIX:X +FLOAT:int40 +LITERAL_FLOAT:1764.0 +sqrt(/(colSums(^(X,int40)),1764.0)) +::STMT +MATRIX:id +diag(diag(==(id,t(id)))) +::STMT +MATRIX:scale_X,z,beta +*(cast.FLOAT(diag(scale_X)),+(cast.FLOAT(beta),cast.FLOAT(z))) +::STMT +MATRIX:X +FLOAT:int459 +LITERAL_FLOAT:1.0,1.0E-6 +/(*(1.0E-6,sum(^(X,int459))),1.0) +::STMT +MATRIX:parsertemp18128,X,parsertemp18133 +FLOAT:int389 +LITERAL_FLOAT:0.0 +rowSums(*(>(%*%(X,parsertemp18128),0.0),t(^(int389,parsertemp18133)))) +::STMT +MATRIX:hubs +FLOAT:parsertemp30953 +LITERAL_FLOAT:2.0 +sum(^(-(/(hubs,parsertemp30953),hubs),2.0)) +::STMT +FLOAT:index +LITERAL_FLOAT:2.0 ++(*(index,2.0),2.0) +::STMT +MATRIX:R,dssp,dsep +FLOAT:4_eAvg +LITERAL_FLOAT:1.0 +-(/(/(+(R,dsep),+(R,dssp)),4_eAvg),1.0) +::STMT +MATRIX:r_LS,parsertemp170556,p_LS,parsertemp170552 +FLOAT:norm_r2_LS,lambda_LS ++(r_LS,*(/(norm_r2_LS,sum(parsertemp170556)),+(%*%(parsertemp170552,p_LS),*(lambda_LS,p_LS)))) +::STMT +MATRIX:X,RMSE +/(RMSE,-(max(X),min(X))) +::STMT +MATRIX:parsertemp472412,fP +FLOAT:max_values,parsertemp472284 +t(<=(parsertemp472412,/(^(parsertemp472284,max_values),ncol(fP)))) +::STMT +MATRIX:ts +FLOAT:q +cast.FLOAT(+(-(q,%*%(ts,ts)),%*%(ts,ts))) +::STMT +MATRIX:Y +FLOAT:x,X +LITERAL_FLOAT:1.0 +*(-(1.0,/(-(x,X),-(X,X))),cast.FLOAT(Y)) +::STMT +MATRIX:X +LITERAL_FLOAT:200.0,2.0 +^(/(t(colSums(X)),200.0),2.0) +::STMT +MATRIX:R +FLOAT:int37,int162 +INT:int981,parsertemp503363 +t(+(R,diag(rand(parsertemp503363,int981,int162,int37)))) +::STMT +MATRIX:y +FLOAT:beta,n +LITERAL_FLOAT:2.0 +/(sum(^(-(beta,y),2.0)),n) +::STMT +FLOAT:g +LITERAL_FLOAT:1.0,2.0 ++(*(g,2.0),1.0) +::STMT +MATRIX:sv,s,w,X,Y,out +FLOAT:step_sz +-(%*%(t(X),*(*(sv,out),Y)),+(w,*(step_sz,s))) +::STMT +MATRIX:parsertemp10744,parsertemp10743,W,H,parsertemp10739 +%*%(W,%*%(*(H,/(parsertemp10739,parsertemp10743)),t(*(H,parsertemp10744)))) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,100.0 +*(*(-(i,1.0),100.0),100.0) +::STMT +MATRIX:Y_counts,Y +/(colSums(Y),sum(Y_counts)) +::STMT +MATRIX:minD,D +/(<=(D,minD),rowSums(<=(D,minD))) +::STMT +MATRIX:parsertemp472317,parsertemp472315,ig +t(rev(*(&(parsertemp472315,parsertemp472317),ig))) +::STMT +FLOAT:factor_up,parsertemp195892 +LITERAL_FLOAT:2.0 +-(*(2.0,factor_up),parsertemp195892) +::STMT +MATRIX:dY,W,Y,sumW +LITERAL_FLOAT:300.0,0.9 +-(*(0.9,dY),*(300.0,-(*(Y,sumW),%*%(W,Y)))) +::STMT +FLOAT:o_init,o +LITERAL_FLOAT:-1.0,2.0 +*(-(*(2.0,o_init),*(2.0,o)),-1.0) +::STMT +MATRIX:parsertemp265709,parsertemp265718 +LITERAL_FLOAT:2.0 +*(2.0,cast.FLOAT(%*%(colSums(parsertemp265718),rowSums(parsertemp265709)))) +::STMT +MATRIX:parsertemp555766,parsertemp555762,target +FLOAT:int381,int17 +sum(-(*(*(target,int17),parsertemp555762),*(-(int381,target),parsertemp555766))) +::STMT +MATRIX:ssX_V,X,P_1K +rowSums(*(P_1K,%*%(X,ssX_V))) +::STMT +LITERAL_FLOAT:8000.0 +8000.0 +::STMT +MATRIX:p,q,lambda,parsertemp116061,parsertemp116062,scale_X,shift_X ++(+(*(scale_X,%*%(parsertemp116061,parsertemp116062)),*(cast.FLOAT(q),shift_X)),*(lambda,p)) +::STMT +MATRIX:ss_avg_res_Y,ss_avg_tot_Y +LITERAL_FLOAT:1.0 +-(1.0,/(ss_avg_res_Y,ss_avg_tot_Y)) +::STMT +MATRIX:Xd,Xu +LITERAL_FLOAT:1.0 +/(1.0,-(Xu,Xd)) +::STMT +MATRIX:Y_counts,parsertemp560521,ent2_vec +sqrt(sum(*(Y_counts,-(ent2_vec,parsertemp560521)))) +::STMT +MATRIX:X,H,parsertemp16755 +LITERAL_FLOAT:0.0,2.0 +*(>(%*%(X,t(H)),0.0),^(2.0,cast.FLOAT(parsertemp16755))) +::STMT +MATRIX:cdf_min_distances +FLOAT:float467,float609 +INT:int767,num_runs +colSums(<(cdf_min_distances,*(rand(int767,num_runs,float609,float467),cdf_min_distances))) +::STMT +MATRIX:WM,Y +/(sum(*(Y,WM)),sum(WM)) +::STMT +MATRIX:scale_lambda +LITERAL_FLOAT:0.0 +*(scale_lambda,0.0) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:2.0 +*(linear_terms,2.0) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,80656.0 ++(*(-(i,1.0),80656.0),1.0) +::STMT +MATRIX:P +LITERAL_FLOAT:4.0 +*(P,4.0) +::STMT +MATRIX:fdom,X,parsertemp1688 ++(X,-(t(parsertemp1688),fdom)) +::STMT +MATRIX:sample_maps,X +LITERAL_FLOAT:2.0 +rowSums(^(%*%(sample_maps,X),2.0)) +::STMT +MATRIX:p,lambda,X +*(p,+(%*%(t(X),%*%(X,p)),*(lambda,p))) +::STMT +MATRIX:Ileft,_funvar2706,_funvar2707 +FLOAT:numI +-(cast.FLOAT(_funvar2706),*(/(rowSums(Ileft),numI),_funvar2707)) +::STMT +MATRIX:parsertemp31029,parsertemp31031 +FLOAT:int871 +LITERAL_FLOAT:149.0,150.0 +/(/(-(colSums(parsertemp31029),*(int871,parsertemp31031)),149.0),150.0) +::STMT +MATRIX:parsertemp130418 +LITERAL_FLOAT:1.0,4.0 ++(*(max(parsertemp130418),4.0),1.0) +::STMT +MATRIX:X +FLOAT:s +LITERAL_FLOAT:0.0 +-(+(nrow(X),0.0),s) +::STMT +MATRIX:parsertemp283570,tpr,fpr,parsertemp283568 +LITERAL_FLOAT:2.0 ++(cast.FLOAT(*(tpr,fpr)),sum(/(*(parsertemp283568,parsertemp283570),2.0))) +::STMT +MATRIX:xs +FLOAT:256_x +LITERAL_FLOAT:1000.0 +-(1000.0,sum(>=(xs,256_x))) +::STMT +MATRIX:parsertemp72182 +LITERAL_FLOAT:8.0 +*(parsertemp72182,8.0) +::STMT +FLOAT:num_centroids +LITERAL_FLOAT:3.0 +*(3.0,num_centroids) +::STMT +MATRIX:scale_X,X,parsertemp274503,parsertemp274506,P_1K +%*%(diag(scale_X),%*%(t(X),-(*(P_1K,parsertemp274503),*(P_1K,parsertemp274506)))) +::STMT +MATRIX:X +FLOAT:n +LITERAL_FLOAT:-1.0 +*(/(t(colSums(X)),n),-1.0) +::STMT +MATRIX:parsertemp42202,F +FLOAT:parsertemp42203,W,int416,meanX +t(*(/(F,-(W,int416)),-(+(parsertemp42202,parsertemp42203),meanX))) +::STMT +LITERAL_FLOAT:6.0,2001.0 +*(6.0,2001.0) +::STMT +MATRIX:parsertemp410987,parsertemp410989,parsertemp410978,W,H,parsertemp410980 +sum(%*%(/(*(W,parsertemp410987),t(parsertemp410989)),/(*(H,parsertemp410978),t(parsertemp410980)))) +::STMT +MATRIX:linear_terms +FLOAT:var_power,link_power,int701 +LITERAL_FLOAT:1.0 +^(linear_terms,-(/(-(int701,var_power),link_power),1.0)) +::STMT +MATRIX:parsertemp149339,parsertemp149335 +FLOAT:int257,obj,parsertemp149332 +LITERAL_FLOAT:0.5 +-(obj,+(+(*(parsertemp149332,int257),sum(parsertemp149335)),*(0.5,sum(parsertemp149339)))) +::STMT +MATRIX:parsertemp107030 +LITERAL_FLOAT:7.0 +*(parsertemp107030,7.0) +::STMT +MATRIX:y_batch,parsertemp459782,parsertemp459784 +FLOAT:loss ++(loss,/(sum(*(parsertemp459782,parsertemp459784)),nrow(y_batch))) +::STMT +MATRIX:parsertemp73634 +LITERAL_FLOAT:16.0 +*(parsertemp73634,16.0) +::STMT +MATRIX:P,Y +LITERAL_FLOAT:1.0 +/(P,+(-(ncol(Y),1.0),1.0)) +::STMT +MATRIX:ss,se,e +LITERAL_FLOAT:1.0,40.0 +-(/(/(se,ss),/(sum(e),40.0)),1.0) +::STMT +FLOAT:parsertemp254715,parsertemp254694,2123_sq_root_d,pp_CG,float162 ++(float162,*(parsertemp254715,/(-(parsertemp254694,2123_sq_root_d),pp_CG))) +::STMT +MATRIX:_sbcvar78,parsertemp22266 +FLOAT:int513 +LITERAL_FLOAT:2.0,10000.0 +/(^(-(_sbcvar78,/(parsertemp22266,int513)),2.0),/(%*%(rowSums(_sbcvar78),colSums(_sbcvar78)),10000.0)) +::STMT +MATRIX:linear_terms +FLOAT:int750,var_power,link_power +LITERAL_FLOAT:2.0 +^(linear_terms,-(/(-(int750,var_power),link_power),2.0)) +::STMT +MATRIX:linear_terms +FLOAT:var_power +LITERAL_FLOAT:1.0,2.0 +/(^(linear_terms,2.0),-(1.0,var_power)) +::STMT +MATRIX:tmp +FLOAT:norm_r2_LS +/(cast.FLOAT(%*%(t(tmp),tmp)),norm_r2_LS) +::STMT +MATRIX:parsertemp556355 +LITERAL_FLOAT:0.125 +*(parsertemp556355,0.125) +::STMT +FLOAT:num_records +LITERAL_FLOAT:1920.0 +/(1920.0,num_records) +::STMT +MATRIX:prec_chol,mu +LITERAL_FLOAT:2.0 +rowSums(*(^(mu,2.0),^(prec_chol,2.0))) +::STMT +LITERAL_FLOAT:100.0 +100.0 +::STMT +LITERAL_FLOAT:105.0 +105.0 +::STMT +LITERAL_FLOAT:81.0 +81.0 +::STMT +LITERAL_FLOAT:80.0 +80.0 +::STMT +LITERAL_FLOAT:127.0 +127.0 +::STMT +LITERAL_FLOAT:120.0 +120.0 +::STMT +MATRIX:parsertemp409212,ctab +LITERAL_FLOAT:0.45 +>(/(parsertemp409212,rowSums(ctab)),0.45) +::STMT +MATRIX:X +LITERAL_FLOAT:0.0 +t(colSums(!=(X,0.0))) +::STMT +MATRIX:negSampleMeans,negSamples +LITERAL_FLOAT:2.0,1500.0 +-(colSums(^(negSamples,2.0)),*(1500.0,^(negSampleMeans,2.0))) +::STMT +MATRIX:totalE,parsertemp31933,X2,parsertemp31935 +t(%*%(t(totalE),==(%*%(X2,parsertemp31935),t(parsertemp31933)))) +::STMT +LITERAL_FLOAT:16.0 +16.0 +::STMT +MATRIX:p,V +%*%(t(V),%*%(V,p)) +::STMT +FLOAT:mu +LITERAL_FLOAT:0.999 +-(0.999,mu) +::STMT +LITERAL_FLOAT:15.0 +15.0 +::STMT +FLOAT:int302,int418 +LITERAL_FLOAT:1.0 ++(+(+(+(int302,int418),1.0),1.0),1.0) +::STMT +MATRIX:subspace_idx,parsertemp73653 +LITERAL_FLOAT:16.0,1.0 +<(-(subspace_idx,*(parsertemp73653,16.0)),1.0) +::STMT +MATRIX:samples_vs_runs_map,centroids,X_samples +LITERAL_FLOAT:2.0 +*(2.0,rowSums(*(X_samples,%*%(samples_vs_runs_map,centroids)))) +::STMT +LITERAL_FLOAT:33.0 +33.0 +::STMT +LITERAL_FLOAT:32.0 +32.0 +::STMT +MATRIX:parsertemp43626 +LITERAL_FLOAT:-1.0 +*(parsertemp43626,-1.0) +::STMT +MATRIX:rowSums_X_sq +FLOAT:D +LITERAL_FLOAT:0.5 +/(*(0.5,sqrt(D)),max(sqrt(rowSums_X_sq))) +::STMT +MATRIX:scale_X,shift_X +FLOAT:r +LITERAL_FLOAT:2.0 +sum(^(+(*(scale_X,r),*(r,shift_X)),2.0)) +::STMT +LITERAL_FLOAT:31.0 +31.0 +::STMT +LITERAL_FLOAT:30.0 +30.0 +::STMT +LITERAL_FLOAT:50.0 +50.0 +::STMT +MATRIX:parsertemp500607,parsertemp500610 +FLOAT:tau +*(tau,sum(abs(*(parsertemp500607,parsertemp500610)))) +::STMT +MATRIX:os,y,o +LITERAL_FLOAT:0.0 +exp(*(-(0.0,y),+(o,os))) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:2.0 +/(^(linear_terms,2.0),2.0) +::STMT +MATRIX:p_LS +FLOAT:norm_r2_LS,parsertemp170552,lambda_LS +*(/(norm_r2_LS,*(cast.FLOAT(p_LS),+(parsertemp170552,lambda_LS))),+(*(cast.FLOAT(parsertemp170552),cast.FLOAT(p_LS)),*(lambda_LS,cast.FLOAT(p_LS)))) +::STMT +MATRIX:b4,2362_2360_Y,W4 +t(+(%*%(W4,t(2362_2360_Y)),b4)) +::STMT +MATRIX:g_new,s,g_old +FLOAT:int686,int503 +*(/(sum(^(g_new,int503)),sum(^(g_old,int686))),s) +::STMT +LITERAL_FLOAT:42.0 +42.0 +::STMT +MATRIX:means,variances +FLOAT:beta +t(-(means,*(beta,variances))) +::STMT +MATRIX:WM,CVars,CFreqs +FLOAT:float270,parsertemp31268,int751,W +LITERAL_FLOAT:1.0 +/(sum(*(-(CFreqs,int751),CVars)),*(-(sum(WM),1.0),/(*(parsertemp31268,W),-(W,float270)))) +::STMT +LITERAL_FLOAT:45.0 +45.0 +::STMT +MATRIX:parsertemp439367,mean,parsertemp439305,weight,parsertemp439306,avgMean +FLOAT:int994 +LITERAL_FLOAT:1.0E-6 ++(+(-(/(parsertemp439367,parsertemp439306),*(int994,avgMean)),/(*(mean,parsertemp439305),t(weight))),1.0E-6) +::STMT +MATRIX:U,X,parsertemp382851 +FLOAT:int910 +t(%*%(t(U),*(!=(X,int910),-(parsertemp382851,X)))) +::STMT +MATRIX:prec_chol,X +LITERAL_FLOAT:2.0 +%*%(rowSums(^(X,2.0)),t(^(prec_chol,2.0))) +::STMT +MATRIX:s,w +FLOAT:lambda,step_sz +*(lambda,+(w,*(step_sz,s))) +::STMT +LITERAL_FLOAT:1000.0 +1000.0 +::STMT +MATRIX:U,V,X +LITERAL_FLOAT:2.0 +^(-(%*%(U,t(V)),X),2.0) +::STMT +MATRIX:S,parsertemp42207 +LITERAL_FLOAT:1.0,2.0 ++(-(parsertemp42207,/(t(S),2.0)),/(1.0,2.0)) +::STMT +MATRIX:parsertemp10744,V,W,H,parsertemp10748 +FLOAT:Eps +/(%*%(V,t(*(H,parsertemp10744))),+(%*%(W,%*%(H,parsertemp10748)),Eps)) +::STMT +MATRIX:ss +LITERAL_FLOAT:0.050000000000000044,1.0,40.0 +*(0.050000000000000044,-(/(40.0,ss),1.0)) +::STMT +MATRIX:W,H,X,parsertemp410997 +-(sum(%*%(W,H)),sum(*(X,parsertemp410997))) +::STMT +MATRIX:mean,parsertemp437225,X,parsertemp437631,weight,parsertemp437222 ++(/(-(%*%(parsertemp437222,X),%*%(parsertemp437225,mean)),sum(weight)),diag(parsertemp437631)) +::STMT +MATRIX:Q3,X,IQR +LITERAL_FLOAT:1.5 +>(X,+(Q3,*(1.5,IQR))) +::STMT +MATRIX:Q1,X,IQR +LITERAL_FLOAT:1.5 +<(X,-(Q1,*(1.5,IQR))) +::STMT +LITERAL_FLOAT:0.0 +INT:int502,int777 +t(rand(int502,int777,0.0,0.0)) +::STMT +LITERAL_FLOAT:0.5000000001 +0.5000000001 +::STMT +MATRIX:C,Xm,parsertemp265701,XtZ +FLOAT:ZtZ_sum +colSums(%*%(%*%(Xm,%*%(C,parsertemp265701)),t(/(XtZ,ZtZ_sum)))) +::STMT +LITERAL_FLOAT:3136.0 +3136.0 +::STMT +MATRIX:d,parsertemp410052,d_r_rev +*(d,t(colSums(*(parsertemp410052,d_r_rev)))) +::STMT +MATRIX:subspace_variance,parsertemp72203 +FLOAT:int677 +LITERAL_FLOAT:1.0 +%*%(t(subspace_variance),diag(/(1.0,<(parsertemp72203,int677)))) +::STMT +FLOAT:parsertemp31330 +LITERAL_FLOAT:1.0,10000.0 +/(*(parsertemp31330,10000.0),-(10000.0,1.0)) +::STMT +MATRIX:ubScores +FLOAT:minsc +LITERAL_FLOAT:0.0 +&(>(ubScores,minsc),>(ubScores,0.0)) +::STMT +MATRIX:parsertemp31105,parsertemp31107 +LITERAL_FLOAT:7.996E9,1999.0,2.0 +/(^(/(-(parsertemp31105,parsertemp31107),1999.0),2.0),7.996E9) +::STMT +LITERAL_FLOAT:254.0 +254.0 +::STMT +LITERAL_FLOAT:255.0 +255.0 +::STMT +LITERAL_FLOAT:300.0 +300.0 +::STMT +MATRIX:p_LS,tmp +FLOAT:norm_r2_LS +/(norm_r2_LS,*(cast.FLOAT(p_LS),cast.FLOAT(tmp))) +::STMT +MATRIX:valueCount,Y +/(t(valueCount),nrow(Y)) +::STMT +MATRIX:selCols2 +sum(!(selCols2)) +::STMT +MATRIX:lambda,B,Grad +LITERAL_FLOAT:2.0 +^(+(Grad,*(lambda,B)),2.0) +::STMT +MATRIX:R,dsep,dssm +/(+(R,dsep),-(R,dssm)) +::STMT +MATRIX:2940_mask,2939_out +LITERAL_FLOAT:0.35 +/(*(2939_out,2940_mask),0.35) +::STMT +MATRIX:r,alpha,Hd +LITERAL_FLOAT:2.0 +^(-(r,*(cast.FLOAT(alpha),Hd)),2.0) +::STMT +MATRIX:Hdiff,betamax,beta +FLOAT:INF,int762,int239 +!=(+(*(>=(Hdiff,int762),betamax),*(<(Hdiff,int239),beta)),INF) +::STMT +MATRIX:out2,parsertemp146940,184_dtemp,W2,W3 +LITERAL_FLOAT:0.0 +%*%(*(>(out2,0.0),%*%(-(184_dtemp,parsertemp146940),t(W3))),t(W2)) +::STMT +MATRIX:W4_rand +LITERAL_FLOAT:0.0873148795050037 +*(0.0873148795050037,W4_rand) +::STMT +MATRIX:w,out +LITERAL_FLOAT:1.0,50.0,0.5 +*(1.0,+(*(0.5,cast.FLOAT(out)),*(50.0,cast.FLOAT(w)))) +::STMT +MATRIX:parsertemp460644 +LITERAL_FLOAT:0.0625 +*(parsertemp460644,0.0625) +::STMT +MATRIX:r,w +FLOAT:tau +LITERAL_FLOAT:0.5 ++(*(0.5,cast.FLOAT(%*%(r,r))),*(tau,sum(abs(w)))) +::STMT +LITERAL_FLOAT:500.0 +500.0 +::STMT +MATRIX:parsertemp31112,parsertemp31114 +LITERAL_FLOAT:1499.0,2.0,3.37275E9 +/(^(/(-(parsertemp31112,parsertemp31114),1499.0),2.0),3.37275E9) +::STMT +MATRIX:S,parsertemp42207 +LITERAL_FLOAT:2.0,0.5 ++(-(parsertemp42207,/(t(S),2.0)),0.5) +::STMT +MATRIX:out,parsertemp2798 +FLOAT:int695,int909,int977,int948 +sum(*(*(>(out,int948),-(int695,parsertemp2798)),*(>(out,int909),-(int977,parsertemp2798)))) +::STMT +MATRIX:parsertemp389760,permut +LITERAL_FLOAT:1.0 +%*%(t(permut),/(-(exp(parsertemp389760),1.0),+(exp(parsertemp389760),1.0))) +::STMT +MATRIX:parsertemp477715,Y,K +FLOAT:X +LITERAL_FLOAT:1.0 +*(-(*(cast.FLOAT(K),-(X,X)),-(cast.FLOAT(Y),cast.FLOAT(Y))),-(1.0,/(cast.FLOAT(parsertemp477715),-(X,X)))) +::STMT +MATRIX:parsertemp222703 +LITERAL_FLOAT:0.0 +==(t(parsertemp222703),0.0) +::STMT +MATRIX:d,parsertemp43998 +FLOAT:int973 +cast.FLOAT(%*%(t(d),+(d,*(int973,parsertemp43998)))) +::STMT +MATRIX:q,r +FLOAT:p,a,norm_r2 +%*%(t(+(r,*(a,q))),+(r,*(/(norm_r2,p),+(q,q)))) +::STMT +MATRIX:m_err +cast.FLOAT(rowSums(colSums(m_err))) +::STMT +FLOAT:m2X +LITERAL_FLOAT:1.000010000100001 +sqrt(*(m2X,1.000010000100001)) +::STMT +MATRIX:g_reg,p_CG +FLOAT:parsertemp170148,int960,q_CG,int952,z,parsertemp170170,pq_CG +*(+(+(*(parsertemp170170,pq_CG),*(z,q_CG)),sum(*(g_reg,p_CG))),/(-(*(z,int952),sqrt(parsertemp170148)),sum(^(p_CG,int960)))) +::STMT +MATRIX:sts,d,parsertemp44021,parsertemp44023 +FLOAT:delta2 +sqrt(+(*(%*%(parsertemp44021,d),%*%(parsertemp44021,d)),*(%*%(parsertemp44023,d),-(delta2,sts)))) +::STMT +FLOAT:offset_x +LITERAL_FLOAT:1.0 +-(1.0,round(offset_x)) +::STMT +MATRIX:parsertemp220853,parsertemp220854,betamax,Hneg,Hpos,beta +FLOAT:INF,logU +LITERAL_FLOAT:0.0 +*(>=(-(+(parsertemp220853,parsertemp220854),logU),0.0),!=(+(*(Hpos,betamax),*(Hneg,beta)),INF)) +::STMT +MATRIX:y_prob,ones_ctg +LITERAL_FLOAT:1.0 +%*%(y_prob,-(1.0,diag(ones_ctg))) +::STMT +MATRIX:A +LITERAL_FLOAT:1.0 ++(+(/(1.0,cast.FLOAT(A)),/(1.0,cast.FLOAT(A))),/(1.0,cast.FLOAT(A))) +::STMT +LITERAL_FLOAT:-0.001 +-0.001 +::STMT +LITERAL_FLOAT:0.001 +0.001 +::STMT +MATRIX:f,I +*(sum(I),max(f)) +::STMT +MATRIX:parsertemp379668 +FLOAT:int826 +LITERAL_FLOAT:1.0,-1.0 +*(sum(-(>=(parsertemp379668,int826),1.0)),-1.0) +::STMT +FLOAT:int713,int28 +LITERAL_FLOAT:0.0 +INT:parsertemp557199,int576 +==(diag(rand(parsertemp557199,int576,int713,int28)),0.0) +::STMT +MATRIX:parsertemp149283,parsertemp149281 +FLOAT:delta2,s2 +LITERAL_FLOAT:2.0 +sqrt(+(^(sum(parsertemp149281),2.0),*(sum(parsertemp149283),-(delta2,s2)))) +::STMT +MATRIX:s,parsertemp44016,d +FLOAT:parsertemp44015 +cast.FLOAT(%*%(t(-(s,parsertemp44016)),-(s,*(parsertemp44015,d)))) +::STMT +MATRIX:b,X +exp(%*%(X,b)) +::STMT +FLOAT:parsertemp41020,m2,int106 +LITERAL_FLOAT:2003.0 +/(sqrt(*(/(int106,parsertemp41020),m2)),sqrt(2003.0)) +::STMT +MATRIX:parsertemp497802,Y +LITERAL_FLOAT:0.0 +*(Y,!=(parsertemp497802,0.0)) +::STMT +MATRIX:p,lambda,scale_X,shift_X +FLOAT:q,norm_r2 +*(/(norm_r2,sum(*(p,q))),+(+(*(scale_X,q),*(q,shift_X)),*(lambda,p))) +::STMT +FLOAT:sample_block_size +LITERAL_FLOAT:1.0,3.0 ++(*(sample_block_size,3.0),1.0) +::STMT +MATRIX:2697_b,parsertemp459149,parsertemp459147 +rowSums(exp(-(+(parsertemp459147,2697_b),parsertemp459149))) +::STMT +MATRIX:output_values,initial_prediction +FLOAT:learning_rate ++(initial_prediction,*(learning_rate,sum(output_values))) +::STMT +FLOAT:m2,float276,int815 +LITERAL_FLOAT:2000.0 +sqrt(*(/(2000.0,-(int815,float276)),m2)) +::STMT +MATRIX:probs,out3,y_batch,184_scores,parsertemp146933 +FLOAT:float988,int950,183_N,int776 +LITERAL_FLOAT:1.0 +*(*(*(/(int950,183_N),-(int776,y_batch)),/(1.0,+(probs,float988))),/(exp(-(out3,parsertemp146933)),rowSums(exp(184_scores)))) +::STMT +MATRIX:parsertemp220853,parsertemp220854,beta +LITERAL_FLOAT:0.0,3.4011973816621555 +*(>=(-(+(parsertemp220853,parsertemp220854),3.4011973816621555),0.0),beta) +::STMT +LITERAL_FLOAT:0.002 +0.002 +::STMT +MATRIX:parsertemp382680,parsertemp382677 +FLOAT:parsertemp382674 +LITERAL_FLOAT:0.5,5.0E-7 ++(*(0.5,parsertemp382674),*(5.0E-7,+(sum(parsertemp382677),sum(parsertemp382680)))) +::STMT +MATRIX:p_LS,X +FLOAT:lambda_LS ++(%*%(%*%(t(X),X),p_LS),*(lambda_LS,p_LS)) +::STMT +LITERAL_FLOAT:8001.0 +8001.0 +::STMT +MATRIX:parsertemp396419,W4_rand +FLOAT:int485,int992 +LITERAL_FLOAT:0.08681986202598489 +%*%(*(0.08681986202598489,W4_rand),t(/(-(parsertemp396419,int992),+(parsertemp396419,int485)))) +::STMT +MATRIX:Y_prob,parsertemp171377,Y,parsertemp171380 +FLOAT:int900 +LITERAL_FLOAT:3.141592653589793 +/(*(rowSums(Y),-(*(Y,Y_prob),*(Y,Y_prob))),*(*(*(parsertemp171377,Y_prob),Y_prob),*(+(int900,parsertemp171380),3.141592653589793))) +::STMT +MATRIX:parsertemp220853,parsertemp220854 +FLOAT:logU +LITERAL_FLOAT:0.0,2.0 +*(2.0,>=(-(+(parsertemp220853,parsertemp220854),logU),0.0)) +::STMT +FLOAT:parsertemp500918,offset_x +-(parsertemp500918,round(offset_x)) +::STMT +FLOAT:var_power,link_power +LITERAL_FLOAT:1.0 +/(-(1.0,var_power),link_power) +::STMT +FLOAT:index +LITERAL_FLOAT:1.0,2.0 ++(+(*(index,2.0),2.0),1.0) +::STMT +MATRIX:Yhat_prime,H3_prime,E,W4 +colSums(*(H3_prime,%*%(*(E,Yhat_prime),W4))) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:750.0 ++(rowSums(classFeatureCounts),750.0) +::STMT +MATRIX:LT,parsertemp149320,parsertemp150469 +rowSums(exp(-(LT,%*%(parsertemp149320,parsertemp150469)))) +::STMT +MATRIX:X,parsertemp429911 +FLOAT:int813,int704 +LITERAL_FLOAT:300.0,2.0 +-(t(colSums(^(X,int813))),*(300.0,^(/(parsertemp429911,int704),2.0))) +::STMT +MATRIX:y_hat,X_adapted +FLOAT:k,parsertemp176418 +>(X_adapted,+(sqrt(parsertemp176418),*(k,y_hat))) +::STMT +MATRIX:y_hat,X_adapted +FLOAT:parsertemp176421,k +<(X_adapted,-(sqrt(parsertemp176421),*(k,y_hat))) +::STMT +FLOAT:int630,i_iter,interval,i_process_item +LITERAL_FLOAT:1.0 +-(i_process_item,+(*(-(i_iter,int630),interval),1.0)) +::STMT +MATRIX:termination_bitmap,final_wcss_successful +LITERAL_FLOAT:1.0,10.0 +*(+(*(10.0,max(final_wcss_successful)),10.0),-(1.0,termination_bitmap)) +::STMT +MATRIX:sig +FLOAT:q,mu +LITERAL_FLOAT:4.0 +/(-(q,*(4.0,*(mu,mu))),*(4.0,*(cast.FLOAT(sig),cast.FLOAT(sig)))) +::STMT +MATRIX:negSampleMeans,negSamples +FLOAT:int169,int697 +LITERAL_FLOAT:1499.0,1500.0 +/(-(colSums(^(negSamples,int169)),*(1500.0,^(negSampleMeans,int697))),1499.0) +::STMT +MATRIX:X +LITERAL_FLOAT:300.0,2.0 +^(/(t(colSums(X)),300.0),2.0) +::STMT +FLOAT:log_l_change +LITERAL_FLOAT:2.0 +*(2.0,abs(log_l_change)) +::STMT +MATRIX:parsertemp132003,parsertemp132023,leftIdx +%*%(parsertemp132023,%*%(t(parsertemp132003),leftIdx)) +::STMT +MATRIX:d,X,logisticD +FLOAT:C +*(C,%*%(t(X),*(logisticD,%*%(X,d)))) +::STMT +MATRIX:parsertemp222700,X,parsertemp222696,parsertemp222693 +LITERAL_FLOAT:-2.0 +<=(+(*(-2.0,%*%(X,parsertemp222693)),t(rowSums(parsertemp222696))),parsertemp222700) +::STMT +MATRIX:X +FLOAT:int617 +LITERAL_FLOAT:0.0 +!=(t(colSums(!=(X,int617))),0.0) +::STMT +MATRIX:ss +FLOAT:alpha +LITERAL_FLOAT:1.0,20.0 +*(-(1.0,alpha),-(/(20.0,ss),1.0)) +::STMT +MATRIX:means,Y +colSums(-(Y,means)) +::STMT +FLOAT:link_power +LITERAL_FLOAT:-2.0 +*(-2.0,link_power) +::STMT +FLOAT:pow_two +LITERAL_FLOAT:2.0 +*(2.0,pow_two) +::STMT +MATRIX:p_CG +FLOAT:parsertemp285739,parsertemp285737,pp_CG +LITERAL_FLOAT:-1.0 +/(-(*(cast.FLOAT(p_CG),-1.0),sqrt(-(parsertemp285737,parsertemp285739))),pp_CG) +::STMT +MATRIX:p,q,V,parsertemp1939 +FLOAT:norm_r2,eps +*(/(norm_r2,cast.FLOAT(%*%(parsertemp1939,q))),+(%*%(t(V),%*%(V,p)),*(eps,p))) +::STMT +MATRIX:linear_terms +FLOAT:link_power,int370 +LITERAL_FLOAT:0.0,1.0 +-(^(+(linear_terms,==(linear_terms,int370)),/(1.0,link_power)),==(linear_terms,0.0)) +::STMT +MATRIX:w,X,y +LITERAL_FLOAT:-1.0 +exp(*(*(y,-1.0),%*%(X,w))) +::STMT +LITERAL_FLOAT:2.0,1500.0 +^(1500.0,2.0) +::STMT +MATRIX:parsertemp132494,rightHist,outBucket +%*%(==(outBucket,t(parsertemp132494)),rightHist) +::STMT +MATRIX:parsertemp174552 +LITERAL_FLOAT:0.0 +abs(==(parsertemp174552,0.0)) +::STMT +MATRIX:y_corr +LITERAL_FLOAT:1.0,2.0,0.5 +-(1.0,*(2.0,>(y_corr,0.5))) +::STMT +FLOAT:ytest,yhat,int56,parsertemp454076,int163 +LITERAL_FLOAT:1.0,2.0 +-(1.0,/(^(-(ytest,yhat),2.0),-(^(ytest,int56),*(int163,parsertemp454076)))) +::STMT +MATRIX:Q1,IQR +FLOAT:k +-(Q1,*(k,IQR)) +::STMT +MATRIX:xs +FLOAT:256_x +LITERAL_FLOAT:1.0,1000.0 ++(-(1000.0,sum(>=(xs,256_x))),1.0) +::STMT +MATRIX:w,parsertemp2794 +LITERAL_FLOAT:2.0,0.5 +*(0.5,sum(^(+(w,parsertemp2794),2.0))) +::STMT +MATRIX:linear_terms,Y +FLOAT:int668 +LITERAL_FLOAT:0.0,1.0 ++(*(linear_terms,-(1.0,==(Y,int668))),==(Y,0.0)) +::STMT +MATRIX:parsertemp410080,d_r_rev,parsertemp410079,parsertemp410090 +LITERAL_FLOAT:-1.0 ++(*(cast.FLOAT(%*%(parsertemp410079,parsertemp410080)),-1.0),sum(*(d_r_rev,parsertemp410090))) +::STMT +MATRIX:parsertemp132003,parsertemp132023,leftIdx +LITERAL_FLOAT:0.0 +>(%*%(parsertemp132023,%*%(t(parsertemp132003),leftIdx)),0.0) +::STMT +MATRIX:parsertemp410987,parsertemp410979,W,parsertemp410981 +/(*(W,parsertemp410987),t(rowSums(/(parsertemp410979,parsertemp410981)))) +::STMT +LITERAL_FLOAT:1.0,4.0 ++(4.0,1.0) +::STMT +MATRIX:D,ZERODIAG,parsertemp220891 +FLOAT:int374,int42 +LITERAL_FLOAT:1.0 +/(*(/(1.0,+(D,int42)),ZERODIAG),sum(*(/(int374,parsertemp220891),ZERODIAG))) +::STMT +MATRIX:184_probs,183_dpred,parsertemp146939,outr2 +LITERAL_FLOAT:2.0 +^(%*%(t(outr2),-(*(183_dpred,184_probs),*(184_probs,parsertemp146939))),2.0) +::STMT +MATRIX:q,r +FLOAT:alpha +sum(*(+(r,*(alpha,q)),+(r,*(alpha,q)))) +::STMT +MATRIX:vb1,parsertemp460691 +FLOAT:lr,mu +-(*(mu,vb1),*(lr,rowSums(parsertemp460691))) +::STMT +FLOAT:obj,obj_new,gs +-(-(obj_new,obj),gs) +::STMT +MATRIX:parsertemp76118 +LITERAL_FLOAT:0.5,4460.0 ++(0.5,/(parsertemp76118,4460.0)) +::STMT +MATRIX:r,parsertemp44050 +sqrt(sum(*(-(r,parsertemp44050),-(r,parsertemp44050)))) +::STMT +FLOAT:padh,Hin +LITERAL_FLOAT:2.0 ++(Hin,*(2.0,padh)) +::STMT +FLOAT:numRows +LITERAL_FLOAT:0.05 +*(0.05,numRows) +::STMT +MATRIX:grad +LITERAL_FLOAT:-1.0 +*(*(grad,-1.0),*(grad,-1.0)) +::STMT +MATRIX:xs +LITERAL_FLOAT:10.0,4.5 +-(10.0,sum(>=(xs,4.5))) +::STMT +MATRIX:parsertemp555766,parsertemp555762,target +LITERAL_FLOAT:-1.0,1.0 +-(*(*(target,-1.0),parsertemp555762),*(-(1.0,target),parsertemp555766)) +::STMT +FLOAT:191_beta2,191_t,191_lr +LITERAL_FLOAT:1.0 +*(191_lr,sqrt(-(1.0,^(191_beta2,191_t)))) +::STMT +MATRIX:w,X,y +sum(*(-(%*%(X,w),y),-(%*%(X,w),y))) +::STMT +LITERAL_FLOAT:0.08720414403938946 +0.08720414403938946 +::STMT +MATRIX:simplex +-(rowSums(simplex),simplex) +::STMT +MATRIX:scale_lambda +LITERAL_FLOAT:1.0E-7 +diag(*(scale_lambda,1.0E-7)) +::STMT +MATRIX:g +FLOAT:lambda +LITERAL_FLOAT:2.0 +sqrt(sum(^(+(g,lambda),2.0))) +::STMT +MATRIX:X,y +FLOAT:int442 +LITERAL_FLOAT:0.0 +INT:m,int706 +-(%*%(X,rand(m,int706,0.0,int442)),y) +::STMT +MATRIX:parsertemp77570 +LITERAL_FLOAT:0.5,2358.0 ++(0.5,/(parsertemp77570,2358.0)) +::STMT +MATRIX:p,q,r,lambda +FLOAT:norm_r2 ++(r,*(/(norm_r2,cast.FLOAT(p)),+(q,*(lambda,p)))) +::STMT +MATRIX:feature +LITERAL_FLOAT:1.0 +-(1.0,min(feature)) +::STMT +MATRIX:n_event_stratum,n_risk_stratum,n_risk +LITERAL_FLOAT:2.0 +*(^(n_risk_stratum,2.0),*(n_risk,n_event_stratum)) +::STMT +MATRIX:Y +FLOAT:check_max,check_min +LITERAL_FLOAT:2.0 +*(/(2.0,-(check_max,check_min)),Y) +::STMT +MATRIX:U,X,parsertemp382669 +LITERAL_FLOAT:0.0,2.0 +*(!=(X,0.0),^(-(%*%(U,parsertemp382669),X),2.0)) +::STMT +FLOAT:idx +LITERAL_FLOAT:256.0 +-(256.0,idx) +::STMT +MATRIX:paramLens,parsertemp387457 +rev(/(parsertemp387457,rev(paramLens))) +::STMT +MATRIX:p_CG,z +FLOAT:trust_delta_sq +-(*(cast.FLOAT(%*%(p_CG,z)),cast.FLOAT(%*%(p_CG,z))),*(cast.FLOAT(%*%(p_CG,p_CG)),-(cast.FLOAT(z),trust_delta_sq))) +::STMT +MATRIX:X,H,parsertemp18133 +LITERAL_FLOAT:0.0,2.0 +*(>(%*%(X,t(H)),0.0),t(^(2.0,parsertemp18133))) +::STMT +MATRIX:parsertemp429918,parsertemp429916,parsertemp429914 +FLOAT:int453,int941 +LITERAL_FLOAT:0.0,1.0,299.0 +*(/(-(t(parsertemp429914),*(int453,parsertemp429916)),299.0),-(1.0,<=(/(parsertemp429918,int941),0.0))) +::STMT +FLOAT:idx +LITERAL_FLOAT:253.0 +-(253.0,idx) +::STMT +MATRIX:parsertemp175075,parsertemp175079,X,R1 +-(R1,/(exp(-(X,parsertemp175075)),rowSums(exp(parsertemp175079)))) +::STMT +FLOAT:522_strideh,522_padh,522_Hin,int470 +LITERAL_FLOAT:1.0 +/(-(+(522_Hin,*(int470,522_padh)),1.0),522_strideh) +::STMT +MATRIX:WM +FLOAT:parsertemp31268 +LITERAL_FLOAT:1.0 +/(*(parsertemp31268,sum(WM)),-(sum(WM),1.0)) +::STMT +LITERAL_FLOAT:0.0,1.0 +INT:n_components,parsertemp506195 +/(rand(parsertemp506195,n_components,0.0,1.0),rowSums(rand(parsertemp506195,n_components,0.0,1.0))) +::STMT +FLOAT:covXY +covXY +::STMT +MATRIX:is_row_in_samples,parsertemp76114 +LITERAL_FLOAT:13381.0 +-(13381.0,*(is_row_in_samples,parsertemp76114)) +::STMT +MATRIX:output,output1 +LITERAL_FLOAT:0.1 +<(abs(-(output,output1)),0.1) +::STMT +MATRIX:prec,X,mu +LITERAL_FLOAT:2.0 +rowSums(^(-(%*%(X,prec),%*%(mu,prec)),2.0)) +::STMT +LITERAL_FLOAT:1.0,100.0 +-(100.0,1.0) +::STMT +MATRIX:parsertemp222310 +FLOAT:parsertemp222313 +LITERAL_FLOAT:0.5 +round(+(/(parsertemp222310,parsertemp222313),0.5)) +::STMT +MATRIX:resp,X,parsertemp437188 +FLOAT:float168 +LITERAL_FLOAT:2.0 +^(/(%*%(t(resp),X),t(+(parsertemp437188,float168))),2.0) +::STMT +MATRIX:y_residual,ytest +LITERAL_FLOAT:2.0 +*($1:nrow(ytest),^(/(sum(y_residual),$1),2.0)) +::STMT +LITERAL_FLOAT:5.0E-4 +5.0E-4 +::STMT +MATRIX:316_unnorm_probs,parsertemp175059 +LITERAL_FLOAT:1.0E-6 +<(abs(-(/(316_unnorm_probs,parsertemp175059),/(316_unnorm_probs,parsertemp175059))),1.0E-6) +::STMT +MATRIX:X +LITERAL_FLOAT:0.0 +t(-(0.0,t(colSums(X)))) +::STMT +MATRIX:y_train,prediction +FLOAT:float477 +/(sum(==(y_train,>(prediction,float477))),nrow(y_train)) +::STMT +MATRIX:z +FLOAT:trust_delta_sq +LITERAL_FLOAT:2.0 +-(^(cast.FLOAT(z),2.0),trust_delta_sq) +::STMT +MATRIX:e_r_rev_agg,d_r_rev,X_agg +t(colSums(/(*(X_agg,d_r_rev),e_r_rev_agg))) +::STMT +MATRIX:parsertemp222327,is_row_in_samples +FLOAT:sample_block_size,num_samples +LITERAL_FLOAT:1.0 +-(+(*(sample_block_size,num_samples),1.0),*(is_row_in_samples,parsertemp222327)) +::STMT +FLOAT:m2Y,sigmaX +LITERAL_FLOAT:1.0002795638803466 +*(sigmaX,sqrt(*(m2Y,1.0002795638803466))) +::STMT +MATRIX:X,permut +FLOAT:n +/(colSums(%*%(permut,X)),n) +::STMT +LITERAL_FLOAT:1.0E-10 +1.0E-10 +::STMT +MATRIX:output_values +LITERAL_FLOAT:0.3 +*(0.3,sum(output_values)) +::STMT +LITERAL_FLOAT:1.0,-1.0 +*(1.0,-1.0) +::STMT +MATRIX:Q,V,X,P_1K +%*%(t(X),-(*(P_1K,%*%(X,V)),*(P_1K,rowSums(Q)))) +::STMT +MATRIX:prec +diag(t(prec)) +::STMT +LITERAL_FLOAT:1.0,5.0 ++(5.0,1.0) +::STMT +LITERAL_FLOAT:0.0 +cast.MATRIX(0.0) +::STMT +MATRIX:parsertemp382680,col_nonzeros,parsertemp382677,row_nonzeros +FLOAT:reg +LITERAL_FLOAT:0.5 +*(*(0.5,reg),+(sum(*(parsertemp382677,row_nonzeros)),sum(*(parsertemp382680,col_nonzeros)))) +::STMT +MATRIX:d_r,parsertemp409781 +%*%(t(rev(d_r)),parsertemp409781) +::STMT +MATRIX:B,X,y +FLOAT:intercept +-(y,+(%*%(X,B),intercept)) +::STMT +MATRIX:A,scale_X,shift_X +t(+(%*%(diag(scale_X),A),%*%(shift_X,A))) +::STMT +FLOAT:idx +LITERAL_FLOAT:1.0,64.0 ++(-(64.0,idx),1.0) +::STMT +MATRIX:g_new,parsertemp2824,s,parsertemp2826 ++(*(/(sum(parsertemp2824),sum(parsertemp2826)),s),g_new) +::STMT +MATRIX:linear_terms +FLOAT:link_power +LITERAL_FLOAT:0.0,1.0 +^(+(linear_terms,==(linear_terms,0.0)),/(1.0,link_power)) +::STMT +MATRIX:parsertemp171600,g_Y,lambda,scale_X,shift_X,gXY,beta ++(+(%*%(diag(scale_X),%*%(parsertemp171600,g_Y)),%*%(shift_X,gXY)),*(lambda,beta)) +::STMT +MATRIX:s,parsertemp44016,d +FLOAT:delta2 +*(%*%(t(d),d),-(delta2,%*%(t(s),-(s,parsertemp44016)))) +::STMT +LITERAL_FLOAT:0.0 +INT:int1,int961 +exp(rand(int1,int961,0.0,0.0)) +::STMT +MATRIX:V,X,P_1K +*(P_1K,rowSums(*(P_1K,%*%(X,V)))) +::STMT +MATRIX:W2_rand +LITERAL_FLOAT:0.6546536707079771 +*(0.6546536707079771,W2_rand) +::STMT +MATRIX:linear_terms +FLOAT:var_power +LITERAL_FLOAT:2.0 +*(linear_terms,-(2.0,var_power)) +::STMT +MATRIX:X2 +LITERAL_FLOAT:4.0 +<(t(colSums(X2)),4.0) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0 +-(+(i,1.0),1.0) +::STMT +FLOAT:approx_sample_size +LITERAL_FLOAT:10.0 ++(approx_sample_size,round(*(10.0,sqrt(approx_sample_size)))) +::STMT +MATRIX:B +LITERAL_FLOAT:4.0 +-(4.0,nrow(B)) +::STMT +FLOAT:dist +t(cast.MATRIX(dist)) +::STMT +MATRIX:num_std +t(sqrt(num_std)) +::STMT +MATRIX:var_X_cols,tmp +FLOAT:int300,int338,int958,N +LITERAL_FLOAT:0.0,1.0 ++(*(/(tmp,-(N,int338)),-(1.0,<=(var_X_cols,int958))),<=(/(tmp,-(N,int300)),0.0)) +::STMT +LITERAL_FLOAT:1.0E-12 +1.0E-12 +::STMT +FLOAT:float824,int237,float466,int581 +LITERAL_FLOAT:1.0,3.0,6.0,2000.0 +/(*(*(6.0,2000.0),-(2000.0,1.0)),*(*(-(int237,float466),+(int581,float824)),+(2000.0,3.0))) +::STMT +MATRIX:is_LT_infinite,Y_prob +LITERAL_FLOAT:1.0 ++(*(Y_prob,-(1.0,rowSums(is_LT_infinite))),is_LT_infinite) +::STMT +MATRIX:means,Y_counts,Y,parsertemp560602 +-(-(Y,means),%*%(Y_counts,/(colSums(parsertemp560602),sum(Y_counts)))) +::STMT +MATRIX:parsertemp382947 +FLOAT:reg,parsertemp382956,loss_init,parsertemp382953,float925 +LITERAL_FLOAT:0.5 +-(loss_init,+(*(0.5,sum(parsertemp382947)),*(*(float925,reg),+(parsertemp382953,parsertemp382956)))) +::STMT +FLOAT:index +LITERAL_FLOAT:2.0,4.0 ++(*(index,4.0),2.0) +::STMT +MATRIX:R +FLOAT:i8 +LITERAL_FLOAT:24.0 +-(nrow(R),*(24.0,i8)) +::STMT +MATRIX:parsertemp436114 +FLOAT:int359,int471 +INT:2663_2662_n_col,int558 +*(cast.FLOAT(parsertemp436114),rand(int558,2663_2662_n_col,int359,int471)) +::STMT +FLOAT:parsertemp83 +abs(-(cast.MATRIX(parsertemp83),parsertemp83)) +::STMT +MATRIX:parsertemp31112,parsertemp31114 +FLOAT:int597,int905 +LITERAL_FLOAT:1.0,2.0,1500.0 +/(^(/(-(parsertemp31112,parsertemp31114),-(int905,int597)),2.0),*(^(1500.0,2.0),-(1500.0,1.0))) +::STMT +MATRIX:Ileft,Iright +FLOAT:min_leaf +&(>=(rowSums(Ileft),min_leaf),>=(rowSums(Iright),min_leaf)) +::STMT +MATRIX:codebook +FLOAT:j +LITERAL_FLOAT:1.0 +*(-(j,1.0),ncol(codebook)) +::STMT +MATRIX:parsertemp429916,parsertemp429914 +FLOAT:int441 +LITERAL_FLOAT:0.0,299.0 +<=(/(-(t(parsertemp429914),*(int441,parsertemp429916)),299.0),0.0) +::STMT +MATRIX:subspace_idx,parsertemp107049 +LITERAL_FLOAT:1.0,7.0 +<(-(subspace_idx,*(parsertemp107049,7.0)),1.0) +::STMT +MATRIX:W1_rand +LITERAL_FLOAT:0.08146881698903526 +*(0.08146881698903526,W1_rand) +::STMT +MATRIX:R +LITERAL_FLOAT:0.0 +cast.FLOAT(==(R,0.0)) +::STMT +MATRIX:parsertemp10743,V,parsertemp10742,H,parsertemp10739,parsertemp10738 +FLOAT:Eps +%*%(*(H,/(%*%(parsertemp10738,V),+(parsertemp10742,Eps))),t(*(H,/(parsertemp10739,parsertemp10743)))) +::STMT +MATRIX:P,Y,dP +sum(&(<=(P,dP),!(Y))) +::STMT +MATRIX:distances,ksmall +FLOAT:int819,int751 +LITERAL_FLOAT:0.0 +INT:parsertemp557199,int480 +*(<=(distances,ksmall),==(diag(rand(parsertemp557199,int480,int819,int751)),0.0)) +::STMT +FLOAT:2690_Hin,parsertemp459058 +LITERAL_FLOAT:1.0,2.0 ++(/(-(+(2690_Hin,parsertemp459058),2.0),2.0),1.0) +::STMT +FLOAT:252_Y,float605,int241,252_X,252_K,float60 +LITERAL_FLOAT:1.0 ++(*(-(*(252_K,252_X),-(252_Y,252_Y)),-(1.0,/(float605,252_X))),*(+(*(int241,252_X),-(252_Y,252_Y)),/(-(float60,252_X),-(252_X,252_X)))) +::STMT +MATRIX:Y,parsertemp2798,Xw +FLOAT:int496,int812 +LITERAL_FLOAT:0.0,1.0 +*(*(>(-(int496,parsertemp2798),0.0),-(1.0,*(Y,Xw))),*(>(-(int812,parsertemp2798),0.0),-(1.0,*(Y,Xw)))) +::STMT +MATRIX:2364_2359_Y_prime,W2,2364_2358_Y,parsertemp389612 +FLOAT:int492 +LITERAL_FLOAT:1.0 +t(*(-(1.0,^(2364_2358_Y,int492)),%*%(*(2364_2359_Y_prime,parsertemp389612),W2))) +::STMT +MATRIX:s +FLOAT:n +LITERAL_FLOAT:1.0 +-(*(/(1.0,s),n),1.0) +::STMT +MATRIX:y_corr +FLOAT:int922 +LITERAL_FLOAT:1.0 +*(*(y_corr,-(1.0,<=(y_corr,int922))),-(1.0,>=(y_corr,1.0))) +::STMT +FLOAT:429_C +LITERAL_FLOAT:1.0 +*(*(429_C,1.0),1.0) +::STMT +MATRIX:parsertemp220853,parsertemp220854,beta +FLOAT:logU +LITERAL_FLOAT:0.0 +*(>=(-(+(parsertemp220853,parsertemp220854),logU),0.0),beta) +::STMT +MATRIX:Y,2212_tp +/(2212_tp,sum(Y)) +::STMT +FLOAT:int489,lratio_t,N +LITERAL_FLOAT:1.0 +-(1.0,exp(/(*(lratio_t,int489),N))) +::STMT +MATRIX:parsertemp116096,X2 +LITERAL_FLOAT:0.0,32.0 +|(<(t(colSums(X2)),32.0),==(t(%*%(parsertemp116096,X2)),0.0)) +::STMT +MATRIX:H2_prime,2365_delta3,H1_prime,W2,W3 +*(H1_prime,%*%(*(H2_prime,%*%(2365_delta3,W3)),W2)) +::STMT +MATRIX:parsertemp44107,parsertemp44109,wnew +FLOAT:C +*(+(wnew,*(C,%*%(parsertemp44107,parsertemp44109))),+(wnew,*(C,%*%(parsertemp44107,parsertemp44109)))) +::STMT +MATRIX:Y +LITERAL_FLOAT:0.0,0.5 +*(-(0.5,Y),==(rowSums(Y),0.0)) +::STMT +MATRIX:s,d,parsertemp44021 +FLOAT:delta2 +*(cast.FLOAT(%*%(t(d),d)),-(delta2,cast.FLOAT(%*%(parsertemp44021,s)))) +::STMT +LITERAL_FLOAT:1.0,100.0,0.8 ++(*(100.0,0.8),1.0) +::STMT +MATRIX:tmp,X,parsertemp393475,parsertemp393466 +LITERAL_FLOAT:1.0E-17 +t(/(-(%*%(tmp,X),parsertemp393466),+(sqrt(parsertemp393475),1.0E-17))) +::STMT +MATRIX:parsertemp129018 +LITERAL_FLOAT:1.0,2.0 ++(*(max(parsertemp129018),2.0),1.0) +::STMT +MATRIX:surv,se_surv,parsertemp538736 +FLOAT:parsertemp538734 +^(surv,exp(/(*(parsertemp538734,se_surv),parsertemp538736))) +::STMT +FLOAT:i,k +LITERAL_FLOAT:4.0 ++(+(i,k),4.0) +::STMT +MATRIX:p,V +FLOAT:eps ++(%*%(t(V),%*%(V,p)),*(eps,p)) +::STMT +MATRIX:parsertemp552345,tab,catTotal +LITERAL_FLOAT:-1.0 +sum(*(*(/(tab,catTotal),-1.0),parsertemp552345)) +::STMT +MATRIX:X2 +LITERAL_FLOAT:32.0 +<(t(colSums(X2)),32.0) +::STMT +MATRIX:m_iter_err_sum,m_err +t(+(colSums(m_err),m_iter_err_sum)) +::STMT +MATRIX:r,d,parsertemp43998 +FLOAT:int723 +LITERAL_FLOAT:2.0 +/(sum(^(r,2.0)),%*%(t(d),+(d,*(int723,parsertemp43998)))) +::STMT +MATRIX:parsertemp122063,parsertemp122058 +FLOAT:eAvg,alpha,n +LITERAL_FLOAT:1.0 +-(*(alpha,-(/(parsertemp122058,eAvg),1.0)),*(-(1.0,alpha),-(*(parsertemp122063,n),1.0))) +::STMT +MATRIX:m_err_mean +LITERAL_FLOAT:-0.001 +-(-0.001,cast.FLOAT(m_err_mean)) +::STMT +MATRIX:X +LITERAL_FLOAT:300.0,0.0 +-(0.0,/(t(colSums(X)),300.0)) +::STMT +MATRIX:WM +FLOAT:m2X,W,float201 +sqrt(*(m2X,/(sum(WM),-(W,float201)))) +::STMT +LITERAL_FLOAT:1.0,3.0 ++(3.0,1.0) +::STMT +MATRIX:V,W,parsertemp10741,H +FLOAT:Eps +*(H,/(%*%(t(W),V),+(%*%(parsertemp10741,H),Eps))) +::STMT +MATRIX:parsertemp410118,g0_1,g_2 +cast.FLOAT(%*%(t(+(g0_1,g_2)),+(g0_1,t(parsertemp410118)))) +::STMT +MATRIX:E,F,parsertemp12849 +FLOAT:q,int210 +sqrt(/(sum(/(parsertemp12849,E)),*(sum(F),-(q,int210)))) +::STMT +MATRIX:log_prob,X +LITERAL_FLOAT:1.8378770664093453 ++(*(ncol(X),1.8378770664093453),log_prob) +::STMT +MATRIX:X,parsertemp16893,parsertemp16892 +/(%*%(X,t(X)),%*%(sqrt(rowSums(parsertemp16892)),t(sqrt(parsertemp16893)))) +::STMT +MATRIX:s,w +%*%(t(+(w,s)),+(w,s)) +::STMT +MATRIX:X +LITERAL_FLOAT:0.0,200.0 +-(0.0,/(t(colSums(X)),200.0)) +::STMT +MATRIX:parsertemp443530,mean,parsertemp443532,X,weight +FLOAT:float416 +-(%*%(t(X),X),%*%(*(t(mean),+(parsertemp443530,float416)),/(%*%(parsertemp443532,X),t(weight)))) +::STMT +MATRIX:linear_terms +FLOAT:var_power +LITERAL_FLOAT:2.0 +/(^(linear_terms,2.0),-(2.0,var_power)) +::STMT +MATRIX:parsertemp170101 +FLOAT:parsertemp170114,r_CG,g_reg,z,277_sq_root_d,parsertemp170093,pp_CG +LITERAL_FLOAT:0.5 ++(*(0.5,*(cast.FLOAT(z),+(r_CG,g_reg))),*(+(+(parsertemp170114,z),sum(parsertemp170101)),/(-(parsertemp170093,277_sq_root_d),pp_CG))) +::STMT +MATRIX:Y +FLOAT:maxv,minv ++(sum(==(Y,minv)),sum(==(Y,maxv))) +::STMT +FLOAT:num_records +LITERAL_FLOAT:960.0 +/(960.0,num_records) +::STMT +MATRIX:r,parsertemp44050 +FLOAT:norm_r2 +/(sum(*(-(r,parsertemp44050),-(r,parsertemp44050))),norm_r2) +::STMT +MATRIX:X,permut +colSums(%*%(permut,X)) +::STMT +FLOAT:batch_size,parsertemp145942 +LITERAL_FLOAT:1.0 +-(+(+(parsertemp145942,1.0),batch_size),1.0) +::STMT +MATRIX:lambda,V,shift_X,parsertemp274512,HV +*(V,+(+(%*%(parsertemp274512,HV),%*%(shift_X,HV)),*(lambda,V))) +::STMT +MATRIX:I,y2 +LITERAL_FLOAT:2.0 +^(/(%*%(I,y2),sum(I)),2.0) +::STMT +MATRIX:H3_prime,delta4,W4 +t(colSums(*(H3_prime,%*%(delta4,W4)))) +::STMT +MATRIX:tmp,parsertemp260786,X,Y,parsertemp260785,out +%*%(t(-(%*%(parsertemp260785,parsertemp260786),tmp)),-(%*%(t(X),*(out,Y)),tmp)) +::STMT +MATRIX:Y,missing_mask_Y +LITERAL_FLOAT:1.0 +*(missing_mask_Y,+(max(Y),1.0)) +::STMT +MATRIX:finite_linear_terms +LITERAL_FLOAT:1.0,0.231641888 ++(1.0,*(abs(finite_linear_terms),0.231641888)) +::STMT +MATRIX:ytest,yhat +FLOAT:int780,mean_y_test +LITERAL_FLOAT:1.0,2.0 +/(^(-(cast.FLOAT(ytest),cast.FLOAT(yhat)),2.0),-(^(cast.FLOAT(ytest),2.0),*(1.0,^(mean_y_test,int780)))) +::STMT +MATRIX:z +FLOAT:trust_delta_sq,pp_CG +LITERAL_FLOAT:2.0 +*(pp_CG,-(^(cast.FLOAT(z),2.0),trust_delta_sq)) +::STMT +MATRIX:parsertemp147188 +FLOAT:D +LITERAL_FLOAT:2.0 +*(parsertemp147188,sqrt(/(2.0,D))) +::STMT +MATRIX:X +FLOAT:int111 +LITERAL_FLOAT:1.0E-6 +/(*(1.0E-6,sum(^(X,int111))),ncol(X)) +::STMT +LITERAL_FLOAT:1.4142135623730951 +1.4142135623730951 +::STMT +MATRIX:sq_sums,mu +FLOAT:window_size +-(/(sq_sums,window_size),*(mu,mu)) +::STMT +MATRIX:663_img +t(rev(t(663_img))) +::STMT +MATRIX:W1_rand +LITERAL_FLOAT:0.08692913816996169 +*(0.08692913816996169,W1_rand) +::STMT +MATRIX:classes +LITERAL_FLOAT:1.0,0.7 +*(cast.FLOAT(classes),-(1.0,0.7)) +::STMT +MATRIX:parsertemp2832 +sum(==(round(parsertemp2832),max(round(parsertemp2832)))) +::STMT +FLOAT:i +LITERAL_FLOAT:18.0 ++(18.0,i) +::STMT +MATRIX:V +FLOAT:std_dev,int435,mu +*(<(V,-(mu,*(int435,std_dev))),V) +::STMT +MATRIX:V +FLOAT:std_dev,mu,int91 +*(>(V,+(mu,*(int91,std_dev))),V) +::STMT +MATRIX:d,X,logisticD +%*%(t(X),*(logisticD,%*%(X,d))) +::STMT +MATRIX:parsertemp477917,b +FLOAT:int929 +LITERAL_FLOAT:2.0 +sum(^(%*%(*(parsertemp477917,int929),b),2.0)) +::STMT +MATRIX:subspace_idx,parsertemp72201 +LITERAL_FLOAT:1.0,8.0 +<(-(subspace_idx,*(parsertemp72201,8.0)),1.0) +::STMT +MATRIX:Y +LITERAL_FLOAT:0.0 +/(Y,+(rowSums(Y),==(rowSums(Y),0.0))) +::STMT +MATRIX:w,X,y +FLOAT:int253 +LITERAL_FLOAT:1.0 ++(1.0,exp(*(*(y,int253),%*%(X,w)))) +::STMT +MATRIX:r_LS +FLOAT:norm_r2_LS,p_LS,parsertemp170552,lambda_LS +LITERAL_FLOAT:0.0 +-(0.0,+(r_LS,*(/(norm_r2_LS,p_LS),+(parsertemp170552,lambda_LS)))) +::STMT +MATRIX:parsertemp552530,Y +LITERAL_FLOAT:0.0 +INT:parsertemp552529,idx +==(+(rand(parsertemp552529,idx,0.0,0.0),t(parsertemp552530)),Y) +::STMT +MATRIX:y_corr +LITERAL_FLOAT:1.0,2.0 +-(1.0,*(2.0,y_corr)) +::STMT +MATRIX:linear_terms +FLOAT:int594 +LITERAL_FLOAT:1.0,2.0 ++(1.0,-(*(2.0,>=(linear_terms,int594)),1.0)) +::STMT +MATRIX:shift_X,w,parsertemp170066,X +*(cast.FLOAT(shift_X),cast.FLOAT(%*%(t(X),*(w,parsertemp170066)))) +::STMT +MATRIX:parsertemp437548,pred,parsertemp437666 +==(*(parsertemp437666,t(parsertemp437548)),pred) +::STMT +MATRIX:means,parsertemp389215 +FLOAT:int11 +LITERAL_FLOAT:1057.0,1058.0 +/(*(-(parsertemp389215,^(means,int11)),1058.0),1057.0) +::STMT +MATRIX:U,V_sum +rowSums(/(*(U,U),sum(V_sum))) +::STMT +FLOAT:padh,strideh,int428,Hin,Hf +/(-(+(Hin,*(int428,padh)),Hf),strideh) +::STMT +MATRIX:parsertemp31111,parsertemp31113 +FLOAT:int876 +LITERAL_FLOAT:1499.0,2.0 +^(/(-(colSums(parsertemp31111),*(int876,parsertemp31113)),1499.0),2.0) +::STMT +MATRIX:parsertemp16859,X +FLOAT:int570 +LITERAL_FLOAT:1.0E-6 ++(sqrt(rowSums(^(X,int570))),*(<(sqrt(parsertemp16859),1.0E-6),1.0E-6)) +::STMT +FLOAT:new_log_l,log_l,neg_log_l_change_predicted +LITERAL_FLOAT:-1.0 +/(+(*(new_log_l,-1.0),log_l),neg_log_l_change_predicted) +::STMT +FLOAT:i2 +LITERAL_FLOAT:24.0,1.0 ++(*(24.0,i2),1.0) +::STMT +MATRIX:grad +sqrt(sum(*(grad,grad))) +::STMT +FLOAT:res_eee +LITERAL_FLOAT:2.0,0.3 +round(-(/(res_eee,2.0),0.3)) +::STMT +MATRIX:parsertemp285531,z,parsertemp285533 +FLOAT:pp,sq_root_d,zq,parsertemp285544,parsertemp285526 +LITERAL_FLOAT:0.5 ++(*(0.5,sum(*(z,parsertemp285533))),*(+(+(parsertemp285544,zq),sum(parsertemp285531)),/(-(parsertemp285526,sq_root_d),pp))) +::STMT +MATRIX:parsertemp382919,parsertemp382916,S,col_nonzeros +FLOAT:reg +*(S,+(t(%*%(parsertemp382916,parsertemp382919)),*(*(reg,S),col_nonzeros))) +::STMT +MATRIX:vectors,pq_result +LITERAL_FLOAT:2.0,480.0 +/(sum(^(-(vectors,pq_result),2.0)),480.0) +::STMT +MATRIX:X,ScaleFactor +FLOAT:N +t(/(colSums(/(X,ScaleFactor)),N)) +::STMT +MATRIX:border,parsertemp386448,parsertemp386459,withinEps +LITERAL_FLOAT:0.0 +t(*(>(*(parsertemp386448,withinEps),0.0),==(-(border,parsertemp386459),0.0))) +::STMT +MATRIX:vectors,pq_result +LITERAL_FLOAT:800.0,2.0 +/(sum(^(-(vectors,pq_result),2.0)),800.0) +::STMT +MATRIX:p,lambda,parsertemp456801,parsertemp456800 +cast.FLOAT(%*%(t(p),+(%*%(parsertemp456800,parsertemp456801),*(lambda,p)))) +::STMT +MATRIX:parsertemp500609,parsertemp500606,parsertemp500604,w +FLOAT:int146,int367 +*(-(*(*(parsertemp500604,parsertemp500606),>(parsertemp500609,int146)),w),-(*(*(parsertemp500604,parsertemp500606),>(parsertemp500609,int367)),w)) +::STMT +LITERAL_FLOAT:0.21483446221182986 +0.21483446221182986 +::STMT +MATRIX:P,X,Y,parsertemp148868 +FLOAT:float9 +LITERAL_FLOAT:0.0,2.0 +^(+(%*%(t(X),-(P,Y)),*(*(parsertemp148868,float9),0.0)),2.0) +::STMT +MATRIX:parsertemp467675,Y,Xw +FLOAT:int437 +LITERAL_FLOAT:0.0,1.0 +*(*(>(-(int437,parsertemp467675),0.0),-(1.0,*(Y,Xw))),Y) +::STMT +MATRIX:simplex +/(-(rowSums(simplex),simplex),nrow(simplex)) +::STMT +MATRIX:d_r,parsertemp409781 +*(rev(d_r),parsertemp409781) +::STMT +FLOAT:W +LITERAL_FLOAT:1.0 +/(W,-(W,1.0)) +::STMT +MATRIX:is_LT_infinite,Y_prob +LITERAL_FLOAT:1.0 +*(/(Y_prob,rowSums(Y_prob)),-(1.0,rowSums(is_LT_infinite))) +::STMT +MATRIX:parsertemp409788,parsertemp409797 +LITERAL_FLOAT:-1.0,2.0 +^(+(*(t(parsertemp409788),-1.0),t(colSums(parsertemp409797))),2.0) +::STMT +MATRIX:parsertemp31111,parsertemp31113 +FLOAT:int649 +LITERAL_FLOAT:1499.0,1500.0 +/(/(-(colSums(parsertemp31111),*(int649,parsertemp31113)),1499.0),1500.0) +::STMT +LITERAL_FLOAT:1.0E-17 +1.0E-17 +::STMT +MATRIX:scale_lambda,parsertemp150455 +LITERAL_FLOAT:0.0,1.0E-5 +*(*(%*%(scale_lambda,parsertemp150455),1.0E-5),0.0) +::STMT +FLOAT:e,decay +LITERAL_FLOAT:1.0 ++(1.0,*(decay,e)) +::STMT +MATRIX:A +/(*(cast.FLOAT(A),cast.FLOAT(A)),*(cast.FLOAT(A),cast.FLOAT(A))) +::STMT +MATRIX:p,V +FLOAT:eps +%*%(t(p),+(%*%(t(V),%*%(V,p)),*(eps,p))) +::STMT +MATRIX:parsertemp43621,X,y +FLOAT:float787 +LITERAL_FLOAT:1.0 +%*%(t(X),*(-(/(float787,parsertemp43621),1.0),y)) +::STMT +MATRIX:g_new,g_old +LITERAL_FLOAT:2.0 +/(sum(^(g_new,2.0)),sum(^(g_old,2.0))) +::STMT +MATRIX:_sbcvar415,parsertemp116129 +FLOAT:eAvg,parsertemp116127 +LITERAL_FLOAT:0.050000000000000044,1.0,0.95 +-(*(0.95,-(/(parsertemp116129,eAvg),1.0)),*(0.050000000000000044,-(/(parsertemp116127,_sbcvar415),1.0))) +::STMT +MATRIX:w_X,X +FLOAT:int159 +cast.FLOAT(%*%(t(-(int159,w_X)),t(colSums(X)))) +::STMT +MATRIX:prec,X,mu +LITERAL_FLOAT:2.0 +^(-(%*%(X,prec),%*%(mu,prec)),2.0) +::STMT +FLOAT:i,Hin,Win +LITERAL_FLOAT:1.0 +*(*(-(i,1.0),Hin),Win) +::STMT +MATRIX:missing_val_maps +LITERAL_FLOAT:3.0 +-(3.0,nrow(missing_val_maps)) +::STMT +MATRIX:out +FLOAT:dd,step_sz,wd +*(-(+(wd,*(step_sz,dd)),sum(out)),-(+(wd,*(step_sz,dd)),sum(out))) +::STMT +MATRIX:W1_rand +LITERAL_FLOAT:0.07261134713572442 +*(0.07261134713572442,W1_rand) +::STMT +MATRIX:g +FLOAT:float990 +LITERAL_FLOAT:0.0,2.0 +sum(^(-(0.0,*(float990,g)),2.0)) +::STMT +MATRIX:cm,FD +LITERAL_FLOAT:1.0 ++(FD,==(cm,1.0)) +::STMT +FLOAT:parsertemp22485,parsertemp22452,parsertemp22453 +LITERAL_FLOAT:2.0 +-(parsertemp22485,*(2.0,sqrt(+(parsertemp22452,parsertemp22453)))) +::STMT +MATRIX:residual_matrix +LITERAL_FLOAT:0.0,2.0 +/(^(sum(residual_matrix),2.0),+(nrow(residual_matrix),0.0)) +::STMT +MATRIX:lambda,parsertemp285716,scale_X,p_CG,shift_X,parsertemp285714,temp_CG ++(+(*(lambda,p_CG),%*%(diag(scale_X),%*%(parsertemp285714,parsertemp285716))),%*%(shift_X,temp_CG)) +::STMT +MATRIX:parsertemp389212,parsertemp389215 +FLOAT:int362 +LITERAL_FLOAT:2.0,1058.0 +*(-(parsertemp389215,^(/(parsertemp389212,int362),2.0)),1058.0) +::STMT +MATRIX:Xm,parsertemp265706,Z,parsertemp265702 +FLOAT:ss +sum(+(%*%(t(Z),%*%(Xm,parsertemp265702)),*(parsertemp265706,ss))) +::STMT +FLOAT:delta +LITERAL_FLOAT:4.0 +*(4.0,delta) +::STMT +MATRIX:parsertemp42207,parsertemp42208,_sbcvar330,438_Ranks +FLOAT:parsertemp42222,meanY,meanX +LITERAL_FLOAT:0.5 +*(t(*(/(_sbcvar330,parsertemp42222),-(438_Ranks,meanX))),-(+(-(parsertemp42207,parsertemp42208),0.5),meanY)) +::STMT +MATRIX:z,parsertemp285752 +FLOAT:2234_sq_root_d,parsertemp285742,pp_CG,parsertemp285757 +LITERAL_FLOAT:0.5 ++(*(0.5,cast.FLOAT(%*%(z,parsertemp285752))),*(parsertemp285757,/(+(parsertemp285742,2234_sq_root_d),pp_CG))) +::STMT +FLOAT:batch_size,parsertemp145942 +LITERAL_FLOAT:1.0 ++(+(parsertemp145942,1.0),batch_size) +::STMT +FLOAT:m2X,W,float178 +sqrt(*(m2X,/(W,-(W,float178)))) +::STMT +MATRIX:p,q +FLOAT:norm_r2 +*(/(norm_r2,sum(*(p,q))),p) +::STMT +FLOAT:m2,mu +LITERAL_FLOAT:1.0005 +/(sqrt(*(1.0005,m2)),mu) +::STMT +MATRIX:Y_counts,Y,parsertemp560599 +FLOAT:parsertemp560600 +LITERAL_FLOAT:2.0 +^(-(Y,%*%(Y_counts,/(parsertemp560599,parsertemp560600))),2.0) +::STMT +MATRIX:Xd,parsertemp2775 +FLOAT:int805 +LITERAL_FLOAT:0.0 +*(*(Xd,>(-(int805,parsertemp2775),0.0)),Xd) +::STMT +MATRIX:parsertemp500663 +LITERAL_FLOAT:-1.0E30 +*(-1.0E30,parsertemp500663) +::STMT +MATRIX:parsertemp477829,2814_Y +FLOAT:2814_X,inp_x +*(+(*(cast.FLOAT(parsertemp477829),-(2814_X,2814_X)),-(cast.FLOAT(2814_Y),cast.FLOAT(2814_Y))),/(-(inp_x,cast.FLOAT(2814_X)),-(cast.FLOAT(2814_X),cast.FLOAT(2814_X)))) +::STMT +MATRIX:Xtest_dists +FLOAT:eps +LITERAL_FLOAT:0.0 +*(<=(Xtest_dists,eps),<(0.0,Xtest_dists)) +::STMT +MATRIX:parsertemp410250,event +FLOAT:parsertemp410251 +/(-(max(^(parsertemp410250,parsertemp410251)),min(^(parsertemp410250,parsertemp410251))),sum(event)) +::STMT +MATRIX:275_X,275_curr_X +FLOAT:275_value +&(==(275_X,275_curr_X),<(275_X,275_value)) +::STMT +MATRIX:r_CG,g_reg,z +cast.FLOAT(%*%(t(z),+(r_CG,g_reg))) +::STMT +MATRIX:X +FLOAT:var_lag,xq_lag,arch_coef,var_coef,a0 +LITERAL_FLOAT:2.0 +/(^(cast.FLOAT(X),2.0),+(+(a0,*(arch_coef,xq_lag)),*(var_coef,var_lag))) +::STMT +FLOAT:k,n +LITERAL_FLOAT:2.0,4.0 +-(+(-(n,4.0),2.0),k) +::STMT +MATRIX:X +FLOAT:x +-(x,X) +::STMT +MATRIX:Hdiff,beta,betamin +FLOAT:int455,int899 ++(beta,+(*(<(Hdiff,int899),betamin),*(>=(Hdiff,int455),beta))) +::STMT +MATRIX:Hdiff,betamax,beta +FLOAT:int866,int847 ++(beta,+(*(>=(Hdiff,int847),betamax),*(<(Hdiff,int866),beta))) +::STMT +MATRIX:z +LITERAL_FLOAT:2.0 +sqrt(^(cast.FLOAT(z),2.0)) +::STMT +MATRIX:X,H +LITERAL_FLOAT:0.0 +>(%*%(X,t(H)),0.0) +::STMT +MATRIX:Bx +diag(Bx) +::STMT +MATRIX:parsertemp31189,parsertemp31194,parsertemp31196,parsertemp31187 +LITERAL_FLOAT:1499.0,6999.0,1500.0,7000.0 ++(/(/(-(parsertemp31187,parsertemp31189),6999.0),7000.0),/(/(-(parsertemp31194,parsertemp31196),1499.0),1500.0)) +::STMT +MATRIX:parsertemp170244,parsertemp170240,parsertemp170238 +FLOAT:float847,float32,float42 +LITERAL_FLOAT:1.0,-0.284496736 +*(/(1.0,+(1.0,*(parsertemp170238,float847))),+(-0.284496736,*(/(float32,parsertemp170240),+(float42,parsertemp170244)))) +::STMT +FLOAT:2690_Hin +LITERAL_FLOAT:0.0,2.0 +-(+(2690_Hin,*(2.0,0.0)),2.0) +::STMT +MATRIX:A,B,X +<=(%*%(X,A),B) +::STMT +LITERAL_FLOAT:1.0,2.0,3.0,2001.0 +*(*(-(2001.0,2.0),+(2001.0,1.0)),+(2001.0,3.0)) +::STMT +MATRIX:P,I,X2 +*(t(%*%(X2,P)),I) +::STMT +LITERAL_FLOAT:0.06835859270246632 +0.06835859270246632 +::STMT +MATRIX:d,parsertemp410053 +sum(*(d,t(colSums(parsertemp410053)))) +::STMT +MATRIX:W +LITERAL_FLOAT:1.0,2.0 +*(^(sum(round(W)),2.0),+(sum(round(W)),1.0)) +::STMT +MATRIX:2883_ctab +LITERAL_FLOAT:0.0,1.0 +==(rowSums(!=(2883_ctab,0.0)),1.0) +::STMT +MATRIX:M2,X +-(nrow(X),nrow(M2)) +::STMT +MATRIX:parsertemp403496,W3_rand +FLOAT:int454,int938 +LITERAL_FLOAT:0.1651445647689541 +%*%(*(0.1651445647689541,W3_rand),t(/(-(parsertemp403496,int454),+(parsertemp403496,int938)))) +::STMT +MATRIX:w,parsertemp2794 +FLOAT:lambda +LITERAL_FLOAT:2.0 +*(/(lambda,2.0),sum(*(+(w,parsertemp2794),+(w,parsertemp2794)))) +::STMT +MATRIX:parsertemp31029,parsertemp31031 +FLOAT:int700 +LITERAL_FLOAT:1.0,2.0,150.0 +^(/(-(colSums(parsertemp31029),*(int700,parsertemp31031)),-(150.0,1.0)),2.0) +::STMT +MATRIX:p,z +*(sum(*(p,z)),sum(*(p,z))) +::STMT +MATRIX:X +LITERAL_FLOAT:-2.0 +*(-2.0,%*%(X,t(X))) +::STMT +MATRIX:parsertemp31189,parsertemp31194,parsertemp31196,parsertemp31187 +FLOAT:int893,int871,int192,int39 +LITERAL_FLOAT:1500.0,7000.0 ++(/(/(-(parsertemp31187,parsertemp31189),-(int893,int39)),7000.0),/(/(-(parsertemp31194,parsertemp31196),-(int192,int871)),1500.0)) +::STMT +MATRIX:scale_X,shift_X +LITERAL_FLOAT:2.0 +*(*(2.0,scale_X),shift_X) +::STMT +MATRIX:COMPONENTS,id +-(==(id,cast.FLOAT(id)),cast.FLOAT(diag(diag(COMPONENTS)))) +::STMT +MATRIX:252_X +LITERAL_FLOAT:4.5 +/(-(4.5,cast.FLOAT(252_X)),-(cast.FLOAT(252_X),cast.FLOAT(252_X))) +::STMT +MATRIX:252_Y,252_X,252_K +-(*(cast.FLOAT(252_K),-(cast.FLOAT(252_X),cast.FLOAT(252_X))),-(cast.FLOAT(252_Y),cast.FLOAT(252_Y))) +::STMT +MATRIX:gs +FLOAT:alpha2Scalar +LITERAL_FLOAT:-0.5 +/(*(-0.5,cast.FLOAT(gs)),alpha2Scalar) +::STMT +MATRIX:parsertemp146940,184_dtemp +FLOAT:beta2 +LITERAL_FLOAT:1.0,2.0 +*(-(1.0,beta2),^(colSums(-(184_dtemp,parsertemp146940)),2.0)) +::STMT +FLOAT:m2 +LITERAL_FLOAT:1.0,2000.0 +*(/(2000.0,-(2000.0,1.0)),m2) +::STMT +MATRIX:parsertemp387404,K_inv,Ks,Kss +-(cast.FLOAT(Kss),cast.FLOAT(%*%(%*%(parsertemp387404,K_inv),Ks))) +::STMT +MATRIX:parsertemp131907,parsertemp131918,cumLeftHist,parsertemp132092,leftHist,outBucket ++(%*%(==(outBucket,%*%(parsertemp132092,parsertemp131907)),-(cumLeftHist,leftHist)),parsertemp131918) +::STMT +MATRIX:e,X2 +LITERAL_FLOAT:0.0 +>(t(%*%(t(e),X2)),0.0) +::STMT +MATRIX:PRED,GT +/(sum(==(PRED,GT)),length(==(PRED,GT))) +::STMT +MATRIX:U,V,X +-(X,%*%(U,t(V))) +::STMT +FLOAT:m2X,float180,int20 +LITERAL_FLOAT:100000.0 +sqrt(*(m2X,/(100000.0,-(int20,float180)))) +::STMT +MATRIX:p,A +*(p,%*%(t(A),%*%(A,p))) +::STMT +MATRIX:V_nonzero,row_nonzeros,lambda_I ++(%*%(t(V_nonzero),V_nonzero),*(cast.FLOAT(row_nonzeros),lambda_I)) +::STMT +MATRIX:C,Xm,parsertemp265706,parsertemp265704,Z,parsertemp265701 +FLOAT:ss +/(%*%(t(Xm),%*%(Xm,%*%(C,parsertemp265701))),sum(+(%*%(parsertemp265704,Z),*(parsertemp265706,ss)))) +::STMT +FLOAT:parsertemp13703 +LITERAL_FLOAT:1.0,1000.0 +/(*(parsertemp13703,1000.0),-(1000.0,1.0)) +::STMT +MATRIX:WM +FLOAT:parsertemp31268 +LITERAL_FLOAT:1.0 +*(-(sum(WM),1.0),/(*(parsertemp31268,sum(WM)),-(sum(WM),1.0))) +::STMT +MATRIX:Xm,tmp,parsertemp265702 +t(/(%*%(t(Xm),%*%(Xm,parsertemp265702)),sum(tmp))) +::STMT +MATRIX:scale_X,shift_X,X,parsertemp271403 +FLOAT:int126,int545 +LITERAL_FLOAT:2.0 ++(+(%*%(^(X,int126),^(scale_X,int545)),%*%(X,*(parsertemp271403,shift_X))),sum(^(shift_X,2.0))) +::STMT +FLOAT:parsertemp271435 +LITERAL_FLOAT:1500.0 +*(1500.0,parsertemp271435) +::STMT +FLOAT:Hin +LITERAL_FLOAT:184.0 +*(+(Hin,184.0),+(Hin,184.0)) +::STMT +MATRIX:d,parsertemp43996,parsertemp43997 +LITERAL_FLOAT:2.0 +%*%(t(d),+(d,*(2.0,%*%(parsertemp43996,parsertemp43997)))) +::STMT +MATRIX:K_inv,Ks,Kss +-(Kss,%*%(%*%(t(Ks),K_inv),Ks)) +::STMT +MATRIX:parsertemp220900,parsertemp220899,Y +LITERAL_FLOAT:300.0,0.0 ++(Y,-(0.0,*(300.0,-(parsertemp220899,parsertemp220900)))) +::STMT +MATRIX:WM +LITERAL_FLOAT:1.0 +-(sum(WM),1.0) +::STMT +FLOAT:res_eee +LITERAL_FLOAT:2.0,0.3 +-(/(res_eee,2.0),0.3) +::STMT +MATRIX:parsertemp24102 +FLOAT:num_bins +LITERAL_FLOAT:1.0 +*(>(+(round(parsertemp24102),1.0),num_bins),num_bins) +::STMT +MATRIX:W +FLOAT:m2 +*(m2,sum(round(W))) +::STMT +MATRIX:2903_mask,dout,2902_W +FLOAT:2903_p +*(/(2903_mask,2903_p),%*%(dout,t(2902_W))) +::STMT +MATRIX:r_CG,q_CG +FLOAT:alpha_CG ++(r_CG,*(alpha_CG,cast.FLOAT(q_CG))) +::STMT +MATRIX:_sbcvar95,_sbcvar97 +FLOAT:221_my +LITERAL_FLOAT:0.0,2.0 +^(+(%*%(_sbcvar95,_sbcvar97),-(0.0,221_my)),2.0) +::STMT +MATRIX:parsertemp395002,W4_rand,parsertemp395005 +LITERAL_FLOAT:0.08692913816996169 +t(%*%(*(0.08692913816996169,W4_rand),t(/(parsertemp395002,parsertemp395005)))) +::STMT +MATRIX:X,Y,K +-(*(K,-(X,X)),-(Y,Y)) +::STMT +MATRIX:Xd,out +FLOAT:dd,parsertemp467655,wd +/(*(-(+(wd,parsertemp467655),sum(out)),-(+(wd,parsertemp467655),sum(out))),+(dd,sum(Xd))) +::STMT +FLOAT:i +LITERAL_FLOAT:27.0 ++(27.0,i) +::STMT +MATRIX:X,y +LITERAL_FLOAT:0.0 +%*%(-(0.0,t(X)),y) +::STMT +MATRIX:leaf_ids +FLOAT:boundary_left,step_size +<(leaf_ids,+(+(boundary_left,step_size),step_size)) +::STMT +FLOAT:idx +LITERAL_FLOAT:1.0,11.0 ++(-(11.0,idx),1.0) +::STMT +LITERAL_FLOAT:1.0,2.0 ++(2.0,1.0) +::STMT +MATRIX:p_gaps_vector +FLOAT:number_nans +/(number_nans,sum(p_gaps_vector)) +::STMT +FLOAT:g,h +/(*(g,g),h) +::STMT +MATRIX:var_X_cols,parsertemp414376,parsertemp414378 +FLOAT:int672 +LITERAL_FLOAT:0.0,1.0,199.0 ++(*(/(-(parsertemp414376,parsertemp414378),199.0),-(1.0,<=(var_X_cols,int672))),<=(/(-(parsertemp414376,parsertemp414378),199.0),0.0)) +::STMT +LITERAL_FLOAT:1.0,6.0,2001.0 +*(*(6.0,2001.0),-(2001.0,1.0)) +::STMT +MATRIX:D,beta +LITERAL_FLOAT:-1.0 +exp(*(*(D,-1.0),beta)) +::STMT +MATRIX:d,exp_Xb,X +rev(*(X,*(%*%(X,d),exp_Xb))) +::STMT +MATRIX:K_inv,parsertemp387408,Ks,Kss +cast.FLOAT(-(Kss,%*%(%*%(parsertemp387408,K_inv),Ks))) +::STMT +MATRIX:present_domain_vals_mat,parsertemp27485 +FLOAT:my +-(%*%(present_domain_vals_mat,parsertemp27485),my) +::STMT +MATRIX:p_CG,z +LITERAL_FLOAT:-1.0 +*(cast.FLOAT(%*%(t(p_CG),z)),-1.0) +::STMT +MATRIX:parsertemp43626 +LITERAL_FLOAT:-1.0,2.0 +*(2.0,sum(*(parsertemp43626,-1.0))) +::STMT +MATRIX:X_batch,W1_grad +FLOAT:step +*(/(step,nrow(X_batch)),W1_grad) +::STMT +MATRIX:_sbcvar1156,_sbcvar1155 +FLOAT:num_records +LITERAL_FLOAT:1.0 ++(*(_sbcvar1155,_sbcvar1156),*(+(num_records,1.0),-(1.0,_sbcvar1156))) +::STMT +MATRIX:e_r_rev_agg,select,d_r_rev,X_rev_agg +/(*(%*%(select,X_rev_agg),d_r_rev),e_r_rev_agg) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:1.0,750.0 +/(+(classFeatureCounts,1.0),+(rowSums(classFeatureCounts),750.0)) +::STMT +MATRIX:parsertemp220853,Ws,beta +FLOAT:logU +LITERAL_FLOAT:0.0 +<(-(+(parsertemp220853,*(beta,Ws)),logU),0.0) +::STMT +MATRIX:P,D,ZERODIAG +LITERAL_FLOAT:1.0E-12 +/(rowSums(*(*(P,ZERODIAG),D)),+(rowSums(*(P,ZERODIAG)),1.0E-12)) +::STMT +MATRIX:tmp,w,out +LITERAL_FLOAT:50.0,0.5 ++(*(0.5,cast.FLOAT(%*%(out,out))),*(50.0,cast.FLOAT(%*%(w,tmp)))) +::STMT +MATRIX:p,G +FLOAT:alpha +*(alpha,%*%(G,p)) +::STMT +MATRIX:q,z +FLOAT:pp,pq,parsertemp285524 +LITERAL_FLOAT:0.5 ++(*(*(0.5,/(parsertemp285524,pp)),pq),sum(*(z,q))) +::STMT +MATRIX:p,q,lambda +sum(*(p,+(q,*(lambda,p)))) +::STMT +MATRIX:lambda,beta +LITERAL_FLOAT:0.0 +*(+(0.0,*(cast.FLOAT(lambda),cast.FLOAT(beta))),+(0.0,*(cast.FLOAT(lambda),cast.FLOAT(beta)))) +::STMT +MATRIX:_sbcvar78 +LITERAL_FLOAT:10000.0 +/(%*%(rowSums(_sbcvar78),colSums(_sbcvar78)),10000.0) +::STMT +MATRIX:parsertemp16859,77_Y_row_norm,parsertemp16868,X,Y,parsertemp16861 +FLOAT:float904 +/(%*%(X,t(Y)),%*%(+(sqrt(parsertemp16859),*(parsertemp16861,float904)),t(+(77_Y_row_norm,parsertemp16868)))) +::STMT +MATRIX:X_adapted,parsertemp176506 +FLOAT:intercept,parsertemp176418 +LITERAL_FLOAT:3.0 +<(X_adapted,-(sqrt(parsertemp176418),*(3.0,+(parsertemp176506,intercept)))) +::STMT +MATRIX:X_adapted,parsertemp176506 +FLOAT:intercept,parsertemp176418 +LITERAL_FLOAT:3.0 +>(X_adapted,+(sqrt(parsertemp176418),*(3.0,+(parsertemp176506,intercept)))) +::STMT +MATRIX:parsertemp171600,g_Y,lambda,parsertemp171602,beta +LITERAL_FLOAT:2.0 +^(+(*(cast.FLOAT(parsertemp171602),%*%(parsertemp171600,g_Y)),*(cast.FLOAT(lambda),cast.FLOAT(beta))),2.0) +::STMT +MATRIX:z +FLOAT:trust_delta_sq,pp_CG +*(pp_CG,-(cast.FLOAT(%*%(z,z)),trust_delta_sq)) +::STMT +LITERAL_FLOAT:-0.36651292058166435 +-0.36651292058166435 +::STMT +MATRIX:F +LITERAL_FLOAT:2.0 +/(rowSums(F),2.0) +::STMT +MATRIX:parsertemp31104,parsertemp31106 +FLOAT:int255 +LITERAL_FLOAT:1999.0,2000.0 +/(/(-(colSums(parsertemp31104),*(int255,parsertemp31106)),1999.0),2000.0) +::STMT +FLOAT:parsertemp5,m2X,parsertemp9,m2Y,covXY +/(covXY,*(sqrt(*(m2X,parsertemp5)),sqrt(*(m2Y,parsertemp9)))) +::STMT +MATRIX:diff_nominal +FLOAT:num_std_median +LITERAL_FLOAT:0.0 +*(!=(diff_nominal,0.0),num_std_median) +::STMT +MATRIX:W1_rand,X,parsertemp399148,parsertemp399138 +FLOAT:float154 +LITERAL_FLOAT:0.08692913816996169 +%*%(*(0.08692913816996169,W1_rand),t(/(-(X,parsertemp399138),+(parsertemp399148,float154)))) +::STMT +MATRIX:maskd1,out1,185_dX,parsertemp146947,W2 +FLOAT:p +LITERAL_FLOAT:0.0 +*(>(out1,0.0),*(/(maskd1,p),%*%(*(parsertemp146947,185_dX),t(W2)))) +::STMT +MATRIX:X +FLOAT:n +LITERAL_FLOAT:2.0 +^(/(t(colSums(X)),n),2.0) +::STMT +MATRIX:LT,Y,parsertemp149320 +sum(*(Y,-(LT,parsertemp149320))) +::STMT +MATRIX:V,X +LITERAL_FLOAT:0.0 +*(V,t(!=(X,0.0))) +::STMT +MATRIX:X,K +LITERAL_FLOAT:-1.0 +*(*(K,-1.0),-(X,X)) +::STMT +MATRIX:W +FLOAT:m4 +LITERAL_FLOAT:1.0,2.0 +*(*(^(sum(W),2.0),+(sum(W),1.0)),m4) +::STMT +MATRIX:parsertemp32006,simplex +LITERAL_FLOAT:2.0 +-(*(2.0,/(-(parsertemp32006,simplex),nrow(simplex))),simplex) +::STMT +MATRIX:resp,mean,X,weight,diff +/(%*%(t(*(diff,resp)),-(X,mean)),cast.FLOAT(weight)) +::STMT +MATRIX:X +/(t(colSums(X)),nrow(X)) +::STMT +MATRIX:parsertemp31023,parsertemp31025,parsertemp31030,parsertemp31032 +LITERAL_FLOAT:149.0,150.0,99.0,100.0 ++(/(/(-(parsertemp31023,parsertemp31025),99.0),100.0),/(/(-(parsertemp31030,parsertemp31032),149.0),150.0)) +::STMT +MATRIX:R +LITERAL_FLOAT:0.0,32.0 +&(>=(R,32.0),>(R,0.0)) +::STMT +MATRIX:parsertemp31763,parsertemp31756 +FLOAT:minSup +LITERAL_FLOAT:0.0 +sum(&(>=(t(parsertemp31756),minSup),>(t(parsertemp31763),0.0))) +::STMT +MATRIX:ytest,yhat +LITERAL_FLOAT:1.0,2.0 +^(/(-(cast.FLOAT(ytest),cast.FLOAT(yhat)),1.0),2.0) +::STMT +MATRIX:s,w +LITERAL_FLOAT:2.0 +*(2.0,cast.FLOAT(%*%(t(w),s))) +::STMT +LITERAL_FLOAT:0.2656844656620286 +0.2656844656620286 +::STMT +MATRIX:R,parsertemp40219,parsertemp40216,parsertemp40225,removedE +FLOAT:level +-(+(R,rowSums(*(parsertemp40216,parsertemp40225))),rowSums(*(==(parsertemp40219,level),t(removedE)))) +::STMT +MATRIX:Y_val,parsertemp459795 +FLOAT:int459 +LITERAL_FLOAT:50.0 +/(sum(*(-(int459,Y_val),parsertemp459795)),50.0) +::STMT +MATRIX:majority +LITERAL_FLOAT:0.0,1.0,2.0 +INT:int589,parsertemp282730 +*(>(rand(parsertemp282730,int589,1.0,2.0),0.0),majority) +::STMT +MATRIX:s,parsertemp44016,d +FLOAT:delta2 +*(%*%(t(d),d),-(delta2,%*%(t(s),-(s,parsertemp44016)))) +::STMT +MATRIX:parsertemp460691 +FLOAT:lr +*(lr,rowSums(parsertemp460691)) +::STMT +MATRIX:parsertemp171269,Y,linear_terms +FLOAT:int153,int429 +LITERAL_FLOAT:0.0 +-(/(+(Y,==(Y,int429)),+(*(linear_terms,parsertemp171269),==(Y,int153))),==(Y,0.0)) +::STMT +MATRIX:finite_linear_terms +LITERAL_FLOAT:1.0E7 +*(==(+(1.0E7,exp(finite_linear_terms)),1.0E7),exp(finite_linear_terms)) +::STMT +MATRIX:CI_l +LITERAL_FLOAT:0.5 +t(<=(CI_l,0.5)) +::STMT +MATRIX:m_iter_err_sum,m_err +-(t(+(colSums(m_err),m_iter_err_sum)),+(colSums(m_err),m_iter_err_sum)) +::STMT +MATRIX:F +colSums(F) +::STMT +MATRIX:ot2 +FLOAT:int521,Nt +LITERAL_FLOAT:100.0 +/(*(sum(>(ot2,int521)),100.0),Nt) +::STMT +MATRIX:eVals,eVecs +LITERAL_FLOAT:-1.0 +%*%(eVecs,diag(^(eVals,-1.0))) +::STMT +MATRIX:R,3_ss,dsep +FLOAT:3_eAvg +/(/(+(R,dsep),3_ss),3_eAvg) +::STMT +MATRIX:b,X +rev(*(X,exp(%*%(X,b)))) +::STMT +MATRIX:X +FLOAT:p +-(nrow(X),p) +::STMT +MATRIX:indexWithInGroups,parsertemp129475,groupIndex,selectedMatrix ++(-(*(groupIndex,max(parsertemp129475)),max(parsertemp129475)),rowSums(*(indexWithInGroups,selectedMatrix))) +::STMT +MATRIX:Y +-(Y,/(sum(Y),nrow(Y))) +::STMT +LITERAL_FLOAT:5.0,2003.0 ++(2003.0,5.0) +::STMT +FLOAT:i,s_cols +LITERAL_FLOAT:1.0 +*(-(i,1.0),s_cols) +::STMT +MATRIX:parsertemp271438,parsertemp271437 +LITERAL_FLOAT:2.0 +sqrt(sum(^(%*%(parsertemp271437,parsertemp271438),2.0))) +::STMT +FLOAT:max_depth +LITERAL_FLOAT:2.0 +^(2.0,max_depth) +::STMT +LITERAL_FLOAT:1.0,100000.0 +/(100000.0,-(100000.0,1.0)) +::STMT +MATRIX:parsertemp251811 +FLOAT:f +LITERAL_FLOAT:0.0 +==(<=(parsertemp251811,f),0.0) +::STMT +LITERAL_FLOAT:44.75488800120049 +44.75488800120049 +::STMT +MATRIX:H2_prime,H1_prime,W2,parsertemp389612 +t(*(H1_prime,%*%(*(H2_prime,parsertemp389612),W2))) +::STMT +LITERAL_FLOAT:0.001308 +0.001308 +::STMT +MATRIX:img_in +FLOAT:w +LITERAL_FLOAT:2.0 +/(-(ncol(img_in),w),2.0) +::STMT +MATRIX:out2,parsertemp146942,184_dscores,outd1 +FLOAT:int988 +LITERAL_FLOAT:2.0 +^(%*%(t(outd1),*(>(out2,int988),%*%(184_dscores,parsertemp146942))),2.0) +::STMT +MATRIX:z_LS +FLOAT:norm_r2_LS,p_LS ++(z_LS,*(/(norm_r2_LS,*(p_LS,p_LS)),cast.FLOAT(p_LS))) +::STMT +MATRIX:y_val,preds +t(-(y_val,preds)) +::STMT +MATRIX:parsertemp2832 +max(round(parsertemp2832)) +::STMT +MATRIX:parsertemp131906,parsertemp132092,rightHist,outBucket +%*%(==(outBucket,%*%(parsertemp132092,t(parsertemp131906))),rightHist) +::STMT +MATRIX:b_cumulant,Y,natural_parameters +-(*(Y,natural_parameters),b_cumulant) +::STMT +FLOAT:norm_r2,norm_r2_initial +/(norm_r2,norm_r2_initial) +::STMT +MATRIX:leaf_ids,out +FLOAT:boundary_left,step_size ++(out,&(>=(leaf_ids,boundary_left),<(leaf_ids,+(boundary_left,step_size)))) +::STMT +MATRIX:B,X,y +FLOAT:intercept +LITERAL_FLOAT:2.0 +^(-(y,+(%*%(X,B),intercept)),2.0) +::STMT +MATRIX:mean +LITERAL_FLOAT:2.0 +*(2.0,^(mean,2.0)) +::STMT +FLOAT:sv,rad,delta2,s2 +/(-(delta2,s2),+(sv,rad)) +::STMT +MATRIX:classes,X +FLOAT:split ++(-(nrow(X),split),nrow(classes)) +::STMT +MATRIX:parsertemp553014,M2,parsertemp553121,parsertemp553122,missing,parsertemp553008 +LITERAL_FLOAT:2.0 +-(+(%*%(rowSums(parsertemp553008),parsertemp553121),t(%*%(parsertemp553014,parsertemp553122))),*(2.0,%*%(M2,t(missing)))) +::STMT +LITERAL_FLOAT:1.6583123951777 +1.6583123951777 +::STMT +FLOAT:sum_y_test,n +LITERAL_FLOAT:2.0 +*(n,^(/(sum_y_test,n),2.0)) +::STMT +FLOAT:a,b,rad +LITERAL_FLOAT:-1.0,2.0 +/(*(-(b,rad),-1.0),*(2.0,a)) +::STMT +MATRIX:linear_terms +FLOAT:var_power +LITERAL_FLOAT:1.0 +*(linear_terms,-(1.0,var_power)) +::STMT +MATRIX:r,s,grad +-(%*%(t(s),grad),%*%(t(s),r)) +::STMT +MATRIX:parsertemp43631,parsertemp43633 +LITERAL_FLOAT:0.0,2.0 +^(+(0.0,*(2.0,%*%(parsertemp43631,parsertemp43633))),2.0) +::STMT +MATRIX:minD,D +colSums(/(<=(D,minD),rowSums(<=(D,minD)))) +::STMT +MATRIX:ones,classFeatureCounts +LITERAL_FLOAT:1.0,500.0 +%*%(+(rowSums(classFeatureCounts),*(500.0,1.0)),ones) +::STMT +MATRIX:b4,parsertemp389330,parsertemp389333,W4 ++(%*%(W4,t(/(parsertemp389330,parsertemp389333))),b4) +::STMT +MATRIX:M +LITERAL_FLOAT:2.0 +>=(rowSums(M),2.0) +::STMT +MATRIX:F +t(colSums(F)) +::STMT +MATRIX:parsertemp146957,188_dX +FLOAT:beta2 +LITERAL_FLOAT:1.0,2.0 +*(-(1.0,beta2),^(colSums(*(parsertemp146957,188_dX)),2.0)) +::STMT +MATRIX:W +LITERAL_FLOAT:1.0,2.0 +-(^(sum(round(W)),2.0),1.0) +::STMT +MATRIX:parsertemp220863,parsertemp220864,Hdiff,beta +FLOAT:int40,INF +LITERAL_FLOAT:2.0 +*(*(*(2.0,>=(Hdiff,int40)),==(+(parsertemp220863,parsertemp220864),INF)),beta) +::STMT +MATRIX:parsertemp42200,parsertemp42201,_sbcvar330 +FLOAT:meanX +LITERAL_FLOAT:1.0,0.5 +*(/(_sbcvar330,-(sum(_sbcvar330),1.0)),-(+(-(parsertemp42200,parsertemp42201),0.5),meanX)) +::STMT +FLOAT:nFeats +LITERAL_FLOAT:6.283185307179586 +^(6.283185307179586,nFeats) +::STMT +MATRIX:p_CG +FLOAT:trust_delta_sq,z,pp_CG +-(*(*(cast.FLOAT(z),sum(p_CG)),*(cast.FLOAT(z),sum(p_CG))),*(pp_CG,-(*(z,z),trust_delta_sq))) +::STMT +MATRIX:45_CVars,45_CFreqs +FLOAT:int43 +LITERAL_FLOAT:1000.0 +/(sum(*(-(45_CFreqs,int43),45_CVars)),-(1000.0,nrow(45_CFreqs))) +::STMT +MATRIX:parsertemp555613,X,Xc,parsertemp555606,parsertemp555615 +LITERAL_FLOAT:1.0 +/(/(%*%(t(Xc),-(X,parsertemp555606)),-(nrow(X),1.0)),%*%(t(sqrt(parsertemp555613)),sqrt(parsertemp555615))) +::STMT +MATRIX:Bxu,Bxd +LITERAL_FLOAT:2.0 +*(2.0,+(Bxd,Bxu)) +::STMT +MATRIX:lambda,beta +LITERAL_FLOAT:0.0 ++(0.0,*(lambda,beta)) +::STMT +FLOAT:parsertemp557360,parsertemp557354,parsertemp557356,parsertemp557358,prob_true,prob_false ++(/(*(prob_true,parsertemp557354),parsertemp557356),/(*(prob_false,parsertemp557358),parsertemp557360)) +::STMT +FLOAT:parsertemp557360,parsertemp557354,parsertemp557356,weight,parsertemp557358,prob_true,prob_false +LITERAL_FLOAT:-1.0 +*(*(-1.0,weight),+(/(*(prob_true,parsertemp557354),parsertemp557356),/(*(prob_false,parsertemp557358),parsertemp557360))) +::STMT +FLOAT:F1 +LITERAL_FLOAT:2.0 +*(*(F1,2.0),2.0) +::STMT +FLOAT:p,P,Q +LITERAL_FLOAT:1.0 ++(+(+(1.0,p),P),Q) +::STMT +MATRIX:scale_X,X +LITERAL_FLOAT:2.0 +%*%(^(X,2.0),^(scale_X,2.0)) +::STMT +MATRIX:ts +FLOAT:q +-(q,%*%(ts,ts)) +::STMT +FLOAT:s +LITERAL_FLOAT:1.0,4.0 +/(4.0,+(s,1.0)) +::STMT +MATRIX:parsertemp410978,W,H +/(*(H,t(parsertemp410978)),t(colSums(W))) +::STMT +MATRIX:classes +LITERAL_FLOAT:0.30000000000000004 +*(cast.FLOAT(classes),0.30000000000000004) +::STMT +MATRIX:g_reg,p_CG +FLOAT:q_CG,z,int78,pq_CG,pp_CG,parsertemp170107,parsertemp170091 +*(+(+(*(parsertemp170107,pq_CG),*(z,q_CG)),sum(*(g_reg,p_CG))),/(+(*(z,int78),sqrt(parsertemp170091)),pp_CG)) +::STMT +MATRIX:s,d,alpha +FLOAT:parsertemp44004 +%*%(t(+(s,*(parsertemp44004,d))),+(s,*(cast.FLOAT(alpha),d))) +::STMT +MATRIX:fP +FLOAT:max_values +^(ncol(fP),max_values) +::STMT +MATRIX:e,X,tS +FLOAT:l +t(%*%(t(e),==(%*%(X,tS),l))) +::STMT +MATRIX:parsertemp22683,id +-(==(id,t(id)),diag(diag(==(id,parsertemp22683)))) +::STMT +MATRIX:g +FLOAT:lambda,parsertemp169913 +LITERAL_FLOAT:2.0 +*(sum(^(+(g,lambda),2.0)),parsertemp169913) +::STMT +MATRIX:dout,out +LITERAL_FLOAT:1.0,2.0 +*(-(1.0,^(out,2.0)),dout) +::STMT +MATRIX:V,W,H,parsertemp10749 +LITERAL_FLOAT:1.0E-8 +*(W,/(%*%(V,t(H)),+(%*%(W,parsertemp10749),1.0E-8))) +::STMT +LITERAL_FLOAT:2.0,2003.0 +^(2003.0,2.0) +::STMT +MATRIX:out2,parsertemp146942,184_dscores +FLOAT:int386,beta1 +LITERAL_FLOAT:1.0 +*(-(1.0,beta1),colSums(*(>(out2,int386),%*%(184_dscores,parsertemp146942)))) +::STMT +MATRIX:X,Centering,ScaleFactor +t(/(-(X,Centering),ScaleFactor)) +::STMT +MATRIX:p +LITERAL_FLOAT:1.0E-8 +*(1.0E-8,p) +::STMT +MATRIX:2701_mask,2700_W,parsertemp459178,2699_dtemp,2702_X +LITERAL_FLOAT:0.0,0.5 +*(>(2702_X,0.0),*(/(2701_mask,0.5),%*%(-(2699_dtemp,parsertemp459178),t(2700_W)))) +::STMT +MATRIX:parsertemp171268,Y,linear_terms,parsertemp171271,vec1 +FLOAT:link_power,int612 +/(-(-(/(parsertemp171268,parsertemp171271),==(Y,int612)),*(*(Y,vec1),linear_terms)),link_power) +::STMT +MATRIX:lambda,shift_X,gXY,parsertemp171602,beta +LITERAL_FLOAT:2.0 +^(+(+(%*%(parsertemp171602,gXY),%*%(shift_X,gXY)),*(lambda,beta)),2.0) +::STMT +MATRIX:W1_rand +LITERAL_FLOAT:0.08333333333333333 +*(0.08333333333333333,W1_rand) +::STMT +MATRIX:P,Z,ZERODIAG,parsertemp220891 +FLOAT:int793 +-(P,/(*(/(int793,parsertemp220891),ZERODIAG),sum(*(Z,ZERODIAG)))) +::STMT +MATRIX:R,S,parsertemp40218 +FLOAT:level +-(R,rowSums(==(%*%(S,parsertemp40218),level))) +::STMT +MATRIX:A +LITERAL_FLOAT:1.0 +sum(>=(rowSums(abs(A)),1.0)) +::STMT +FLOAT:float320,parsertemp169813 +LITERAL_FLOAT:2.302585092994046,4.0 +*(2.302585092994046,-(4.0,round(-(parsertemp169813,float320)))) +::STMT +MATRIX:parsertemp393584,W4_rand,parsertemp393587 +LITERAL_FLOAT:0.08709382882250233 +t(%*%(*(0.08709382882250233,W4_rand),t(/(parsertemp393584,parsertemp393587)))) +::STMT +MATRIX:parsertemp414374,avg_X_cols +FLOAT:int635 +LITERAL_FLOAT:200.0,199.0 +/(-(t(colSums(parsertemp414374)),*(200.0,^(avg_X_cols,int635))),199.0) +::STMT +MATRIX:parsertemp10743,V,H,parsertemp10739 +%*%(V,t(*(H,/(parsertemp10739,parsertemp10743)))) +::STMT +MATRIX:D,beta +LITERAL_FLOAT:0.0 +*(-(0.0,D),beta) +::STMT +MATRIX:X_batch,2365_delta2,H1_prime,W2 +%*%(t(*(H1_prime,%*%(2365_delta2,W2))),X_batch) +::STMT +MATRIX:parsertemp409789,parsertemp409798,g0_2,g0_1 +FLOAT:int16 +cast.FLOAT(%*%(t(+(g0_1,g0_2)),+(-(int16,parsertemp409789),t(parsertemp409798)))) +::STMT +MATRIX:WM +FLOAT:parsertemp31268 +LITERAL_FLOAT:1.0 +/(*(parsertemp31268,sum(WM)),-(sum(WM),1.0)) +::STMT +FLOAT:arch_coef,var_coef,a0 +LITERAL_FLOAT:1.0 +/(a0,-(-(1.0,arch_coef),var_coef)) +::STMT +MATRIX:parsertemp220988,parsertemp220989,dY,Y +LITERAL_FLOAT:300.0,0.9 ++(Y,-(*(0.9,dY),*(300.0,-(parsertemp220988,parsertemp220989)))) +::STMT +MATRIX:p,q,r,parsertemp1947 +FLOAT:norm_r2,alpha +LITERAL_FLOAT:-1.0 ++(*(+(r,*(alpha,q)),-1.0),*(/(sum(parsertemp1947),norm_r2),p)) +::STMT +MATRIX:upd_W1 +LITERAL_FLOAT:0.9 +*(0.9,upd_W1) +::STMT +LITERAL_FLOAT:1.0E-30 +1.0E-30 +::STMT +MATRIX:p,q,parsertemp503394,Z +FLOAT:norm_r2 +*(/(norm_r2,cast.FLOAT(%*%(parsertemp503394,q))),%*%(Z,p)) +::STMT +FLOAT:n_group_cols +LITERAL_FLOAT:2.0 ++(2.0,n_group_cols) +::STMT +MATRIX:P,Phi,Theta +%*%(%*%(P,Theta),t(Phi)) +::STMT +MATRIX:2697_out,2697_b,parsertemp459149,parsertemp459147 +/(exp(-(+(parsertemp459147,2697_b),parsertemp459149)),rowSums(exp(-(2697_out,parsertemp459149)))) +::STMT +MATRIX:out +FLOAT:dd,step_sz,wd +-(+(wd,*(step_sz,dd)),sum(out)) +::STMT +MATRIX:A,scale_X,shift_X,parsertemp115874,X +t(+(%*%(diag(scale_X),%*%(parsertemp115874,X)),%*%(shift_X,A))) +::STMT +MATRIX:d +cast.FLOAT(%*%(t(d),d)) +::STMT +MATRIX:tmp +FLOAT:norm_r2_LS +/(*(cast.FLOAT(tmp),cast.FLOAT(tmp)),norm_r2_LS) +::STMT +MATRIX:r,s,grad +-(cast.FLOAT(%*%(t(s),grad)),cast.FLOAT(%*%(t(s),r))) +::STMT +FLOAT:o_init,N +LITERAL_FLOAT:-2.0 +exp(/(*(-2.0,o_init),N)) +::STMT +MATRIX:W2_rand +LITERAL_FLOAT:0.1092173494617922 +*(0.1092173494617922,W2_rand) +::STMT +LITERAL_FLOAT:1.0,-1.0E30 +INT:int11,M +*(-1.0E30,rand(M,int11,1.0,1.0)) +::STMT +MATRIX:the_exp,linear_terms,Y +FLOAT:int787 +*(*(exp(*(the_exp,int787)),exp(linear_terms)),rowSums(Y)) +::STMT +MATRIX:rowSums_X_sq +LITERAL_FLOAT:8.660254037844387 +/(8.660254037844387,max(sqrt(rowSums_X_sq))) +::STMT +MATRIX:parsertemp500607,w,parsertemp500610 +sum(*(-(*(parsertemp500607,parsertemp500610),w),-(*(parsertemp500607,parsertemp500610),w))) +::STMT +MATRIX:Xtest_dists +FLOAT:int953,eps +LITERAL_FLOAT:1.0 +>=(rowSums(*(<=(Xtest_dists,eps),<(int953,Xtest_dists))),1.0) +::STMT +MATRIX:ZtZ,C,Xm,parsertemp265709,Z,parsertemp265701 +%*%(%*%(Xm,%*%(C,parsertemp265701)),t(/(%*%(parsertemp265709,Z),sum(ZtZ)))) +::STMT +MATRIX:s,w +*(w,s) +::STMT +MATRIX:linear_terms +exp(linear_terms) +::STMT +MATRIX:269_Row_norm,parsertemp34343,X_block +LITERAL_FLOAT:0.3 +>(/(%*%(X_block,t(X_block)),%*%(sqrt(parsertemp34343),t(269_Row_norm))),0.3) +::STMT +FLOAT:int874,int128,width,parsertemp387147 +LITERAL_FLOAT:-1.0,2.0 +exp(/(*(-1.0,^(parsertemp387147,int128)),*(2.0,^(width,int874)))) +::STMT +MATRIX:parsertemp174552 +LITERAL_FLOAT:0.0 +sum(==(parsertemp174552,0.0)) +::STMT +FLOAT:s +LITERAL_FLOAT:1.0,5.0 +/(5.0,+(s,1.0)) +::STMT +MATRIX:parsertemp436668,X,parsertemp436672,bc_matrix +LITERAL_FLOAT:2.0 +-(*(bc_matrix,t(rowSums(parsertemp436668))),*(2.0,%*%(X,t(parsertemp436672)))) +::STMT +MATRIX:resp,X,weight +LITERAL_FLOAT:2.0 +/(%*%(t(resp),^(X,2.0)),t(weight)) +::STMT +MATRIX:B,C,D,E,parsertemp462474 +%*%(==(%*%(<=(parsertemp462474,B),C),D),E) +::STMT +MATRIX:X,permut +FLOAT:n +LITERAL_FLOAT:2.0 +/(colSums(^(%*%(permut,X),2.0)),n) +::STMT +MATRIX:parsertemp411208,parsertemp411210,parsertemp411199,X,parsertemp411201,parsertemp411217 +-(sum(%*%(/(parsertemp411208,parsertemp411210),/(parsertemp411199,parsertemp411201))),sum(*(X,parsertemp411217))) +::STMT +MATRIX:C,parsertemp174574 +FLOAT:numRows +/(sum(==(parsertemp174574,C)),numRows) +::STMT +LITERAL_FLOAT:1.0,2003.0 ++(2003.0,1.0) +::STMT +MATRIX:X_orig +FLOAT:parsertemp164950 ++(ncol(X_orig),parsertemp164950) +::STMT +MATRIX:parsertemp196005 +FLOAT:parsertemp191170,Wf +LITERAL_FLOAT:2.0 +*(parsertemp196005,sqrt(/(2.0,*(parsertemp191170,Wf)))) +::STMT +MATRIX:tmp,X +FLOAT:x +*(cast.FLOAT(tmp),/(-(x,cast.FLOAT(X)),-(cast.FLOAT(X),cast.FLOAT(X)))) +::STMT +MATRIX:W1_rand,parsertemp396312,X,parsertemp396302 +FLOAT:float297 +LITERAL_FLOAT:0.07808688094430302 +%*%(*(0.07808688094430302,W1_rand),t(/(-(X,parsertemp396302),+(parsertemp396312,float297)))) +::STMT +MATRIX:newbeta,lambda +LITERAL_FLOAT:2.0 +*(cast.FLOAT(lambda),^(cast.FLOAT(newbeta),2.0)) +::STMT +MATRIX:parsertemp393583,W4_rand +FLOAT:int268,int639 +LITERAL_FLOAT:0.08709382882250233 +%*%(*(0.08709382882250233,W4_rand),t(/(-(parsertemp393583,int639),+(parsertemp393583,int268)))) +::STMT +MATRIX:Nc +==(Nc,max(Nc)) +::STMT +MATRIX:parsertemp31030,parsertemp31032 +LITERAL_FLOAT:149.0,2.0,3352500.0 +/(^(/(-(parsertemp31030,parsertemp31032),149.0),2.0),3352500.0) +::STMT +MATRIX:parsertemp472404 +FLOAT:max_features,n +<=(parsertemp472404,/(^(n,max_features),n)) +::STMT +MATRIX:77_Y_row_norm,parsertemp16864 +FLOAT:float693 +LITERAL_FLOAT:1.0E-6 +t(+(sqrt(rowSums(parsertemp16864)),*(<(77_Y_row_norm,float693),1.0E-6))) +::STMT +MATRIX:g_reg,q_CG,p_CG,z +FLOAT:float720,277_tau_1,pq_CG ++(+(*(*(float720,277_tau_1),pq_CG),*(cast.FLOAT(z),cast.FLOAT(q_CG))),sum(*(g_reg,p_CG))) +::STMT +MATRIX:X +X +::STMT +MATRIX:p_CG +FLOAT:trust_delta_sq,z +sqrt(-(*(cast.FLOAT(p_CG),cast.FLOAT(p_CG)),*(cast.FLOAT(p_CG),-(z,trust_delta_sq)))) +::STMT +MATRIX:parsertemp389580,parsertemp389560,2365_delta3,W3 +FLOAT:int629 +LITERAL_FLOAT:1.0 +%*%(t(*(-(int629,parsertemp389580),%*%(2365_delta3,W3))),/(-(exp(parsertemp389560),1.0),+(exp(parsertemp389560),1.0))) +::STMT +MATRIX:parsertemp410245,parsertemp410248 +FLOAT:int652,float185 +LITERAL_FLOAT:0.6666666666666666 +min(^(/(-(int652,parsertemp410245),*(float185,parsertemp410248)),0.6666666666666666)) +::STMT +MATRIX:Q,R,parsertemp500360,parsertemp500308,parsertemp500359,parsertemp500300 +LITERAL_FLOAT:2.0 +-(+(%*%(rowSums(parsertemp500300),parsertemp500359),%*%(parsertemp500360,t(parsertemp500308))),*(2.0,%*%(R,t(Q)))) +::STMT +MATRIX:y_train,prediction +LITERAL_FLOAT:0.5 +==(y_train,>(prediction,0.5)) +::STMT +MATRIX:newbeta,lambda +LITERAL_FLOAT:2.0,0.5 +*(0.5,*(cast.FLOAT(lambda),^(cast.FLOAT(newbeta),2.0))) +::STMT +MATRIX:R +FLOAT:minSup +>=(R,minSup) +::STMT +MATRIX:w,ssX_p_CG,X +cast.FLOAT(%*%(t(X),*(w,%*%(X,ssX_p_CG)))) +::STMT +FLOAT:ratio +LITERAL_FLOAT:15.0 +*(15.0,ratio) +::STMT +MATRIX:G,minDist +FLOAT:int625 +LITERAL_FLOAT:-1.0 +^(+(G,*(!=(G,int625),minDist)),-1.0) +::STMT +LITERAL_FLOAT:3.0,2003.0 ++(2003.0,3.0) +::STMT +MATRIX:w,g +FLOAT:alpha,tau +-(abs(-(w,/(g,alpha))),/(tau,alpha)) +::STMT +MATRIX:parsertemp72334 +FLOAT:rows +cast.FLOAT(/(colSums(rowSums(parsertemp72334)),rows)) +::STMT +FLOAT:new_log_l,saturated_log_l +LITERAL_FLOAT:2.0 +*(2.0,-(saturated_log_l,new_log_l)) +::STMT +MATRIX:n_risk,n_event +/(n_event,*(n_risk,-(n_risk,n_event))) +::STMT +MATRIX:parsertemp283570,tpr,fpr,parsertemp283568 +LITERAL_FLOAT:2.0 ++(*(cast.FLOAT(tpr),cast.FLOAT(fpr)),sum(/(*(parsertemp283568,parsertemp283570),2.0))) +::STMT +MATRIX:X +LITERAL_FLOAT:-2.0,2.0 ++(*(-2.0,%*%(X,t(X))),rowSums(^(X,2.0))) +::STMT +MATRIX:prec,X,mu +*(-(%*%(X,prec),%*%(mu,prec)),-(%*%(X,prec),%*%(mu,prec))) +::STMT +MATRIX:mean,X,weight +-(%*%(t(X),X),%*%(*(t(mean),weight),mean)) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,3.0 ++(*(3.0,-(i,1.0)),1.0) +::STMT +MATRIX:X +LITERAL_FLOAT:0.0 +-(0.0,sum(X)) +::STMT +FLOAT:parsertemp85,int24,wt,parsertemp90 +LITERAL_FLOAT:1.0,2.0,4.0 +*(*(4.0,-(^(wt,int24),1.0)),^(sqrt(/(parsertemp85,parsertemp90)),2.0)) +::STMT +MATRIX:lambda,g,beta +LITERAL_FLOAT:-1.0 +*(+(g,*(lambda,beta)),-1.0) +::STMT +MATRIX:ubScores,fSizes +FLOAT:minsc +LITERAL_FLOAT:0.0 +&(fSizes,&(>(ubScores,minsc),>(ubScores,0.0))) +::STMT +LITERAL_FLOAT:1.0,2.0,4.0,2001.0 +*(4.0,-(^(2001.0,2.0),1.0)) +::STMT +LITERAL_FLOAT:-1.0 +INT:int571,n +diag(rand(n,int571,-1.0,-1.0)) +::STMT +LITERAL_FLOAT:1.0 +INT:int269,n +diag(rand(n,int269,1.0,1.0)) +::STMT +MATRIX:parsertemp71758,is_row_in_samples +FLOAT:sample_block_size +LITERAL_FLOAT:1.0,3.0 +-(+(*(sample_block_size,3.0),1.0),*(is_row_in_samples,parsertemp71758)) +::STMT +MATRIX:scale_X,w,parsertemp170066,X +*(cast.FLOAT(diag(scale_X)),cast.FLOAT(%*%(t(X),*(w,parsertemp170066)))) +::STMT +FLOAT:s +LITERAL_FLOAT:3.0 +^(3.0,s) +::STMT +MATRIX:m_iter_err_sum_squared,parsertemp379572,parsertemp379570,parsertemp379563 +FLOAT:i_process_item +LITERAL_FLOAT:1.0 +sqrt(/(+(-(parsertemp379570,parsertemp379572),+(parsertemp379563,m_iter_err_sum_squared)),-(i_process_item,1.0))) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,3.0 ++(*(3.0,-(i,1.0)),3.0) +::STMT +MATRIX:X,Y +FLOAT:x +LITERAL_FLOAT:1.0 +*(-(1.0,/(-(x,X),-(X,X))),Y) +::STMT +MATRIX:w,parsertemp500601 +FLOAT:alpha,tau +LITERAL_FLOAT:0.0 +>(-(abs(-(w,parsertemp500601)),/(tau,alpha)),0.0) +::STMT +MATRIX:parsertemp131967 +*(ncol(parsertemp131967),nrow(parsertemp131967)) +::STMT +MATRIX:parsertemp265718,parsertemp265715 +FLOAT:Xm +LITERAL_FLOAT:2.0,4000.0 +/(-(+(Xm,trace(parsertemp265715)),*(2.0,cast.FLOAT(parsertemp265718))),4000.0) +::STMT +MATRIX:m_iter_err_sum_squared,m_err +LITERAL_FLOAT:2.0 ++(colSums(^(m_err,2.0)),m_iter_err_sum_squared) +::STMT +MATRIX:obj,gs,parsertemp44066,parsertemp44078 +FLOAT:parsertemp44082 +LITERAL_FLOAT:-0.5 +cast.FLOAT(/(-(obj,+(parsertemp44078,parsertemp44082)),*(-0.5,-(gs,parsertemp44066)))) +::STMT +MATRIX:X +FLOAT:x +LITERAL_FLOAT:1.0 +-(1.0,/(-(x,cast.FLOAT(X)),-(cast.FLOAT(X),cast.FLOAT(X)))) +::STMT +MATRIX:parsertemp552530 +LITERAL_FLOAT:0.0 +INT:parsertemp552529,idx ++(rand(parsertemp552529,idx,0.0,0.0),t(parsertemp552530)) +::STMT +MATRIX:Q,ssX_V,X,parsertemp150463,P_1K +-(*(P_1K,%*%(X,ssX_V)),*(P_1K,%*%(rowSums(Q),parsertemp150463))) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,61.0 ++(*(-(i,1.0),61.0),61.0) +::STMT +MATRIX:prob,test_Y +FLOAT:threshold +LITERAL_FLOAT:0.0 +*(test_Y,==(>(prob,threshold),0.0)) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,61.0,34.0 ++(*(-(i,1.0),61.0),34.0) +::STMT +MATRIX:X +LITERAL_FLOAT:0.0 +!=(rowSums(!=(X,0.0)),0.0) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,2.0,3.0 ++(*(3.0,-(i,1.0)),2.0) +::STMT +MATRIX:ss +LITERAL_FLOAT:0.050000000000000044,1.0,20.0 +*(0.050000000000000044,-(/(20.0,ss),1.0)) +::STMT +MATRIX:subspace_idx,parsertemp109953 +LITERAL_FLOAT:1.0,42.0 +<(-(subspace_idx,*(parsertemp109953,42.0)),1.0) +::STMT +MATRIX:parsertemp2832 +sum(==(round(parsertemp2832),min(round(parsertemp2832)))) +::STMT +MATRIX:parsertemp24100 +FLOAT:bin_width +LITERAL_FLOAT:1.0,0.5 ++(round(-(/(parsertemp24100,bin_width),0.5)),1.0) +::STMT +MATRIX:p,Z +cast.FLOAT(%*%(t(p),%*%(Z,p))) +::STMT +MATRIX:By2,By1 +LITERAL_FLOAT:3.0 +*(3.0,+(By1,By2)) +::STMT +FLOAT:m2X +LITERAL_FLOAT:1.0,100000.0 +*(m2X,/(100000.0,-(100000.0,1.0))) +::STMT +MATRIX:C,parsertemp11014 +LITERAL_FLOAT:1000.0,100.0 +*(/(sum(==(parsertemp11014,C)),1000.0),100.0) +::STMT +MATRIX:parsertemp2832 +sum(==(round(parsertemp2832),max(round(parsertemp2832)))) +::STMT +LITERAL_FLOAT:3.5355339059327378 +3.5355339059327378 +::STMT +FLOAT:dist ++(cast.MATRIX(dist),t(cast.MATRIX(dist))) +::STMT +MATRIX:parsertemp409054,ctab +FLOAT:threshold +>(/(parsertemp409054,rowSums(ctab)),threshold) +::STMT +MATRIX:_sbcvar95,_sbcvar97 +FLOAT:221_my +LITERAL_FLOAT:0.0 ++(%*%(_sbcvar95,_sbcvar97),-(0.0,221_my)) +::STMT +LITERAL_FLOAT:0.0 +INT:int373,int452,int579,int618 +%*%(t(rand(int373,int618,0.0,0.0)),rand(int452,int579,0.0,0.0)) +::STMT +MATRIX:leaf_ids +FLOAT:boundary_left,step_size +&(!(<(leaf_ids,+(boundary_left,step_size))),<(leaf_ids,+(+(boundary_left,step_size),step_size))) +::STMT +MATRIX:parsertemp163357 +LITERAL_FLOAT:1.0 +t(/(1.0,parsertemp163357)) +::STMT +MATRIX:ss +LITERAL_FLOAT:1.0 +/(1.0,t(ss)) +::STMT +MATRIX:parsertemp149867,Y +FLOAT:int506 +LITERAL_FLOAT:100.0 +*(/(sum(==(parsertemp149867,int506)),nrow(Y)),100.0) +::STMT +MATRIX:W1_rand +LITERAL_FLOAT:0.06835859270246632 +*(0.06835859270246632,W1_rand) +::STMT +MATRIX:221_CFreqs,_sbcvar95,_sbcvar98 +FLOAT:int359 +LITERAL_FLOAT:1000.0 +/(sum(*(+(221_CFreqs,int359),%*%(_sbcvar95,_sbcvar98))),-(1000.0,nrow(_sbcvar95))) +::STMT +MATRIX:sv,Y,Xd,out +sum(*(*(*(out,sv),Y),Xd)) +::STMT +MATRIX:w,X,y +t(-(%*%(X,w),y)) +::STMT +MATRIX:output1,dataset +LITERAL_FLOAT:0.0,0.16 +==(<(abs(-(output1,dataset)),0.16),0.0) +::STMT +MATRIX:w,out +FLOAT:lambda +LITERAL_FLOAT:2.0,0.5 ++(*(0.5,sum(*(out,out))),*(/(lambda,2.0),sum(*(w,w)))) +::STMT +MATRIX:parsertemp447299 +LITERAL_FLOAT:1.0 +t(-(parsertemp447299,1.0)) +::STMT +MATRIX:w,X,y +%*%(t(X),-(%*%(X,w),y)) +::STMT +MATRIX:parsertemp170251,lt_pos_neg +FLOAT:int508 +LITERAL_FLOAT:2.0,0.5 +*(-(0.5,lt_pos_neg),exp(/(-(int508,parsertemp170251),2.0))) +::STMT +MATRIX:g_new,g_old +/(sum(*(g_new,g_new)),sum(*(g_old,g_old))) +::STMT +MATRIX:X +LITERAL_FLOAT:0.0,1.0 +==(<(X,1.0),0.0) +::STMT +MATRIX:sv,Xd +FLOAT:dd ++(dd,sum(*(*(Xd,sv),Xd))) +::STMT +MATRIX:parsertemp115729,parsertemp115724 +FLOAT:eAvg,n2 +LITERAL_FLOAT:0.050000000000000044,1.0,0.95 +-(*(0.95,-(/(parsertemp115724,eAvg),1.0)),*(0.050000000000000044,-(*(parsertemp115729,n2),1.0))) +::STMT +MATRIX:id +diag(==(id,t(id))) +::STMT +MATRIX:termination_bitmap,parsertemp222665 +FLOAT:int772,parsertemp222668,int577 +min(+(*(parsertemp222665,termination_bitmap),*(+(parsertemp222668,int772),-(int577,termination_bitmap)))) +::STMT +MATRIX:E,X +LITERAL_FLOAT:-1.0 +*(t(colSums(*(X,E))),-1.0) +::STMT +MATRIX:parsertemp150393 +LITERAL_FLOAT:0.0,0.1 +sum(==(<(abs(parsertemp150393),0.1),0.0)) +::STMT +MATRIX:means,parsertemp560511 +LITERAL_FLOAT:2.0 +^(rowSums(*(means,parsertemp560511)),2.0) +::STMT +MATRIX:finite_linear_terms +LITERAL_FLOAT:1.0,2.0 +-(1.0,/(exp(finite_linear_terms),2.0)) +::STMT +MATRIX:scale_lambda,parsertemp150455 +LITERAL_FLOAT:1.0E-5 +*(%*%(scale_lambda,parsertemp150455),1.0E-5) +::STMT +MATRIX:X +FLOAT:index +LITERAL_FLOAT:1.0,2.0 ++(*(index,-(ncol(X),1.0)),2.0) +::STMT +MATRIX:p_gaps_vector +LITERAL_FLOAT:0.0 +t(>(p_gaps_vector,0.0)) +::STMT +MATRIX:X +LITERAL_FLOAT:2.0 +-(ncol(X),2.0) +::STMT +MATRIX:F +FLOAT:q +LITERAL_FLOAT:1.0 +*(sum(F),-(q,1.0)) +::STMT +MATRIX:X,Centering +LITERAL_FLOAT:2.0 +colSums(^(-(X,Centering),2.0)) +::STMT +MATRIX:m_iter_err_sum_squared,parsertemp379560,m_err_mean,m_iter_err_sum,m_err +FLOAT:int71,int123,i_process_item,int826 ++(-(*(^(m_err_mean,int123),i_process_item),*(*(int826,m_err_mean),+(parsertemp379560,m_iter_err_sum))),+(colSums(^(m_err,int71)),m_iter_err_sum_squared)) +::STMT +MATRIX:d,alpha +*(cast.FLOAT(alpha),d) +::STMT +MATRIX:prevTK2,totalE,X2 +*(==(%*%(X2,t(prevTK2)),t(rowSums(prevTK2))),totalE) +::STMT +FLOAT:parsertemp166531 +LITERAL_FLOAT:10.0 +*(10.0,parsertemp166531) +::STMT +MATRIX:parsertemp170136 +FLOAT:278_sq_root_d,parsertemp170150,pq_CG +LITERAL_FLOAT:0.5 +*(*(0.5,/(+(parsertemp170150,278_sq_root_d),sum(parsertemp170136))),pq_CG) +::STMT +MATRIX:FXY +LITERAL_FLOAT:1.0 +-(ncol(FXY),1.0) +::STMT +MATRIX:G,authorities +/(%*%(t(G),%*%(G,authorities)),max(%*%(t(G),%*%(G,authorities)))) +::STMT +MATRIX:shift_X,ssX_newbeta,z,beta ++(ssX_newbeta,%*%(t(shift_X),+(beta,z))) +::STMT +MATRIX:_sbcvar96,_sbcvar95,_sbcvar98 +LITERAL_FLOAT:-1.0 +*(+(%*%(_sbcvar95,_sbcvar96),-1.0),%*%(_sbcvar95,_sbcvar98)) +::STMT +MATRIX:V,y +LITERAL_FLOAT:-1.0 +*(%*%(t(V),y),-1.0) +::STMT +FLOAT:idx +LITERAL_FLOAT:64.0 +-(64.0,idx) +::STMT +FLOAT:e,mu,epochs +LITERAL_FLOAT:0.999,1.0 +/(-(0.999,mu),-(+(1.0,epochs),e)) +::STMT +LITERAL_FLOAT:1.0E-6 +INT:int362,int452 +rand(int452,int362,1.0E-6,1.0E-6) +::STMT +FLOAT:parsertemp22485,parsertemp22452,parsertemp22453 +abs(/(parsertemp22485,sqrt(+(parsertemp22452,parsertemp22453)))) +::STMT +MATRIX:parsertemp43626 +LITERAL_FLOAT:-1.0,2.0 +INT:int411,int621 +rand(int411,int621,*(2.0,*(-1.0,sum(parsertemp43626))),*(2.0,*(-1.0,sum(parsertemp43626)))) +::STMT +MATRIX:W +LITERAL_FLOAT:1.0,2.0 +*(^(sum(round(W)),2.0),+(sum(round(W)),1.0)) +::STMT +FLOAT:int1,parsertemp86,int43,parsertemp87,int280,wt +sqrt(/(*(*(int280,wt),-(wt,int1)),*(*(parsertemp86,parsertemp87),+(wt,int43)))) +::STMT +MATRIX:classes +FLOAT:split ++(split,nrow(classes)) +::STMT +MATRIX:V,y +LITERAL_FLOAT:-1.0 +*(*(%*%(t(V),y),-1.0),*(%*%(t(V),y),-1.0)) +::STMT +FLOAT:n_group_cols +LITERAL_FLOAT:3.0 ++(3.0,n_group_cols) +::STMT +MATRIX:rowSums_X_sq +LITERAL_FLOAT:2.29128784747792 +/(2.29128784747792,max(sqrt(rowSums_X_sq))) +::STMT +MATRIX:means,Y_counts,Y,parsertemp560603 +FLOAT:parsertemp560604 +LITERAL_FLOAT:2.0 +^(-(-(Y,means),%*%(Y_counts,/(parsertemp560603,parsertemp560604))),2.0) +::STMT +MATRIX:2883_ctab +FLOAT:int703 +LITERAL_FLOAT:1.0 +sum(==(rowSums(!=(2883_ctab,int703)),1.0)) +::STMT +MATRIX:g_new,parsertemp468777,tmp,g_old +/(cast.FLOAT(%*%(t(g_new),-(parsertemp468777,tmp))),cast.FLOAT(%*%(t(g_old),g_old))) +::STMT +FLOAT:norm_r2,norm_r2_initial +sqrt(/(norm_r2,norm_r2_initial)) +::STMT +MATRIX:Y +FLOAT:parsertemp185166 +-(parsertemp185166,min(Y)) +::STMT +MATRIX:X,parsertemp386475 +FLOAT:int965 +sqrt(+(+(*(int965,parsertemp386475),X),t(X))) +::STMT +MATRIX:2701_mask,doutd3 +LITERAL_FLOAT:0.5 +*(/(2701_mask,0.5),doutd3) +::STMT +MATRIX:svUpBnd,R,svLowBnd +*(<=(R,cast.FLOAT(svUpBnd)),>(R,cast.FLOAT(svLowBnd))) +::STMT +MATRIX:P12,map +LITERAL_FLOAT:0.0 +rowSums(!=(%*%(map,P12),0.0)) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,51.0,64.0 ++(*(-(i,1.0),64.0),51.0) +::STMT +MATRIX:parsertemp1531,y +FLOAT:int824 +LITERAL_FLOAT:2.0 +sum(^(%*%(-(int824,parsertemp1531),y),2.0)) +::STMT +FLOAT:K +LITERAL_FLOAT:11.0 +*(11.0,K) +::STMT +FLOAT:C,K +LITERAL_FLOAT:1.0,2.0 +*(*(C,+(C,1.0)),^(K,2.0)) +::STMT +MATRIX:prediction,target +LITERAL_FLOAT:2.0,0.5 +*(0.5,rowSums(^(-(prediction,target),2.0))) +::STMT +MATRIX:os,y,o +FLOAT:int829 +LITERAL_FLOAT:1.0 ++(1.0,exp(*(*(y,int829),+(o,os)))) +::STMT +FLOAT:m2 +LITERAL_FLOAT:1.0005 +sqrt(*(1.0005,m2)) +::STMT +MATRIX:lambda,scale_X,p_CG,w,parsertemp170066,X ++(*(lambda,p_CG),*(cast.FLOAT(diag(scale_X)),%*%(t(X),*(w,parsertemp170066)))) +::STMT +MATRIX:parsertemp382670,X +LITERAL_FLOAT:0.0,2.0 +sum(*(!=(X,0.0),^(-(parsertemp382670,X),2.0))) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,60.0,64.0 ++(*(-(i,1.0),64.0),60.0) +::STMT +FLOAT:C,Hf,Wf +LITERAL_FLOAT:2.0 +/(2.0,*(*(C,Hf),Wf)) +::STMT +MATRIX:linear_terms,Y +FLOAT:parsertemp171226,link_power,parsertemp171223,int493 +/(*(^(linear_terms,-(parsertemp171226,int493)),-(Y,^(linear_terms,parsertemp171223))),link_power) +::STMT +FLOAT:int276,z,pp_CG,parsertemp170091 +LITERAL_FLOAT:0.5 +*(0.5,/(+(*(z,int276),sqrt(parsertemp170091)),pp_CG)) +::STMT +MATRIX:r,c,F +LITERAL_FLOAT:2.0 +^(-(F,/(%*%(r,c),sum(F))),2.0) +::STMT +FLOAT:float658,float239,float677,float221 +LITERAL_FLOAT:2.0 +INT:int110,int752,int269,int936 ++(sum(^(rand(int752,int936,float221,float239),2.0)),sum(^(rand(int269,int110,float677,float658),2.0))) +::STMT +MATRIX:R,B,parsertemp503780 +%*%(t(+(R,diag(parsertemp503780))),B) +::STMT +LITERAL_FLOAT:1.0,20.0 ++(20.0,1.0) +::STMT +MATRIX:X,mu,parsertemp183827,ScaleFactor +FLOAT:int264,N +LITERAL_FLOAT:1.0 +-(/(%*%(t(X),/(X,ScaleFactor)),-(N,1.0)),*(/(N,-(N,int264)),%*%(t(mu),/(parsertemp183827,N)))) +::STMT +LITERAL_FLOAT:1.0,7000.0 +-(7000.0,1.0) +::STMT +MATRIX:knn_index +FLOAT:s +LITERAL_FLOAT:100.0 +*(/(s,100.0),ncol(knn_index)) +::STMT +FLOAT:p,P,Q,q,int89 ++(+(+(+(int89,p),P),Q),q) +::STMT +FLOAT:2344_s_err_vars,2344_s_err_mean +LITERAL_FLOAT:-1.0,0.001 +/(-(*(0.001,-1.0),2344_s_err_mean),2344_s_err_vars) +::STMT +MATRIX:Y +FLOAT:class +LITERAL_FLOAT:1.0,2.0 +-(*(2.0,==(Y,class)),1.0) +::STMT +FLOAT:int520,int776,parsertemp459331,Win +LITERAL_FLOAT:2.0,64.0 +/(2.0,*(*(64.0,/(parsertemp459331,int776)),/(/(Win,int520),2.0))) +::STMT +MATRIX:W1_rand,stds,parsertemp400568 +LITERAL_FLOAT:0.08333333333333333 +t(%*%(*(0.08333333333333333,W1_rand),t(/(parsertemp400568,stds)))) +::STMT +MATRIX:p_CG +FLOAT:int158,parsertemp254749,z,parsertemp254772,int517 +*(parsertemp254772,/(-(*(z,int158),sqrt(parsertemp254749)),sum(^(p_CG,int517)))) +::STMT +MATRIX:ytest +FLOAT:mean_y_test,int293 +LITERAL_FLOAT:0.0,1.0,2.0 +/(-(^(cast.FLOAT(ytest),2.0),*(1.0,^(mean_y_test,int293))),0.0) +::STMT +MATRIX:X2 +FLOAT:minSup +>=(t(colSums(X2)),minSup) +::STMT +MATRIX:B,S +LITERAL_FLOAT:2.0 +^(+(B,S),2.0) +::STMT +MATRIX:parsertemp31105,parsertemp31107 +FLOAT:int559,int592 +LITERAL_FLOAT:1.0,2.0,2000.0 +/(^(/(-(parsertemp31105,parsertemp31107),-(int559,int592)),2.0),*(^(2000.0,2.0),-(2000.0,1.0))) +::STMT +MATRIX:D,parsertemp570375,classMeans +LITERAL_FLOAT:1.0,2.0 +*(/(1.0,2.0),%*%(%*%(-(D,classMeans),parsertemp570375),t(-(D,classMeans)))) +::STMT +FLOAT:link_power +LITERAL_FLOAT:0.0,1.0 +-(/(0.0,link_power),1.0) +::STMT +FLOAT:parsertemp496694,int349,parsertemp496686,n,a0 +LITERAL_FLOAT:1.0,2.0 +*(/(1.0,*(2.0,n)),+(parsertemp496694,/(^(parsertemp496686,int349),a0))) +::STMT +MATRIX:yhat +FLOAT:ytest,int615 +LITERAL_FLOAT:1.0,2.0 +-(^(-(cast.FLOAT(ytest),cast.FLOAT(yhat)),2.0),*(1.0,^(/(ytest,int615),2.0))) +::STMT +MATRIX:id +diag(==(id,cast.FLOAT(id))) +::STMT +MATRIX:parsertemp456742,X,y +LITERAL_FLOAT:0.0 +%*%(t(-(0.0,%*%(parsertemp456742,y))),%*%(t(X),y)) +::STMT +MATRIX:parsertemp410081,d_r_rev,parsertemp410090 +FLOAT:o +LITERAL_FLOAT:-1.0 +-(+(*(cast.FLOAT(parsertemp410081),-1.0),cast.FLOAT(%*%(d_r_rev,parsertemp410090))),o) +::STMT +MATRIX:parsertemp570396,classVars +*(diag(parsertemp570396),max(classVars)) +::STMT +MATRIX:subspace_idx,parsertemp72201 +FLOAT:subvector_size +LITERAL_FLOAT:1.0 +/(1.0,<(-(subspace_idx,*(parsertemp72201,subvector_size)),1.0)) +::STMT +MATRIX:252_X,252_K +*(cast.FLOAT(252_K),-(cast.FLOAT(252_X),cast.FLOAT(252_X))) +::STMT +MATRIX:is_natural_parameter_log_zero,Y +sum(*(is_natural_parameter_log_zero,abs(Y))) +::STMT +MATRIX:X_Train,X_Test ++(sum(X_Train),sum(X_Test)) +::STMT +MATRIX:G,authorities,hubs +LITERAL_FLOAT:2.0 +^(-(/(%*%(G,authorities),max(hubs)),hubs),2.0) +::STMT +FLOAT:parsertemp115827,sum_sq_y_test,n +LITERAL_FLOAT:1.0 +sqrt(/(-(sum_sq_y_test,*(n,parsertemp115827)),-(n,1.0))) +::STMT +FLOAT:link_power +LITERAL_FLOAT:2.0 +-(/(2.0,link_power),2.0) +::STMT +FLOAT:link_power +LITERAL_FLOAT:0.0,2.0 +-(/(0.0,link_power),2.0) +::STMT +MATRIX:images +LITERAL_FLOAT:2.0,255.0 +*(/(images,255.0),2.0) +::STMT +MATRIX:s,w +LITERAL_FLOAT:1.0 +*(1.0,cast.FLOAT(%*%(t(w),s))) +::STMT +FLOAT:m2X +LITERAL_FLOAT:1.000010000100001 +*(m2X,1.000010000100001) +::STMT +FLOAT:check_max,check_min +/(+(check_min,check_max),-(check_max,check_min)) +::STMT +MATRIX:_sbcvar14,_sbcvar13 +FLOAT:int143,parsertemp13703,int127 +LITERAL_FLOAT:999.0 +/(sum(*(-(_sbcvar13,int143),_sbcvar14)),*(999.0,/(*(parsertemp13703,int127),999.0))) +::STMT +FLOAT:wcss +LITERAL_FLOAT:1.0E-6 +*(1.0E-6,wcss) +::STMT +MATRIX:parsertemp31763,parsertemp31756,parsertemp31758,maxsc +FLOAT:minSup +LITERAL_FLOAT:0.0 +&(&(>=(t(parsertemp31756),minSup),>(t(parsertemp31763),0.0)),|(>(t(parsertemp31758),0.0),>(maxsc,0.0))) +::STMT +MATRIX:r,parsertemp44050 +sqrt(sum(*(-(r,parsertemp44050),-(r,parsertemp44050)))) +::STMT +FLOAT:deviance_nodisp +LITERAL_FLOAT:0.1 ++(deviance_nodisp,0.1) +::STMT +MATRIX:y +FLOAT:n +LITERAL_FLOAT:2.0 +/(sum(^(y,2.0)),n) +::STMT +LITERAL_FLOAT:0.0,1.0 +INT:int467,m +sum(abs(rand(m,int467,0.0,1.0))) +::STMT +MATRIX:parsertemp436667,precisions +LITERAL_FLOAT:1.0 +INT:parsertemp436666,int896 +*(rand(int896,parsertemp436666,1.0,1.0),t(rowSums(*(parsertemp436667,precisions)))) +::STMT +MATRIX:p,q,lambda +*(p,+(q,*(lambda,p))) +::STMT +MATRIX:svUpBnd,R,svLowBnd +diag(*(<=(R,cast.FLOAT(svUpBnd)),>(R,cast.FLOAT(svLowBnd)))) +::STMT +MATRIX:lambda,B,Grad +LITERAL_FLOAT:2.0 +sum(^(+(Grad,*(lambda,B)),2.0)) +::STMT +MATRIX:W4_rand +LITERAL_FLOAT:0.08720414403938946 +*(0.08720414403938946,W4_rand) +::STMT +MATRIX:parsertemp415351,parsertemp415356 +FLOAT:parsertemp415362,parsertemp415358,n +LITERAL_FLOAT:1.0 +-(1.0,/(-(sum(parsertemp415356),*(n,parsertemp415358)),-(sum(parsertemp415351),*(n,parsertemp415362)))) +::STMT +MATRIX:y_residual,ytest +FLOAT:int275,avg_res,mean_y_test,int699,int768,int838 +/(-(sum(^(y_residual,int838)),*($1:nrow(ytest),^(avg_res,int275))),-(sum(^(ytest,int768)),*($1,^(mean_y_test,int699)))) +::STMT +MATRIX:grad +FLOAT:int211 +LITERAL_FLOAT:2.0 +sqrt(sum(^(-(int211,grad),2.0))) +::STMT +MATRIX:_sbcvar92,parsertemp27721,220_r,220_c,220_E +FLOAT:int757 +LITERAL_FLOAT:2.0,1.0E-4 +/(^(-(_sbcvar92,+(parsertemp27721,220_E)),2.0),+(*(==(220_E,int757),1.0E-4),/(%*%(220_r,220_c),sum(_sbcvar92)))) +::STMT +MATRIX:s +LITERAL_FLOAT:1.0 +/(1.0,s) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:-1.0 +/(-1.0,linear_terms) +::STMT +MATRIX:184_probs,183_dpred,parsertemp146939 +FLOAT:beta1 +LITERAL_FLOAT:1.0 +*(-(1.0,beta1),colSums(-(*(183_dpred,184_probs),*(184_probs,parsertemp146939)))) +::STMT +MATRIX:r,Hd +FLOAT:parsertemp44049 +sum(*(-(r,*(parsertemp44049,Hd)),-(r,*(parsertemp44049,Hd)))) +::STMT +MATRIX:parsertemp222310 +FLOAT:parsertemp222313 +LITERAL_FLOAT:0.5 ++(/(parsertemp222310,parsertemp222313),0.5) +::STMT +MATRIX:parsertemp414372,X +FLOAT:int923,int309 +LITERAL_FLOAT:200.0,2.0 +-(t(colSums(^(X,int923))),*(200.0,^(/(parsertemp414372,int309),2.0))) +::STMT +FLOAT:k +LITERAL_FLOAT:1.0,4.0 +-(+(k,4.0),1.0) +::STMT +FLOAT:parsertemp477829,parsertemp477814,2814_K,2814_X,2814_Y,inp_x +LITERAL_FLOAT:1.0 ++(*(-(*(2814_K,2814_X),-(2814_Y,2814_Y)),-(1.0,/(parsertemp477814,2814_X))),*(+(*(parsertemp477829,2814_X),-(2814_Y,2814_Y)),/(-(inp_x,2814_X),-(2814_X,2814_X)))) +::STMT +FLOAT:output_values,log_odds,float34 +LITERAL_FLOAT:1.0,2.7182818284 ++(1.0,^(2.7182818284,+(log_odds,*(float34,output_values)))) +::STMT +FLOAT:run_index +LITERAL_FLOAT:24.0 +*(24.0,run_index) +::STMT +MATRIX:p,parsertemp1934,parsertemp1935 +FLOAT:eps +cast.FLOAT(%*%(t(p),+(%*%(parsertemp1934,parsertemp1935),*(eps,p)))) +::STMT +MATRIX:parsertemp43620,parsertemp43619 +FLOAT:float10 +LITERAL_FLOAT:1.0 +*(/(1.0,+(1.0,exp(parsertemp43619))),-(1.0,/(1.0,+(float10,parsertemp43620)))) +::STMT +MATRIX:X +FLOAT:N +/(colSums(X),N) +::STMT +MATRIX:parsertemp31024,parsertemp31022 +FLOAT:int627 +LITERAL_FLOAT:1.0,2.0,100.0 +^(/(-(colSums(parsertemp31022),*(int627,parsertemp31024)),-(100.0,1.0)),2.0) +::STMT +MATRIX:finite_linear_terms +FLOAT:int949 +LITERAL_FLOAT:0.0,2.0 +exp(/(-(0.0,^(finite_linear_terms,int949)),2.0)) +::STMT +MATRIX:os,y,o +LITERAL_FLOAT:-1.0 +*(*(y,-1.0),+(o,os)) +::STMT +MATRIX:g +LITERAL_FLOAT:2.0,0.01 +*(0.01,sum(^(g,2.0))) +::STMT +MATRIX:Y,parsertemp171319 +FLOAT:one_over_sqrt_two_pi,float696 +LITERAL_FLOAT:2.0 +*(*(exp(/(parsertemp171319,float696)),^(one_over_sqrt_two_pi,2.0)),rowSums(Y)) +::STMT +MATRIX:negSampleMeans +LITERAL_FLOAT:2.0,1500.0 +*(1500.0,^(negSampleMeans,2.0)) +::STMT +FLOAT:parsertemp169812 +LITERAL_FLOAT:2.302585092994046,0.5 +round(-(/(parsertemp169812,2.302585092994046),0.5)) +::STMT +MATRIX:P,X,Y +%*%(t(X),-(P,Y)) +::STMT +MATRIX:parsertemp285516 +FLOAT:pp,parsertemp285518,parsertemp285520 +LITERAL_FLOAT:-1.0 +/(+(*(sum(parsertemp285516),-1.0),sqrt(-(parsertemp285518,parsertemp285520))),pp) +::STMT +LITERAL_FLOAT:0.08692913816996169 +0.08692913816996169 +::STMT +MATRIX:WM,Y,CMeans +-(CMeans,/(sum(*(Y,WM)),sum(WM))) +::STMT +MATRIX:colSD,colMean +LITERAL_FLOAT:3.0 ++(colMean,*(3.0,colSD)) +::STMT +MATRIX:p,q,V,parsertemp1939 +FLOAT:norm_r2 +LITERAL_FLOAT:1.0E-8 +*(/(norm_r2,cast.FLOAT(%*%(parsertemp1939,q))),+(%*%(t(V),%*%(V,p)),*(1.0E-8,p))) +::STMT +MATRIX:subspace_idx,parsertemp109953 +LITERAL_FLOAT:42.0 +-(subspace_idx,*(parsertemp109953,42.0)) +::STMT +MATRIX:linear_terms +FLOAT:var_power +LITERAL_FLOAT:1.0 +*(^(exp(linear_terms),-(1.0,var_power)),exp(linear_terms)) +::STMT +MATRIX:scale_lambda +LITERAL_FLOAT:0.001 +*(scale_lambda,0.001) +::STMT +MATRIX:X +FLOAT:a0 +LITERAL_FLOAT:2.0 +/(^(cast.FLOAT(X),2.0),a0) +::STMT +MATRIX:parsertemp13711,_sbcvar14 +FLOAT:parsertemp13704,float583 +LITERAL_FLOAT:1.0,999.0 +-(1.0,/(sum(*(parsertemp13711,_sbcvar14)),*(999.0,/(parsertemp13704,float583)))) +::STMT +MATRIX:P,X,Y +LITERAL_FLOAT:2.0 +sum(^(%*%(t(X),-(P,Y)),2.0)) +::STMT +MATRIX:Y_counts,vars +FLOAT:dispersion +/(*(dispersion,colSums(vars)),sum(Y_counts)) +::STMT +MATRIX:termination_bitmap,parsertemp222665,parsertemp222670 +FLOAT:parsertemp222669 +==(*(parsertemp222665,termination_bitmap),min(+(*(parsertemp222665,termination_bitmap),*(parsertemp222669,parsertemp222670)))) +::STMT +MATRIX:parsertemp31026,parsertemp31033 +FLOAT:int571,parsertemp31034,int594,int55,parsertemp31027,int98,int225,int812,int171,int19 +LITERAL_FLOAT:2.0 ++(/(^(/(parsertemp31026,parsertemp31027),2.0),*(^(int571,int98),-(int225,int171))),/(^(/(parsertemp31033,parsertemp31034),2.0),*(^(int19,int594),-(int55,int812)))) +::STMT +MATRIX:X_train +LITERAL_FLOAT:256.0 +/(nrow(X_train),256.0) +::STMT +MATRIX:r,scale_X,shift_X ++(*(scale_X,r),*(cast.FLOAT(r),shift_X)) +::STMT +MATRIX:y_hat,b,R +-(-(b,%*%(R,y_hat)),y_hat) +::STMT +MATRIX:b,H,parsertemp410187,parsertemp410189 +%*%(%*%(t(b),-(+(H,parsertemp410187),diag(parsertemp410189))),b) +::STMT +MATRIX:U,V +FLOAT:int540,int757 +LITERAL_FLOAT:5.0E-7 +*(5.0E-7,+(sum(^(U,int540)),sum(^(V,int757)))) +::STMT +MATRIX:subspace_idx,parsertemp75105 +LITERAL_FLOAT:32.0 +-(subspace_idx,*(parsertemp75105,32.0)) +::STMT +MATRIX:X +LITERAL_FLOAT:-1.0 +*(*(t(colSums(X)),-1.0),-1.0) +::STMT +MATRIX:parsertemp410245,parsertemp410248 +FLOAT:int865,float295 +LITERAL_FLOAT:0.6666666666666666 +max(^(/(-(int865,parsertemp410245),*(float295,parsertemp410248)),0.6666666666666666)) +::STMT +MATRIX:parsertemp146931,184_dtemp,parsertemp146929,184_unnorm_probs,parsertemp146936 +colSums(-(*(*(parsertemp146929,parsertemp146931),/(184_unnorm_probs,parsertemp146936)),*(/(184_unnorm_probs,parsertemp146936),rowSums(184_dtemp)))) +::STMT +MATRIX:P +/(+(P,t(P)),sum(+(P,t(P)))) +::STMT +MATRIX:parsertemp265709,tmp,Z,XtZ +FLOAT:ZtZ_sum +*(tmp,%*%(t(/(XtZ,ZtZ_sum)),/(%*%(parsertemp265709,Z),sum(tmp)))) +::STMT +MATRIX:test_val +LITERAL_FLOAT:128.0 +/(nrow(test_val),128.0) +::STMT +MATRIX:s,w +%*%(t(+(w,s)),+(w,s)) +::STMT +MATRIX:Hdiff,betamax,beta +FLOAT:INF,int983,int39 +==(+(*(>=(Hdiff,int983),betamax),*(<(Hdiff,int39),beta)),INF) +::STMT +MATRIX:subspace_idx,parsertemp73653 +LITERAL_FLOAT:16.0 +-(subspace_idx,*(parsertemp73653,16.0)) +::STMT +MATRIX:subspace_idx,parsertemp107049 +LITERAL_FLOAT:7.0 +-(subspace_idx,*(parsertemp107049,7.0)) +::STMT +LITERAL_FLOAT:1.8378770664093453 +1.8378770664093453 +::STMT +MATRIX:s,parsertemp44016,d +FLOAT:parsertemp44015,delta2 +-(delta2,%*%(t(-(s,parsertemp44016)),-(s,*(parsertemp44015,d)))) +::STMT +MATRIX:parsertemp43626 +LITERAL_FLOAT:-1.0 +sum(*(parsertemp43626,-1.0)) +::STMT +MATRIX:r,d,parsertemp43999 +LITERAL_FLOAT:2.0 +/(sum(^(r,2.0)),cast.FLOAT(%*%(t(d),+(d,parsertemp43999)))) +::STMT +MATRIX:left,tmp,right +==(%*%(tmp,left),%*%(tmp,right)) +::STMT +FLOAT:Z_logl,dispersion +/(Z_logl,sqrt(dispersion)) +::STMT +FLOAT:int81,ytest,int874,parsertemp454076 +LITERAL_FLOAT:0.0 +sqrt(/(-(^(ytest,int874),*(int81,parsertemp454076)),0.0)) +::STMT +MATRIX:r_LS +FLOAT:norm_r2_LS,p_LS,parsertemp170552,lambda_LS +LITERAL_FLOAT:2.0 +^(+(cast.FLOAT(r_LS),*(/(norm_r2_LS,p_LS),+(parsertemp170552,lambda_LS))),2.0) +::STMT +MATRIX:subspace_idx,parsertemp72201 +LITERAL_FLOAT:8.0 +-(subspace_idx,*(parsertemp72201,8.0)) +::STMT +MATRIX:w_X,z_LS,X +/(nrow(X),*(cast.FLOAT(w_X),cast.FLOAT(z_LS))) +::STMT +MATRIX:col,more_than_ub,parsertemp24107,parsertemp24102,parsertemp24103 +FLOAT:int331,num_bins +LITERAL_FLOAT:1.0 ++(+(*(-(parsertemp24107,more_than_ub),+(parsertemp24103,int331)),*(>(col,num_bins),num_bins)),<(+(round(parsertemp24102),1.0),1.0)) +::STMT +MATRIX:parsertemp171315,Y,parsertemp171307,parsertemp171319 +FLOAT:float945,float368,float541 +*(*(exp(/(parsertemp171319,float368)),*(/(float945,parsertemp171307),+(float541,parsertemp171315))),rowSums(Y)) +::STMT +FLOAT:int92,n +LITERAL_FLOAT:1.0,2.0,0.02 +*(-(+(-(n,int92),1.0),2.0),0.02) +::STMT +MATRIX:subspace_idx,parsertemp75105 +LITERAL_FLOAT:1.0,32.0 +<(-(subspace_idx,*(parsertemp75105,32.0)),1.0) +::STMT +MATRIX:Y_prob,Y +*(rowSums(Y),-(*(Y,Y_prob),*(Y,Y_prob))) +::STMT +MATRIX:neighbors +-(neighbors,diag(diag(neighbors))) +::STMT +MATRIX:X +LITERAL_FLOAT:-1.0 +*(t(colSums(X)),-1.0) +::STMT +MATRIX:X,Y,K +FLOAT:int87,x +*(+(*(*(K,int87),-(X,X)),-(Y,Y)),/(-(x,X),-(X,X))) +::STMT +MATRIX:resp +LITERAL_FLOAT:2.22E-16 +t(+(colSums(resp),2.22E-16)) +::STMT +MATRIX:TopIxs,TopVals +LITERAL_FLOAT:0.0 +*(TopIxs,>(TopVals,0.0)) +::STMT +MATRIX:p_CG +FLOAT:parsertemp170089,z,pp_CG +LITERAL_FLOAT:-1.0 +-(*(*(cast.FLOAT(z),sum(p_CG)),-1.0),sqrt(-(*(z,z),*(pp_CG,parsertemp170089)))) +::STMT +MATRIX:r,X,y +FLOAT:int400 +cast.FLOAT(%*%(t(-(int400,r)),%*%(t(X),y))) +::STMT +MATRIX:X +LITERAL_FLOAT:1.0E-6 +*(1.0E-6,cast.FLOAT(%*%(t(X),X))) +::STMT +MATRIX:d,exp_Xb,X +rev(*(%*%(X,d),exp_Xb)) +::STMT +MATRIX:R +FLOAT:i8 +LITERAL_FLOAT:24.0 +-(ncol(R),*(24.0,i8)) +::STMT +MATRIX:W2_rand +LITERAL_FLOAT:0.21483446221182986 +*(0.21483446221182986,W2_rand) +::STMT +LITERAL_FLOAT:1.0E-12 +INT:int552,int757 +diag(rand(int552,int757,1.0E-12,1.0E-12)) +::STMT +MATRIX:C,parsertemp174574 +FLOAT:numRows +LITERAL_FLOAT:100.0 +*(/(sum(==(parsertemp174574,C)),numRows),100.0) +::STMT +FLOAT:a0 +LITERAL_FLOAT:1.0E-5 ++(a0,1.0E-5) +::STMT +MATRIX:parsertemp149307,parsertemp149305 +FLOAT:parsertemp149336,obj,parsertemp149333,parsertemp149340,float839 +LITERAL_FLOAT:-0.5 +/(-(obj,+(+(parsertemp149333,parsertemp149336),*(float839,parsertemp149340))),*(-0.5,-(sum(parsertemp149305),sum(parsertemp149307)))) +::STMT +MATRIX:Y,Xd,parsertemp2775,out +FLOAT:int664,int14 +*(*(*(-(int14,parsertemp2775),>(out,int664)),Y),Xd) +::STMT +MATRIX:2903_mask,dout,2904_X,2902_W +FLOAT:2903_p +LITERAL_FLOAT:0.0 +*(>(2904_X,0.0),*(/(2903_mask,2903_p),%*%(dout,t(2902_W)))) +::STMT +MATRIX:PRED,GT +/(sum(*(PRED,GT)),sum(PRED)) +::STMT +FLOAT:AIC_best_orig +LITERAL_FLOAT:0.001 +abs(*(0.001,AIC_best_orig)) +::STMT +MATRIX:s,d +FLOAT:norm_r2,alpha_deno +%*%(t(+(s,*(norm_r2,d))),+(s,*(/(norm_r2,alpha_deno),d))) +::STMT +MATRIX:resp,parsertemp443532,X,weight +LITERAL_FLOAT:2.22E-16 +*(t(/(%*%(parsertemp443532,X),t(weight))),+(colSums(resp),2.22E-16)) +::STMT +FLOAT:x +LITERAL_FLOAT:-1.0 +exp(*(x,-1.0)) +::STMT +FLOAT:num_records +LITERAL_FLOAT:3840.0 +/(3840.0,num_records) +::STMT +MATRIX:R,dssm +FLOAT:2_n +LITERAL_FLOAT:1.0 +-(/(2_n,-(R,dssm)),1.0) +::STMT +MATRIX:w,wnew +FLOAT:sigma,alpha +LITERAL_FLOAT:0.5 +*(*(*(0.5,sigma),alpha),sum(*(-(wnew,w),-(wnew,w)))) +::STMT +MATRIX:r,parsertemp1936,parsertemp1937 +FLOAT:parsertemp1941,norm_r2 +LITERAL_FLOAT:0.0 +-(0.0,+(r,*(/(norm_r2,parsertemp1941),+(parsertemp1936,parsertemp1937)))) +::STMT +MATRIX:qLow,length,qUp +LITERAL_FLOAT:2.0 +>=(rowSums(|(<(length,qLow),>(length,qUp))),2.0) +::STMT +MATRIX:_sbcvar11 +LITERAL_FLOAT:1000.0 +-(_sbcvar11,/(%*%(rowSums(_sbcvar11),colSums(_sbcvar11)),1000.0)) +::STMT +MATRIX:minD,D +rowSums(<=(D,minD)) +::STMT +LITERAL_FLOAT:0.99 +0.99 +::STMT +MATRIX:parsertemp500608,parsertemp500604,parsertemp500605 +FLOAT:lambda +LITERAL_FLOAT:0.0 +abs(*(*(parsertemp500604,-(parsertemp500605,lambda)),>(-(parsertemp500608,lambda),0.0))) +::STMT +MATRIX:parsertemp260769,w +FLOAT:reg +LITERAL_FLOAT:2.0 +*(/(reg,2.0),sum(*(+(w,parsertemp260769),+(w,parsertemp260769)))) +::STMT +MATRIX:_sbcvar1708 +LITERAL_FLOAT:0.7 +*(_sbcvar1708,0.7) +::STMT +MATRIX:WM +FLOAT:m2X +LITERAL_FLOAT:1.0 +*(m2X,/(sum(WM),-(sum(WM),1.0))) +::STMT +MATRIX:tmp,g_old +/(cast.FLOAT(%*%(t(tmp),tmp)),cast.FLOAT(%*%(t(g_old),g_old))) +::STMT +MATRIX:parsertemp409789,parsertemp409798,parsertemp409788,parsertemp409797 +FLOAT:int843 +LITERAL_FLOAT:0.0 +%*%(t(+(-(int843,parsertemp409789),t(parsertemp409798))),+(-(0.0,t(parsertemp409788)),t(colSums(parsertemp409797)))) +::STMT +MATRIX:tmp_Xw,Y,parsertemp2775 +FLOAT:int711 +LITERAL_FLOAT:0.0,1.0 +*(*(-(1.0,*(Y,tmp_Xw)),>(-(int711,parsertemp2775),0.0)),Y) +::STMT +MATRIX:parsertemp389212 +LITERAL_FLOAT:2.0,1058.0 +^(/(parsertemp389212,1058.0),2.0) +::STMT +MATRIX:2903_mask,dout,X,2904_X,parsertemp555692 +FLOAT:2903_p +LITERAL_FLOAT:0.0 +%*%(t(X),*(>(2904_X,0.0),*(/(2903_mask,2903_p),%*%(dout,parsertemp555692)))) +::STMT +MATRIX:w,g +FLOAT:alpha +-(w,/(g,alpha)) +::STMT +MATRIX:P,lambda,X,Y,B_new +LITERAL_FLOAT:2.0 +^(+(%*%(t(X),-(P,Y)),*(lambda,B_new)),2.0) +::STMT +FLOAT:i +LITERAL_FLOAT:4.0 ++(4.0,i) +::STMT +MATRIX:cdf_min_distances,random_row +colSums(<(cdf_min_distances,*(random_row,cdf_min_distances))) +::STMT +FLOAT:deviance_nodisp,eps +LITERAL_FLOAT:0.1 +*(eps,+(deviance_nodisp,0.1)) +::STMT +MATRIX:X +LITERAL_FLOAT:1.0,2.0 ++(exp(*(2.0,X)),1.0) +::STMT +MATRIX:colSD,X,colMean +LITERAL_FLOAT:3.0 +<(X,-(colMean,*(3.0,colSD))) +::STMT +MATRIX:colSD,X,colMean +LITERAL_FLOAT:3.0 +>(X,+(colMean,*(3.0,colSD))) +::STMT +MATRIX:p_LS,parsertemp170551,X +FLOAT:lambda_LS ++(*(cast.FLOAT(%*%(parsertemp170551,X)),cast.FLOAT(p_LS)),*(lambda_LS,cast.FLOAT(p_LS))) +::STMT +MATRIX:parsertemp10744,parsertemp10746,V,W,H +LITERAL_FLOAT:1.0E-8 +/(%*%(V,t(*(H,parsertemp10744))),+(%*%(W,%*%(H,parsertemp10746)),1.0E-8)) +::STMT +MATRIX:W +round(W) +::STMT +MATRIX:X +FLOAT:threshold +*(>(X,threshold),X) +::STMT +MATRIX:mu +FLOAT:window_size,q +-(q,*(window_size,cast.FLOAT(*(mu,mu)))) +::STMT +FLOAT:log_ten,parsertemp169814 +LITERAL_FLOAT:4.0 +exp(*(log_ten,-(4.0,round(parsertemp169814)))) +::STMT +MATRIX:parsertemp393570,W3_rand +FLOAT:int397,int924 +LITERAL_FLOAT:0.128920512778062 +%*%(*(0.128920512778062,W3_rand),t(/(-(parsertemp393570,int397),+(parsertemp393570,int924)))) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:1.0,750.0 +/(+(classFeatureCounts,1.0),+(rowSums(classFeatureCounts),*(750.0,1.0))) +::STMT +FLOAT:a,b,x +LITERAL_FLOAT:2.0 ++(*(a,^(x,2.0)),*(b,x)) +::STMT +MATRIX:Q,R,parsertemp500307 +FLOAT:int723 +LITERAL_FLOAT:2.0 +-(+(rowSums(^(R,int723)),t(rowSums(parsertemp500307))),*(2.0,%*%(R,t(Q)))) +::STMT +LITERAL_FLOAT:0.25 +0.25 +::STMT +MATRIX:parsertemp115857,X,avg_X_cols +FLOAT:int636 +LITERAL_FLOAT:1.0 +/(-(t(colSums(parsertemp115857)),*($1:nrow(X),^(avg_X_cols,int636))),-($1,1.0)) +::STMT +FLOAT:run_index +LITERAL_FLOAT:1.0,2.0 ++(*(2.0,-(run_index,1.0)),1.0) +::STMT +FLOAT:parsertemp181047,parsertemp181040 +LITERAL_FLOAT:1.0,8.0 +sqrt(*(8.0,-(1.0,/(parsertemp181040,parsertemp181047)))) +::STMT +MATRIX:g0_1,d_r_rev,parsertemp410116 ++(g0_1,t(colSums(*(parsertemp410116,d_r_rev)))) +::STMT +MATRIX:parsertemp411194,parsertemp411197,W,H,parsertemp411205,parsertemp411206 +%*%(/(*(W,%*%(parsertemp411205,parsertemp411206)),t(rowSums(H))),/(*(H,%*%(parsertemp411194,parsertemp411197)),t(colSums(W)))) +::STMT +MATRIX:X,y +FLOAT:int879,int649 +INT:int378,m +%*%(t(X),-(%*%(X,rand(m,int378,int649,int879)),y)) +::STMT +MATRIX:ss +FLOAT:130_n +LITERAL_FLOAT:1.0 +-(/(130_n,ss),1.0) +::STMT +MATRIX:ot,yt +LITERAL_FLOAT:0.0,100.0 +*(sum(>(*(yt,ot),0.0)),100.0) +::STMT +LITERAL_FLOAT:-0.5 +-0.5 +::STMT +LITERAL_FLOAT:0.5 +0.5 +::STMT +MATRIX:W +FLOAT:int197,int575,m3,var,wt +LITERAL_FLOAT:2.0,3.0 +/(*(^(sum(W),2.0),m3),*(*(-(wt,int197),-(wt,int575)),^(sqrt(var),3.0))) +::STMT +LITERAL_FLOAT:0.254829592 +0.254829592 +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:1.0 +-(1.0,exp(linear_terms)) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,128.0,4.0 ++(*(-(i,1.0),128.0),4.0) +::STMT +FLOAT:522_padh,522_Hin +LITERAL_FLOAT:1.0,2.0 +-(+(522_Hin,*(2.0,522_padh)),1.0) +::STMT +FLOAT:obj,objnew +abs(-(objnew,obj)) +::STMT +LITERAL_FLOAT:1.0,150.0 +-(150.0,1.0) +::STMT +LITERAL_FLOAT:0.75 +0.75 +::STMT +MATRIX:p,A,r,parsertemp477951 +FLOAT:norm_r2 ++(r,*(/(norm_r2,sum(parsertemp477951)),%*%(t(A),%*%(A,p)))) +::STMT +MATRIX:parsertemp285848,X +LITERAL_FLOAT:0.0 +%*%(t(-(0.0,t(parsertemp285848))),t(colSums(X))) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,128.0,3.0 ++(*(-(i,1.0),128.0),3.0) +::STMT +MATRIX:lambda,beta +LITERAL_FLOAT:0.0,2.0 +^(+(0.0,*(cast.FLOAT(lambda),cast.FLOAT(beta))),2.0) +::STMT +MATRIX:p,r,Z +FLOAT:parsertemp31794,norm_r2 +*(+(r,*(/(norm_r2,parsertemp31794),%*%(Z,p))),+(r,*(/(norm_r2,parsertemp31794),%*%(Z,p)))) +::STMT +LITERAL_FLOAT:0.0625 +0.0625 +::STMT +LITERAL_FLOAT:1.0002795638803466 +1.0002795638803466 +::STMT +FLOAT:int263,2690_Hin,int538 +LITERAL_FLOAT:2.0 +/(-(+(2690_Hin,*(int538,int263)),2.0),2.0) +::STMT +MATRIX:A,CVars,CFreqs +FLOAT:int972 +/(sum(*(-(CFreqs,int972),CVars)),-(nrow(A),nrow(CFreqs))) +::STMT +MATRIX:linear_terms +FLOAT:link_power,int964,int879 +LITERAL_FLOAT:-2.0,1.0 +/(^(linear_terms,+(-2.0,/(int964,link_power))),-(1.0,^(linear_terms,/(int879,link_power)))) +::STMT +MATRIX:ss,parsertemp31463 +FLOAT:eAvg,alpha,n +LITERAL_FLOAT:1.0 +-(*(alpha,-(/(parsertemp31463,eAvg),1.0)),*(-(1.0,alpha),-(/(n,ss),1.0))) +::STMT +LITERAL_FLOAT:1.0,0.8 ++(1.0,0.8) +::STMT +LITERAL_FLOAT:0.125 +0.125 +::STMT +MATRIX:X,Centering +LITERAL_FLOAT:2.0,1764.0 +/(colSums(^(-(X,Centering),2.0)),1764.0) +::STMT +MATRIX:intercept,X,beta +exp(+(%*%(X,beta),intercept)) +::STMT +MATRIX:A,present_domain_vals_mat,CFreqs,parsertemp27487 +FLOAT:int999 +/(sum(*(-(CFreqs,int999),%*%(present_domain_vals_mat,parsertemp27487))),-(nrow(A),nrow(present_domain_vals_mat))) +::STMT +FLOAT:int453,F1 +LITERAL_FLOAT:2.0 +*(*(*(*(F1,int453),2.0),2.0),2.0) +::STMT +MATRIX:R,parsertemp503780 +t(+(R,diag(parsertemp503780))) +::STMT +MATRIX:means,Y,vars +LITERAL_FLOAT:2.0 +/(^(-(Y,means),2.0),vars) +::STMT +MATRIX:X,parsertemp438796 +*(ncol(X),parsertemp438796) +::STMT +LITERAL_FLOAT:4.0 +4.0 +::STMT +FLOAT:n +LITERAL_FLOAT:2.0,4.0 ++(-(n,4.0),2.0) +::STMT +LITERAL_FLOAT:4.5 +4.5 +::STMT +FLOAT:start_x,i,s_cols +LITERAL_FLOAT:1.0 ++(*(-(i,1.0),s_cols),start_x) +::STMT +MATRIX:2014_cnI,parsertemp230385 +t(%*%(parsertemp230385,2014_cnI)) +::STMT +MATRIX:obj,objnew,gs +-(-(objnew,obj),gs) +::STMT +MATRIX:P,Y,dP +&(>(P,dP),Y) +::STMT +MATRIX:parsertemp44107,parsertemp44109,wnew +FLOAT:C +*(+(wnew,*(C,%*%(parsertemp44107,parsertemp44109))),+(wnew,*(C,%*%(parsertemp44107,parsertemp44109)))) +::STMT +MATRIX:G +sum(!=(rowSums(G),t(colSums(G)))) +::STMT +MATRIX:parsertemp171245,Y +LITERAL_FLOAT:1.0 +*(rowSums(Y),/(1.0,-(exp(parsertemp171245),1.0))) +::STMT +LITERAL_FLOAT:6.0 +6.0 +::STMT +MATRIX:parsertemp43626 +LITERAL_FLOAT:-1.0,2.0 +*(2.0,*(-1.0,sum(parsertemp43626))) +::STMT +LITERAL_FLOAT:5.0 +5.0 +::STMT +MATRIX:cumLeftHist,parsertemp131906,parsertemp132092,leftHist,outBucket +%*%(==(outBucket,%*%(parsertemp132092,t(parsertemp131906))),-(cumLeftHist,leftHist)) +::STMT +LITERAL_FLOAT:1.0E-9 +1.0E-9 +::STMT +LITERAL_FLOAT:2.515517 +2.515517 +::STMT +MATRIX:_sbcvar96,_sbcvar95,_sbcvar97 +FLOAT:221_my,int469 +LITERAL_FLOAT:2.0 +*(%*%(_sbcvar95,_sbcvar96),^(+(%*%(_sbcvar95,_sbcvar97),-(int469,221_my)),2.0)) +::STMT +FLOAT:n +LITERAL_FLOAT:1.0,4.0 ++(-(n,4.0),1.0) +::STMT +LITERAL_FLOAT:8.0 +8.0 +::STMT +MATRIX:prec,X,mu +-(%*%(X,prec),%*%(mu,prec)) +::STMT +LITERAL_FLOAT:9.0 +9.0 +::STMT +LITERAL_FLOAT:7.0 +7.0 +::STMT +MATRIX:M +FLOAT:parsertemp178174 ++(max(M),parsertemp178174) +::STMT +MATRIX:sample_rec_ids +FLOAT:num_records +LITERAL_FLOAT:1.0 ++(*(sample_rec_ids,<=(sample_rec_ids,num_records)),*(+(num_records,1.0),-(1.0,<=(sample_rec_ids,num_records)))) +::STMT +MATRIX:W,H +%*%(%*%(t(W),W),H) +::STMT +MATRIX:feature +LITERAL_FLOAT:1.0 ++(feature,-(1.0,min(feature))) +::STMT +MATRIX:p,ssX_p,shift_X ++(ssX_p,%*%(t(shift_X),p)) +::STMT +MATRIX:parsertemp27461,r,c,E,F +FLOAT:int686 +LITERAL_FLOAT:2.0,1.0E-4 +/(^(-(F,+(parsertemp27461,E)),2.0),+(*(==(E,int686),1.0E-4),/(%*%(r,c),sum(F)))) +::STMT +MATRIX:Q3,IQR +LITERAL_FLOAT:2.0 ++(Q3,*(2.0,IQR)) +::STMT +LITERAL_FLOAT:10.0 +10.0 +::STMT +MATRIX:linear_terms +FLOAT:link_power +LITERAL_FLOAT:-2.0,1.0 +^(linear_terms,+(-2.0,/(1.0,link_power))) +::STMT +LITERAL_FLOAT:1.0 +1.0 +::STMT +LITERAL_FLOAT:-1.0 +-1.0 +::STMT +FLOAT:parsertemp2 +cast.MATRIX(parsertemp2) +::STMT +LITERAL_FLOAT:-Infinity +-Infinity +::STMT +LITERAL_FLOAT:Infinity +Infinity +::STMT +MATRIX:W,parsertemp411110,X,H +LITERAL_FLOAT:1.0E-8 +*(W,/(%*%(X,t(H)),+(%*%(W,parsertemp411110),1.0E-8))) +::STMT +MATRIX:parsertemp459193,vW3,parsertemp459200,2703_W +FLOAT:lr,mu,float473 +-(*(mu,vW3),*(lr,+(%*%(parsertemp459200,parsertemp459193),*(float473,2703_W)))) +::STMT +MATRIX:pred +LITERAL_FLOAT:1.0E-10 ++(pred,1.0E-10) +::STMT +FLOAT:factor_up,parsertemp195892 +LITERAL_FLOAT:1.0,2.0 +-(-(*(2.0,factor_up),parsertemp195892),1.0) +::STMT +LITERAL_FLOAT:NaN +NaN +::STMT +LITERAL_FLOAT:1.5 +1.5 +::STMT +MATRIX:P1,P2,S +LITERAL_FLOAT:0.0 +!=(+(%*%(P1,S),%*%(P2,S)),0.0) +::STMT +MATRIX:parsertemp539203 +LITERAL_FLOAT:-1.0,2.0 +/(*(parsertemp539203,-1.0),2.0) +::STMT +MATRIX:parsertemp222703 +LITERAL_FLOAT:0.0,1.0 ++(rowSums(==(t(parsertemp222703),0.0)),1.0) +::STMT +MATRIX:U,row_nonzeros +FLOAT:reg +*(*(reg,U),row_nonzeros) +::STMT +MATRIX:2701_mask,2700_W,parsertemp459178,2699_dtemp,2702_X,2703_W +FLOAT:int377,float760 +%*%(*(*(>(2702_X,int377),/(2701_mask,float760)),%*%(-(2699_dtemp,parsertemp459178),t(2700_W))),t(2703_W)) +::STMT +LITERAL_FLOAT:2.0 +2.0 +::STMT +LITERAL_FLOAT:0.0 +0.0 +::STMT +LITERAL_FLOAT:-0.0 +-0.0 +::STMT +LITERAL_FLOAT:-2.0 +-2.0 +::STMT +MATRIX:parsertemp220911,g,Y +FLOAT:float687 +LITERAL_FLOAT:0.0 +-(+(Y,-(0.0,*(float687,g))),parsertemp220911) +::STMT +MATRIX:E,F +LITERAL_FLOAT:0.001 +<(-(E,F),0.001) +::STMT +MATRIX:RDMean,parsertemp265748 +LITERAL_FLOAT:2.0 +t(-(parsertemp265748,^(RDMean,2.0))) +::STMT +LITERAL_FLOAT:3.0 +3.0 +::STMT +MATRIX:svUpBnd,R,svLowBnd +*(<=(R,cast.FLOAT(svUpBnd)),>(R,cast.FLOAT(svLowBnd))) +::STMT +MATRIX:sts,d,parsertemp44021,parsertemp44023 +FLOAT:delta2 +sqrt(+(*(%*%(parsertemp44021,d),%*%(parsertemp44021,d)),*(%*%(parsertemp44023,d),-(delta2,sts)))) +::STMT +MATRIX:t,parsertemp32834,parsertemp32843,X,parsertemp32837,parsertemp32827,parsertemp32824,parsertemp32846 +FLOAT:int882,x +LITERAL_FLOAT:1.0 +*(*(/(-(x,X),-(X,X)),-(1.0,/(parsertemp32824,parsertemp32827))),+(*(-(parsertemp32834,parsertemp32837),-(int882,t)),*(+(parsertemp32843,parsertemp32846),/(parsertemp32824,parsertemp32827)))) +::STMT +MATRIX:parsertemp145796,y +FLOAT:int717 +sum(rowSums(*(*(y,int717),parsertemp145796))) +::STMT +MATRIX:Y,vec1 +FLOAT:link_power +LITERAL_FLOAT:2.0 +/(*(rowSums(Y),vec1),^(link_power,2.0)) +::STMT +MATRIX:X_Xd_exp_Xb_rev_agg,select,D_r_rev +/(%*%(select,X_Xd_exp_Xb_rev_agg),D_r_rev) +::STMT +MATRIX:s,d,alpha +FLOAT:parsertemp44015 +%*%(t(-(s,*(parsertemp44015,d))),-(s,*(cast.FLOAT(alpha),d))) +::STMT +MATRIX:X +FLOAT:int681,epsilon +<(sqrt(rowSums(^(X,int681))),epsilon) +::STMT +FLOAT:K +LITERAL_FLOAT:21.0 +*(21.0,K) +::STMT +MATRIX:neighbors,corePts,withinEps +LITERAL_FLOAT:0.0 +colSums(>(*(*(neighbors,corePts),withinEps),0.0)) +::STMT +MATRIX:y_corr,parsertemp171002 +FLOAT:int375 +LITERAL_FLOAT:0.0,1.0 +-(parsertemp171002,/(==(y_corr,0.0),-(1.0,==(y_corr,int375)))) +::STMT +MATRIX:W +FLOAT:float615,m2,wt +/(sqrt(/(*(m2,wt),-(wt,float615))),sqrt(sum(round(W)))) +::STMT +MATRIX:X +FLOAT:index +LITERAL_FLOAT:1.0 +*(index,-(ncol(X),1.0)) +::STMT +MATRIX:parsertemp472326,parsertemp472314 +-(nrow(parsertemp472314),cast.FLOAT(parsertemp472326)) +::STMT +MATRIX:b,parsertemp410078,sb +LITERAL_FLOAT:-1.0 +*(cast.FLOAT(%*%(colSums(parsertemp410078),+(b,sb))),-1.0) +::STMT +MATRIX:parsertemp24102,parsertemp24103 +FLOAT:num_bins,int935 +LITERAL_FLOAT:1.0 +-(-(1.0,<(+(parsertemp24103,int935),1.0)),>(+(round(parsertemp24102),1.0),num_bins)) +::STMT +LITERAL_FLOAT:10.0,-8.0 +^(10.0,-8.0) +::STMT +MATRIX:2792_M2 +LITERAL_FLOAT:0.0 +|(!=(2792_M2,0.0),!=(2792_M2,0.0)) +::STMT +LITERAL_FLOAT:10.0,-10.0 +^(10.0,-10.0) +::STMT +LITERAL_FLOAT:-12.0,10.0 +^(10.0,-12.0) +::STMT +MATRIX:minD,D,parsertemp222603,parsertemp222600 +t(/(<=(+(parsertemp222600,parsertemp222603),minD),rowSums(<=(D,minD)))) +::STMT +MATRIX:parsertemp222703 +LITERAL_FLOAT:0.0 +rowSums(==(t(parsertemp222703),0.0)) +::STMT +FLOAT:num_func_invoc +LITERAL_FLOAT:1.0,5.0 +-(+(num_func_invoc,5.0),1.0) +::STMT +MATRIX:ss_res_Y,var_tot_Y +FLOAT:df_ss_res_Y +/(/(ss_res_Y,df_ss_res_Y),var_tot_Y) +::STMT +MATRIX:M +LITERAL_FLOAT:0.0,2.0 +&(>(rowSums(M),0.0),<(rowSums(M),2.0)) +::STMT +MATRIX:X,permut +FLOAT:n +-(%*%(permut,X),/(colSums(%*%(permut,X)),n)) +::STMT +MATRIX:CMeans,CFreqs +FLOAT:my +LITERAL_FLOAT:2.0 +sum(*(CFreqs,^(-(CMeans,my),2.0))) +::STMT +MATRIX:B +LITERAL_FLOAT:8.0 +/(nrow(B),8.0) +::STMT +LITERAL_FLOAT:0.0873148795050037 +0.0873148795050037 +::STMT +MATRIX:W +LITERAL_FLOAT:1.0,2.0,3.0 +*(*(-(sum(W),2.0),+(sum(W),1.0)),+(sum(round(W)),3.0)) +::STMT +MATRIX:parsertemp414371 +LITERAL_FLOAT:200.0,2.0 +*(200.0,^(/(t(parsertemp414371),200.0),2.0)) +::STMT +MATRIX:X +FLOAT:x +sum(>=(X,x)) +::STMT +MATRIX:border,parsertemp386448,parsertemp386459,parsertemp386449,parsertemp386460,withinEps +FLOAT:int478,int316 +LITERAL_FLOAT:0.0 ++(*(>(*(parsertemp386448,withinEps),0.0),==(-(border,parsertemp386459),0.0)),t(*(>(parsertemp386449,int478),==(parsertemp386460,int316)))) +::STMT +LITERAL_FLOAT:10.0,-30.0 +^(10.0,-30.0) +::STMT +LITERAL_FLOAT:10.0,30.0 +^(10.0,30.0) +::STMT +MATRIX:parsertemp191275,parsertemp191273 +FLOAT:397_C ++(parsertemp191273,*(397_C,t(parsertemp191275))) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:1.0,2.0 ++(1.0,^(linear_terms,2.0)) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,128.0 ++(*(-(i,1.0),128.0),128.0) +::STMT +MATRIX:B +LITERAL_FLOAT:4.0 +/(nrow(B),4.0) +::STMT +MATRIX:237_present_domain_vals_mat,parsertemp29514,237_CFreqs +FLOAT:int194 +LITERAL_FLOAT:10000.0 +/(sum(*(-(237_CFreqs,int194),%*%(237_present_domain_vals_mat,parsertemp29514))),-(10000.0,nrow(237_present_domain_vals_mat))) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,128.0,84.0 ++(*(-(i,1.0),128.0),84.0) +::STMT +MATRIX:S,parsertemp175056 +rowSums(exp(-(S,parsertemp175056))) +::STMT +MATRIX:dout +LITERAL_FLOAT:0.01 +*(0.01,dout) +::STMT +MATRIX:parsertemp122291,parsertemp122288 +LITERAL_FLOAT:0.0,4.0 +sum(|(<(t(parsertemp122288),4.0),==(t(parsertemp122291),0.0))) +::STMT +MATRIX:B +LITERAL_FLOAT:2.0 +/(nrow(B),2.0) +::STMT +MATRIX:Bx,Yd,Yu +LITERAL_FLOAT:2.0 +/(-(Yu,Yd),^(Bx,2.0)) +::STMT +MATRIX:Q1,Q3,X,IQR +FLOAT:k +|(<(X,-(Q1,*(k,IQR))),>(X,+(Q3,*(k,IQR)))) +::STMT +LITERAL_FLOAT:0.08681986202598489 +0.08681986202598489 +::STMT +FLOAT:i,k +LITERAL_FLOAT:1.0 +cast.MATRIX(-(+(i,k),1.0)) +::STMT +MATRIX:parsertemp2832 +sum(==(round(parsertemp2832),min(round(parsertemp2832)))) +::STMT +MATRIX:w +LITERAL_FLOAT:0.5 +*(0.5,%*%(t(w),w)) +::STMT +LITERAL_FLOAT:1.0,10.0 ++(10.0,1.0) +::STMT +MATRIX:vW1,dW,parsertemp459256 +FLOAT:lr,mu,float518 +-(*(mu,vW1),*(lr,+(dW,*(float518,parsertemp459256)))) +::STMT +MATRIX:p_LS +FLOAT:norm_r2_LS,parsertemp170552,lambda_LS +*(/(norm_r2_LS,*(cast.FLOAT(p_LS),+(parsertemp170552,lambda_LS))),cast.FLOAT(p_LS)) +::STMT +FLOAT:m2 +LITERAL_FLOAT:3.0,1.0005 +^(sqrt(*(1.0005,m2)),3.0) +::STMT +MATRIX:X +LITERAL_FLOAT:1.0 +/(nrow(X),1.0) +::STMT +MATRIX:Q1,X,IQR +LITERAL_FLOAT:2.0 +<(X,-(Q1,*(2.0,IQR))) +::STMT +MATRIX:Q3,X,IQR +LITERAL_FLOAT:2.0 +>(X,+(Q3,*(2.0,IQR))) +::STMT +FLOAT:o_init +LITERAL_FLOAT:-2.0,50.0 +/(*(-2.0,o_init),50.0) +::STMT +FLOAT:m2 +LITERAL_FLOAT:4.0,1.0005 +^(sqrt(*(1.0005,m2)),4.0) +::STMT +FLOAT:std,float498,float46 +INT:int895,int207 +cast.MATRIX(*(cast.FLOAT(rand(int207,int895,float46,float498)),std)) +::STMT +FLOAT:parsertemp190484,parsertemp190485,FN,TN,FP +sqrt(*(*(*(parsertemp190484,parsertemp190485),+(TN,FP)),+(TN,FN))) +::STMT +MATRIX:parsertemp443530,resp,X +FLOAT:float889 +t(/(%*%(t(resp),X),t(+(parsertemp443530,float889)))) +::STMT +MATRIX:W,H +FLOAT:Eps ++(%*%(%*%(t(W),W),H),Eps) +::STMT +MATRIX:mean,parsertemp437236,parsertemp437235,X,weight,parsertemp437241 +FLOAT:int326 +LITERAL_FLOAT:2.0 ++(-(/(%*%(parsertemp437235,parsertemp437236),t(weight)),*(2.0,^(mean,int326))),/(*(mean,%*%(parsertemp437241,X)),t(weight))) +::STMT +MATRIX:s,d +FLOAT:parsertemp44015 +%*%(t(-(s,*(parsertemp44015,d))),d) +::STMT +MATRIX:parsertemp410987,parsertemp410978,W,H +%*%(/(*(W,parsertemp410987),t(rowSums(H))),/(*(H,t(parsertemp410978)),t(colSums(W)))) +::STMT +FLOAT:_sbcvar1799 +LITERAL_FLOAT:9.0 +-(9.0,_sbcvar1799) +::STMT +FLOAT:i +LITERAL_FLOAT:9.0 ++(i,9.0) +::STMT +MATRIX:parsertemp460644 +FLOAT:float790,2715_D +LITERAL_FLOAT:2.0 +/(*(parsertemp460644,sqrt(/(float790,2715_D))),sqrt(2.0)) +::STMT +LITERAL_FLOAT:9.999999999 +9.999999999 +::STMT +MATRIX:_sbcvar11,43_r,43_c,43_E +LITERAL_FLOAT:2.0,1000.0 +sum(/(^(-(_sbcvar11,43_E),2.0),/(%*%(43_r,43_c),1000.0))) +::STMT +MATRIX:Xd,Xw +FLOAT:step_sz ++(Xw,*(step_sz,Xd)) +::STMT +MATRIX:is_zero_y_corr,is_one_y_corr +FLOAT:parsertemp171116 +*(parsertemp171116,+(is_zero_y_corr,is_one_y_corr)) +::STMT +MATRIX:_sbcvar332,parsertemp42290 +FLOAT:float884,meanX +LITERAL_FLOAT:9999.0 +t(*(/(_sbcvar332,9999.0),-(+(parsertemp42290,float884),meanX))) +::STMT +FLOAT:KM_offset +LITERAL_FLOAT:7.0 ++(KM_offset,7.0) +::STMT +MATRIX:R,3_ss,dsep +/(+(R,dsep),3_ss) +::STMT +MATRIX:Y,parsertemp2796,Xw +LITERAL_FLOAT:0.0,1.0 +*(>(-(1.0,*(Y,Xw)),0.0),-(1.0,*(Y,+(Xw,parsertemp2796)))) +::STMT +FLOAT:i +LITERAL_FLOAT:12.0 ++(i,12.0) +::STMT +FLOAT:i +LITERAL_FLOAT:192.0 ++(192.0,i) +::STMT +FLOAT:i +LITERAL_FLOAT:11.0 ++(i,11.0) +::STMT +MATRIX:parsertemp500607,X,y,parsertemp500610 +*(-(%*%(X,*(parsertemp500607,parsertemp500610)),y),-(%*%(X,*(parsertemp500607,parsertemp500610)),y)) +::STMT +MATRIX:parsertemp436669,prec_chol,X,parsertemp436673 +FLOAT:int93,int32,int745 +LITERAL_FLOAT:2.0 +INT:parsertemp436666,int445 ++(-(*(rand(int445,parsertemp436666,int32,int93),t(parsertemp436669)),*(2.0,%*%(X,parsertemp436673))),%*%(^(X,2.0),t(^(prec_chol,int745)))) +::STMT +FLOAT:502_strideh,502_padh,502_Hin +LITERAL_FLOAT:1.0,2.0 +-(*(502_strideh,-(502_Hin,1.0)),*(2.0,502_padh)) +::STMT +FLOAT:k +LITERAL_FLOAT:4.0 ++(k,4.0) +::STMT +MATRIX:y_corr +LITERAL_FLOAT:0.0,1.0 +-(1.0,==(y_corr,0.0)) +::STMT +MATRIX:ss +FLOAT:130_n,130_alpha +LITERAL_FLOAT:1.0 +*(-(1.0,130_alpha),-(/(130_n,ss),1.0)) +::STMT +MATRIX:parsertemp44080,obj,parsertemp44076,wnew +FLOAT:C +LITERAL_FLOAT:0.5 +-(obj,+(*(0.5,%*%(parsertemp44076,wnew)),*(C,sum(parsertemp44080)))) +::STMT +MATRIX:classCounts +FLOAT:numRows +/(classCounts,numRows) +::STMT +MATRIX:parsertemp500604,w,parsertemp500601 +FLOAT:alpha,tau +*(parsertemp500604,-(abs(-(w,parsertemp500601)),/(tau,alpha))) +::STMT +FLOAT:KM_offset +LITERAL_FLOAT:6.0 ++(KM_offset,6.0) +::STMT +MATRIX:mW2,dW2 +FLOAT:193_lr,parsertemp147034,193_beta1,int779,193_t +LITERAL_FLOAT:1.0 +*(/(*(193_lr,sqrt(parsertemp147034)),-(1.0,^(193_beta1,193_t))),+(*(193_beta1,mW2),*(-(int779,193_beta1),dW2))) +::STMT +MATRIX:parsertemp40482,X2,l +/(nrow(X2),t(colSums(==(parsertemp40482,l)))) +::STMT +MATRIX:parsertemp429910 +LITERAL_FLOAT:300.0,2.0 +*(300.0,^(/(t(parsertemp429910),300.0),2.0)) +::STMT +FLOAT:w_i +LITERAL_FLOAT:5.0 ++(w_i,5.0) +::STMT +MATRIX:parsertemp171246,Y +FLOAT:int23 +LITERAL_FLOAT:1.0 +-(Y,*(Y,/(1.0,-(parsertemp171246,int23)))) +::STMT +FLOAT:run_index +LITERAL_FLOAT:48.0 +*(48.0,run_index) +::STMT +MATRIX:weightMatrix +FLOAT:threshold +LITERAL_FLOAT:0.0 +&(<(weightMatrix,threshold),>(weightMatrix,0.0)) +::STMT +MATRIX:posSampleVariances,negSampleVariances +FLOAT:int21,int125,int667,int938 +LITERAL_FLOAT:3.42951E11,2.0,3.37275E9 +/(^(+(/(posSampleVariances,int21),/(negSampleVariances,int938)),2.0),+(/(^(posSampleVariances,int667),3.42951E11),/(^(negSampleVariances,int125),3.37275E9))) +::STMT +MATRIX:p_CG +FLOAT:parsertemp285794,parsertemp285796 +LITERAL_FLOAT:-1.0 +/(-(*(cast.FLOAT(p_CG),-1.0),sqrt(-(parsertemp285794,parsertemp285796))),cast.FLOAT(%*%(t(p_CG),p_CG))) +::STMT +FLOAT:norm_Grad_initial +LITERAL_FLOAT:1.0E-4 +*(1.0E-4,norm_Grad_initial) +::STMT +MATRIX:parsertemp498247,m_iter_err_sum +FLOAT:i_process_item +LITERAL_FLOAT:0.0,2.0 +^(/(-(0.0,-(parsertemp498247,m_iter_err_sum)),i_process_item),2.0) +::STMT +FLOAT:int200,parsertemp285740,p_CG,parsertemp285763,pp_CG +*(parsertemp285763,/(-(*(p_CG,int200),sqrt(parsertemp285740)),pp_CG)) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:1.0,750.0 +/(+(classFeatureCounts,1.0),+(rowSums(classFeatureCounts),*(750.0,1.0))) +::STMT +MATRIX:P12,map +FLOAT:level +LITERAL_FLOAT:0.0 +==(rowSums(!=(%*%(map,P12),0.0)),level) +::STMT +MATRIX:W +FLOAT:m2,int91 +LITERAL_FLOAT:1.0,2.0,3.0 +*(*(*(3.0,^(m2,int91)),^(sum(W),2.0)),-(sum(round(W)),1.0)) +::STMT +MATRIX:W2_rand +LITERAL_FLOAT:0.128920512778062 +*(0.128920512778062,W2_rand) +::STMT +MATRIX:p,V +LITERAL_FLOAT:1.0E-8 +%*%(t(p),+(%*%(t(V),%*%(V,p)),*(1.0E-8,p))) +::STMT +MATRIX:X +LITERAL_FLOAT:2.0 +^(max(X),2.0) +::STMT +MATRIX:parsertemp31190,parsertemp31197 +FLOAT:parsertemp31191,parsertemp31198 +LITERAL_FLOAT:1500.0,7000.0 +sqrt(+(/(/(parsertemp31190,parsertemp31191),7000.0),/(/(parsertemp31197,parsertemp31198),1500.0))) +::STMT +LITERAL_FLOAT:0.007 +0.007 +::STMT +MATRIX:y_hat,b,R +*(-(-(b,%*%(R,y_hat)),y_hat),-(-(b,%*%(R,y_hat)),y_hat)) +::STMT +MATRIX:ytest +FLOAT:sum_y_test,n +LITERAL_FLOAT:2.0 +-(sum(^(ytest,2.0)),*(nrow(ytest),^(/(sum_y_test,n),2.0))) +::STMT +MATRIX:s,w +FLOAT:step_sz +*(+(w,*(step_sz,s)),+(w,*(step_sz,s))) +::STMT +MATRIX:dW,parsertemp459256 +FLOAT:lr +LITERAL_FLOAT:5.0E-4 +*(lr,+(dW,*(5.0E-4,parsertemp459256))) +::STMT +FLOAT:parsertemp40812,m2,int31,mu +/(sqrt(*(/(int31,parsertemp40812),m2)),mu) +::STMT +FLOAT:_sbcvar1783 +LITERAL_FLOAT:8.0 +-(8.0,_sbcvar1783) +::STMT +MATRIX:ss,se +FLOAT:parsertemp122358,int182 +LITERAL_FLOAT:1.0,0.95 +*(0.95,-(/(/(se,ss),/(parsertemp122358,int182)),1.0)) +::STMT +MATRIX:select,X_exp_Xb_rev_agg,D_r_rev,Xd_exp_Xb_rev_agg +LITERAL_FLOAT:2.0 +/(*(X_exp_Xb_rev_agg,%*%(select,Xd_exp_Xb_rev_agg)),^(D_r_rev,2.0)) +::STMT +MATRIX:Y_counts,parsertemp560508,parsertemp560522,ent1_vec +/(-(sum(rowSums(parsertemp560508)),sum(*(Y_counts,ent1_vec))),sqrt(sum(*(Y_counts,parsertemp560522)))) +::STMT +LITERAL_FLOAT:1.0,2000.0 +-(2000.0,1.0) +::STMT +FLOAT:lambda,beta +LITERAL_FLOAT:0.0 +sqrt(*(+(0.0,*(lambda,beta)),+(0.0,*(lambda,beta)))) +::STMT +MATRIX:g_Y,w +LITERAL_FLOAT:2.0 +/(^(g_Y,2.0),w) +::STMT +MATRIX:X +LITERAL_FLOAT:0.0 +INT:int47,int606 +%*%(X,rand(int606,int47,0.0,0.0)) +::STMT +MATRIX:A +abs(-(A,t(A))) +::STMT +MATRIX:Y +sum(==(Y,max(Y))) +::STMT +MATRIX:determinants +FLOAT:nFeats +LITERAL_FLOAT:3.141592653589793,2.0 +*(^(*(2.0,3.141592653589793),nFeats),determinants) +::STMT +LITERAL_FLOAT:44.73253849269008 +44.73253849269008 +::STMT +MATRIX:L,m +FLOAT:sum +/(-(m,sum),cast.FLOAT(L)) +::STMT +MATRIX:parsertemp260755,Xd +FLOAT:dd,step_sz,wd +*(-(+(wd,*(step_sz,dd)),sum(*(parsertemp260755,Xd))),-(+(wd,*(step_sz,dd)),sum(*(parsertemp260755,Xd)))) +::STMT +MATRIX:ss +LITERAL_FLOAT:40.0 +/(40.0,ss) +::STMT +MATRIX:prec_chol,mu +FLOAT:int750 +LITERAL_FLOAT:2.0 +t(*(rowSums(^(mu,int750)),^(prec_chol,2.0))) +::STMT +MATRIX:means,Y_counts,ones_ctg +LITERAL_FLOAT:1.0 +<(*(means,%*%(Y_counts,t(ones_ctg))),1.0) +::STMT +FLOAT:int18 +LITERAL_FLOAT:0.0 +INT:int193,m +abs(rand(m,int193,0.0,int18)) +::STMT +MATRIX:probs,scores,unnorm_probs,dprobs +-(*(dprobs,/(exp(scores),rowSums(unnorm_probs))),*(/(exp(scores),rowSums(unnorm_probs)),rowSums(*(dprobs,probs)))) +::STMT +LITERAL_FLOAT:3.0,2000.0 +-(2000.0,3.0) +::STMT +FLOAT:parsertemp230731 +LITERAL_FLOAT:2.0 ++(parsertemp230731,2.0) +::STMT +MATRIX:labels +LITERAL_FLOAT:1.0 ++(labels,-(1.0,min(labels))) +::STMT +MATRIX:tmp,leftIdx +LITERAL_FLOAT:0.0 +>(%*%(tmp,%*%(t(tmp),leftIdx)),0.0) +::STMT +MATRIX:t_gp,parsertemp560875,linear_terms,parsertemp560867 +FLOAT:int721,float396 +LITERAL_FLOAT:1.0,2.0,0.254829592 +*(*(/(1.0,+(float396,parsertemp560867)),+(0.254829592,*(t_gp,parsertemp560875))),-(*(2.0,>=(linear_terms,int721)),1.0)) +::STMT +FLOAT:parsertemp191177,strideh,Hin,Hf +LITERAL_FLOAT:1.0 ++(/(-(+(Hin,parsertemp191177),Hf),strideh),1.0) +::STMT +MATRIX:parsertemp539203 +FLOAT:int975 +LITERAL_FLOAT:2.0,0.6666666666666666 +max(^(/(*(parsertemp539203,int975),2.0),0.6666666666666666)) +::STMT +MATRIX:pred,y +LITERAL_FLOAT:1.0,-1.0,1.0E-10 +*(*(/(1.0,nrow(y)),*(y,-1.0)),/(1.0,+(pred,1.0E-10))) +::STMT +FLOAT:KM_offset +LITERAL_FLOAT:3.0 ++(KM_offset,3.0) +::STMT +MATRIX:parsertemp146972,parsertemp146970,W1,191_v +FLOAT:parsertemp146984,parsertemp146982,191_epsilon +-(W1,/(*(/(parsertemp146982,parsertemp146984),+(parsertemp146970,parsertemp146972)),+(sqrt(191_v),191_epsilon))) +::STMT +MATRIX:R,dssp,dsep +/(+(R,dsep),+(R,dssp)) +::STMT +MATRIX:e,X2 +LITERAL_FLOAT:0.0 +==(t(%*%(t(e),X2)),0.0) +::STMT +LITERAL_FLOAT:1.0E-6 +INT:int996,int118 +diag(rand(int996,int118,1.0E-6,1.0E-6)) +::STMT +FLOAT:parsertemp410218,parsertemp410219,N +LITERAL_FLOAT:-1.0 +exp(/(*(-(parsertemp410218,parsertemp410219),-1.0),N)) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0 ++(i,1.0) +::STMT +MATRIX:y_prob,elt +FLOAT:int410 +LITERAL_FLOAT:1.0,1.0E7 +*(-(1.0,==(+(int410,elt),1.0E7)),-(1.0,y_prob)) +::STMT +MATRIX:X +LITERAL_FLOAT:1.0E-7 +INT:int802,m ++(%*%(t(X),X),diag(rand(m,int802,1.0E-7,1.0E-7))) +::STMT +FLOAT:m2 +LITERAL_FLOAT:44.73253849269008,1.0005 +/(sqrt(*(1.0005,m2)),44.73253849269008) +::STMT +LITERAL_FLOAT:2.0,2000.0 +-(2000.0,2.0) +::STMT +LITERAL_FLOAT:3.42951E11 +3.42951E11 +::STMT +MATRIX:means,Y_counts,ones_ctg +LITERAL_FLOAT:5.0 +<(*(means,%*%(Y_counts,t(ones_ctg))),5.0) +::STMT +MATRIX:W2_rand +LITERAL_FLOAT:0.1651445647689541 +*(0.1651445647689541,W2_rand) +::STMT +MATRIX:X +LITERAL_FLOAT:0.0,1.0 +INT:int447,m +%*%(X,rand(m,int447,0.0,1.0)) +::STMT +MATRIX:Y +LITERAL_FLOAT:0.0,1.0 +!=(+(Y,1.0),0.0) +::STMT +MATRIX:parsertemp379565,m_iter_err_sum +FLOAT:i_process_item +LITERAL_FLOAT:-1.0,2.0 +^(/(*(-(parsertemp379565,m_iter_err_sum),-1.0),i_process_item),2.0) +::STMT +MATRIX:252_X +FLOAT:252_X,float360 +LITERAL_FLOAT:1.0,4.5 +*(/(-(4.5,cast.FLOAT(252_X)),-(cast.FLOAT(252_X),cast.FLOAT(252_X))),-(1.0,/(-(float360,252_X),-(252_X,252_X)))) +::STMT +MATRIX:parsertemp1517,parsertemp1515 +FLOAT:int869,n +LITERAL_FLOAT:0.0,1.0 +-(1.0,<=(/(-(parsertemp1515,parsertemp1517),-(n,int869)),0.0)) +::STMT +FLOAT:_sbcvar1847 +LITERAL_FLOAT:11.0 +-(11.0,_sbcvar1847) +::STMT +FLOAT:i +LITERAL_FLOAT:1048.0 ++(i,1048.0) +::STMT +MATRIX:WM +FLOAT:parsertemp31268 +LITERAL_FLOAT:1.0 +*(-(sum(WM),1.0),/(*(parsertemp31268,sum(WM)),-(sum(WM),1.0))) +::STMT +FLOAT:i +LITERAL_FLOAT:1024.0 ++(i,1024.0) +::STMT +MATRIX:p_CG,z +FLOAT:rr_CG,pq_CG ++(z,*(/(rr_CG,pq_CG),p_CG)) +::STMT +MATRIX:ot2 +FLOAT:int897 +LITERAL_FLOAT:1500.0,100.0 +/(*(sum(>(ot2,int897)),100.0),1500.0) +::STMT +MATRIX:X +FLOAT:int17 +LITERAL_FLOAT:0.0 +INT:m,int172 +%*%(X,rand(m,int172,0.0,int17)) +::STMT +MATRIX:lambda,B_new +LITERAL_FLOAT:2.0 +sum(*(lambda,^(B_new,2.0))) +::STMT +MATRIX:W +LITERAL_FLOAT:3.0,5.0 +*(+(sum(round(W)),5.0),-(sum(round(W)),3.0)) +::STMT +MATRIX:parsertemp171326,is_lt_pos,parsertemp171330,Y,parsertemp171329 +FLOAT:one_over_sqrt_two_pi,float268 +*(one_over_sqrt_two_pi,+(-(Y,*(parsertemp171326,is_lt_pos)),*(*(parsertemp171329,parsertemp171330),-(is_lt_pos,float268)))) +::STMT +MATRIX:r,parsertemp44063,grad +LITERAL_FLOAT:-0.5 +*(-0.5,-(cast.FLOAT(%*%(parsertemp44063,grad)),cast.FLOAT(%*%(parsertemp44063,r)))) +::STMT +MATRIX:p,e,u +FLOAT:alpha +LITERAL_FLOAT:1.0 +*(-(1.0,alpha),%*%(%*%(e,u),p)) +::STMT +MATRIX:p_CG +FLOAT:rr_CG,pq_CG +*(/(rr_CG,pq_CG),p_CG) +::STMT +LITERAL_FLOAT:-0.6931471805599453 +-0.6931471805599453 +::STMT +LITERAL_FLOAT:0.6931471805599453 +0.6931471805599453 +::STMT +LITERAL_FLOAT:1.0E-7 +INT:int329,m +diag(rand(m,int329,1.0E-7,1.0E-7)) +::STMT +MATRIX:r,parsertemp1936,parsertemp1937 +FLOAT:parsertemp1941,norm_r2 +*(+(r,*(/(norm_r2,parsertemp1941),+(parsertemp1936,parsertemp1937))),+(r,*(/(norm_r2,parsertemp1941),+(parsertemp1936,parsertemp1937)))) +::STMT +MATRIX:X,Centering,ScaleFactor +FLOAT:N +/(colSums(/(-(X,Centering),ScaleFactor)),N) +::STMT +MATRIX:classFeatureCounts +FLOAT:numFeatures,laplaceCorrection +/(+(classFeatureCounts,laplaceCorrection),+(rowSums(classFeatureCounts),*(numFeatures,laplaceCorrection))) +::STMT +FLOAT:std +LITERAL_FLOAT:0.0,1.0 +INT:int654,int49 +*(cast.FLOAT(rand(int49,int654,0.0,1.0)),std) +::STMT +MATRIX:X +LITERAL_FLOAT:1.0 ++(max(X),1.0) +::STMT +MATRIX:xs +LITERAL_FLOAT:4.5 +sum(>=(xs,4.5)) +::STMT +MATRIX:parsertemp13624,_sbcvar11 +FLOAT:int284 +LITERAL_FLOAT:2.0,1000.0 +/(^(-(_sbcvar11,/(parsertemp13624,int284)),2.0),/(%*%(rowSums(_sbcvar11),colSums(_sbcvar11)),1000.0)) +::STMT +MATRIX:R +LITERAL_FLOAT:1.0 +INT:parsertemp503363,int581 ++(R,diag(rand(parsertemp503363,int581,1.0,1.0))) +::STMT +LITERAL_FLOAT:2.22E-16 +2.22E-16 +::STMT +MATRIX:svUpBnd,R +<=(R,cast.FLOAT(svUpBnd)) +::STMT +MATRIX:vW1,dW1 +FLOAT:2727_mu,2727_lr +LITERAL_FLOAT:1.0 +*(+(1.0,2727_mu),-(*(2727_mu,vW1),*(2727_lr,dW1))) +::STMT +MATRIX:Y,predicted_Y +LITERAL_FLOAT:0.0 +/(sum(==(-(predicted_Y,Y),0.0)),nrow(Y)) +::STMT +LITERAL_FLOAT:0.025253813613805267 +0.025253813613805267 +::STMT +MATRIX:q,r +FLOAT:p,norm_r2 +t(+(r,*(/(norm_r2,p),+(q,q)))) +::STMT +MATRIX:codebook +FLOAT:j +LITERAL_FLOAT:1.0 ++(1.0,*(-(j,1.0),ncol(codebook))) +::STMT +FLOAT:_sbcvar1831 +LITERAL_FLOAT:10.0 +-(10.0,_sbcvar1831) +::STMT +FLOAT:sd_Y,sd_X +-(sqrt(sd_Y),sqrt(sd_X)) +::STMT +MATRIX:distT +LITERAL_FLOAT:0.0 +!=(distT,0.0) +::STMT +FLOAT:a,b +LITERAL_FLOAT:2.0 +*(2.0,*(a,b)) +::STMT +MATRIX:_sbcvar1006 +LITERAL_FLOAT:0.0 +>(t(_sbcvar1006),0.0) +::STMT +MATRIX:parsertemp31933,X2,parsertemp31935 +t(colSums(==(%*%(X2,parsertemp31935),t(parsertemp31933)))) +::STMT +LITERAL_FLOAT:999.0 +999.0 +::STMT +FLOAT:Hin +LITERAL_FLOAT:184.0 ++(Hin,184.0) +::STMT +MATRIX:leaf_ids +FLOAT:boundary_left,step_size +!(<(leaf_ids,+(boundary_left,step_size))) +::STMT +MATRIX:C,Xm,parsertemp265701,XtZ +FLOAT:ZtZ_sum +-(%*%(%*%(Xm,%*%(C,parsertemp265701)),t(/(XtZ,ZtZ_sum))),Xm) +::STMT +FLOAT:i +LITERAL_FLOAT:64.0 ++(i,64.0) +::STMT +MATRIX:filled_matrix,aligned +t(%*%(t(filled_matrix),aligned)) +::STMT +MATRIX:m_active_flag_tmp +LITERAL_FLOAT:1.0 +!=(m_active_flag_tmp,1.0) +::STMT +LITERAL_FLOAT:1.01 +1.01 +::STMT +MATRIX:p,r,parsertemp1934,parsertemp1935,parsertemp1940 +FLOAT:norm_r2,eps ++(r,*(/(norm_r2,cast.FLOAT(parsertemp1940)),+(%*%(parsertemp1934,parsertemp1935),*(eps,p)))) +::STMT +MATRIX:X +FLOAT:x +LITERAL_FLOAT:1.0 +-(1.0,/(-(x,X),-(X,X))) +::STMT +FLOAT:parsertemp98,int764,var,m4,parsertemp99,int59,parsertemp93,parsertemp94,wt,parsertemp105,parsertemp104 +LITERAL_FLOAT:4.0 +/(-(*(*(parsertemp93,parsertemp94),m4),*(*(parsertemp98,parsertemp99),-(wt,int59))),*(*(*(parsertemp104,parsertemp105),-(wt,int764)),^(sqrt(var),4.0))) +::STMT +MATRIX:resp,mean,X,weight +LITERAL_FLOAT:2.0 +-(/(%*%(t(resp),*(X,X)),t(weight)),*(2.0,^(mean,2.0))) +::STMT +LITERAL_FLOAT:3.141592653589793,2.0 +*(2.0,3.141592653589793) +::STMT +MATRIX:X +LITERAL_FLOAT:10.0 +!=(X,10.0) +::STMT +MATRIX:X,ScaleFactor +FLOAT:N +%*%(t(/(colSums(X),N)),/(colSums(/(X,ScaleFactor)),N)) +::STMT +FLOAT:beg +LITERAL_FLOAT:512.0 ++(beg,512.0) +::STMT +MATRIX:border,parsertemp386449,neighbors,parsertemp386460 +FLOAT:int891,int557 +LITERAL_FLOAT:0.0 +>(+(*(>(parsertemp386449,int557),==(parsertemp386460,int891)),t(*(neighbors,border))),0.0) +::STMT +MATRIX:parsertemp31188,parsertemp31186 +FLOAT:int184 +LITERAL_FLOAT:1.0,2.0,7000.0 +^(/(-(colSums(parsertemp31186),*(int184,parsertemp31188)),-(7000.0,1.0)),2.0) +::STMT +MATRIX:X,Y ++(abs(X),abs(Y)) +::STMT +MATRIX:mean,weight +%*%(*(t(mean),weight),mean) +::STMT +MATRIX:R,parsertemp40219,parsertemp40216 +FLOAT:numRows,level +/(numRows,-(+(R,rowSums(parsertemp40216)),rowSums(==(parsertemp40219,level)))) +::STMT +FLOAT:beg +LITERAL_FLOAT:256.0 ++(beg,256.0) +::STMT +FLOAT:i +LITERAL_FLOAT:253.0 ++(i,253.0) +::STMT +MATRIX:os,y,o +FLOAT:int917 +LITERAL_FLOAT:1.0 ++(1.0,exp(*(-(int917,y),+(o,os)))) +::STMT +MATRIX:X,tS +FLOAT:l +==(%*%(X,tS),l) +::STMT +LITERAL_FLOAT:2.0,83.0 +/(83.0,2.0) +::STMT +MATRIX:parsertemp171348,is_too_small,parsertemp171346,parsertemp171344,parsertemp171353,Y,the_exp,parsertemp171349 +FLOAT:int124,int429 +/(-(*(rowSums(Y),exp(parsertemp171344)),Y),+(/(*(parsertemp171348,parsertemp171349),+(the_exp,is_too_small)),*(==(parsertemp171346,int429),-(int124,parsertemp171353)))) +::STMT +FLOAT:i +LITERAL_FLOAT:3000.0 +-(3000.0,i) +::STMT +MATRIX:parsertemp400664,parsertemp400661,W3_rand +LITERAL_FLOAT:0.2656844656620286 +t(%*%(*(0.2656844656620286,W3_rand),t(/(parsertemp400661,parsertemp400664)))) +::STMT +MATRIX:240_elt,240_ones_ctg +/(240_elt,%*%(rowSums(240_elt),t(240_ones_ctg))) +::STMT +MATRIX:Bxu,Bxd ++(Bxd,Bxu) +::STMT +FLOAT:42_m2X +LITERAL_FLOAT:1.001001001001001 +*(42_m2X,1.001001001001001) +::STMT +MATRIX:parsertemp43634 +FLOAT:float614,int863,int687,float282,float925,float13 +INT:int241,int486,int281,int506 +sum(*(+(rand(int281,int486,float282,float614),*(int863,parsertemp43634)),+(rand(int506,int241,float925,float13),*(int687,parsertemp43634)))) +::STMT +MATRIX:221_present_domain_vals_mat,parsertemp27770 +sqrt(%*%(221_present_domain_vals_mat,parsertemp27770)) +::STMT +MATRIX:s +LITERAL_FLOAT:1.0,2.0 +*(1.0,sum(^(s,2.0))) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:0.0 +exp(-(0.0,linear_terms)) +::STMT +MATRIX:finite_linear_terms +LITERAL_FLOAT:0.0 +-(0.0,exp(finite_linear_terms)) +::STMT +FLOAT:i +LITERAL_FLOAT:16.0,1.0 ++(*(-(i,1.0),16.0),1.0) +::STMT +MATRIX:Y,parsertemp221025 +FLOAT:int526 +LITERAL_FLOAT:1.0 +sum(*(/(1.0,+(Y,int526)),+(diag(parsertemp221025),1.0))) +::STMT +MATRIX:logisticnew +LITERAL_FLOAT:1.0 +*(logisticnew,-(1.0,logisticnew)) +::STMT +MATRIX:parsertemp437238,parsertemp437237,mean,weight,parsertemp437242,avgMean +FLOAT:int92,reg_covar ++(+(-(/(parsertemp437237,parsertemp437238),*(int92,avgMean)),/(*(mean,parsertemp437242),t(weight))),reg_covar) +::STMT +MATRIX:simplex +LITERAL_FLOAT:2.0,4.0 +*(2.0,/(-(rowSums(simplex),simplex),4.0)) +::STMT +MATRIX:posSamples,posSampleMeans +LITERAL_FLOAT:2.0,100.0 +-(colSums(^(posSamples,2.0)),*(100.0,^(posSampleMeans,2.0))) +::STMT +MATRIX:X2,85_s +FLOAT:alpha,int392 +LITERAL_FLOAT:1.0 +*(-(1.0,alpha),-(*(/(int392,85_s),nrow(X2)),1.0)) +::STMT +MATRIX:shift_X,beta_unscaled +cast.FLOAT(%*%(t(shift_X),beta_unscaled)) +::STMT +MATRIX:Y +FLOAT:num_categories +LITERAL_FLOAT:-1.0 ++(*(Y,-1.0),num_categories) +::STMT +LITERAL_FLOAT:24.0,1.0 +*(24.0,1.0) +::STMT +MATRIX:Nc +/(Nc,sum(Nc)) +::STMT +LITERAL_FLOAT:0.0,1.0 +INT:int540,int910,int823,int161 ++(sum(rand(int161,int823,0.0,1.0)),sum(rand(int910,int540,0.0,1.0))) +::STMT +MATRIX:w_X,z_LS +LITERAL_FLOAT:10000.0 +/(10000.0,cast.FLOAT(%*%(t(w_X),z_LS))) +::STMT +MATRIX:Y_counts +FLOAT:num_features +LITERAL_FLOAT:1.0 +-(-(sum(Y_counts),num_features),1.0) +::STMT +LITERAL_FLOAT:1.0E-9,10.0 +-(10.0,1.0E-9) +::STMT +MATRIX:parsertemp570396,classVars +FLOAT:varSmoothing +*(*(diag(parsertemp570396),max(classVars)),varSmoothing) +::STMT +MATRIX:parsertemp460643 +LITERAL_FLOAT:0.025253813613805267 +*(parsertemp460643,0.025253813613805267) +::STMT +LITERAL_FLOAT:1.0,2.0,4.0,2000.0 +*(4.0,-(^(2000.0,2.0),1.0)) +::STMT +MATRIX:Bx,Yd,Yu +/(-(Yu,Yd),*(Bx,Bx)) +::STMT +MATRIX:252_X +LITERAL_FLOAT:1.0,4.5 +-(1.0,/(-(4.5,cast.FLOAT(252_X)),-(cast.FLOAT(252_X),cast.FLOAT(252_X)))) +::STMT +LITERAL_FLOAT:0.35 +0.35 +::STMT +FLOAT:parsertemp40916,int333,m2 +LITERAL_FLOAT:2001.0 +/(sqrt(*(/(int333,parsertemp40916),m2)),sqrt(2001.0)) +::STMT +MATRIX:P,scale_X,X,Y +%*%(diag(scale_X),%*%(t(X),-(P,Y))) +::STMT +MATRIX:s,w +LITERAL_FLOAT:100.0 +*(100.0,cast.FLOAT(%*%(t(w),s))) +::STMT +FLOAT:o_init +LITERAL_FLOAT:-2.0,50.0 +exp(/(*(-2.0,o_init),50.0)) +::STMT +MATRIX:G,authorities +/(%*%(t(G),%*%(G,authorities)),max(%*%(t(G),%*%(G,authorities)))) +::STMT +MATRIX:is_natural_parameter_log_zero,Y +*(is_natural_parameter_log_zero,abs(Y)) +::STMT +FLOAT:43_q +LITERAL_FLOAT:1.0,1000.0 +*(1000.0,-(43_q,1.0)) +::STMT +FLOAT:m2X,W +LITERAL_FLOAT:1.0 +*(m2X,/(W,-(W,1.0))) +::STMT +MATRIX:r,Hd +FLOAT:c +t(+(r,*(c,Hd))) +::STMT +MATRIX:TKC +/(cast.FLOAT(TKC),cast.FLOAT(TKC)) +::STMT +LITERAL_FLOAT:0.5,-0.5 +INT:rank,m +rand(m,rank,-0.5,0.5) +::STMT +MATRIX:parsertemp382917,U,W +t(%*%(t(U),*(W,%*%(U,parsertemp382917)))) +::STMT +LITERAL_FLOAT:1.0E8 +1.0E8 +::STMT +FLOAT:int384,i,Hin,Win +LITERAL_FLOAT:1.0 ++(*(*(-(i,int384),Hin),Win),1.0) +::STMT +MATRIX:X,weight +/(weight,nrow(X)) +::STMT +MATRIX:a,b,t,parsertemp32856,Y,parsertemp32827,parsertemp32824 +FLOAT:int228,int23 ++(+(*(-(int228,t),Y),*(/(parsertemp32824,parsertemp32827),Y)),*(*(/(parsertemp32824,parsertemp32827),-(int23,t)),+(*(a,parsertemp32856),*(b,t)))) +::STMT +MATRIX:parsertemp30951,G,authorities,hubs +-(/(%*%(t(G),%*%(G,authorities)),max(%*%(parsertemp30951,hubs))),authorities) +::STMT +FLOAT:_sbcvar1735 +LITERAL_FLOAT:12.0 +-(12.0,_sbcvar1735) +::STMT +FLOAT:i,num_centroids +LITERAL_FLOAT:2.0 ++(*(num_centroids,2.0),i) +::STMT +MATRIX:parsertemp150470,LT,parsertemp149320,parsertemp150469 +/(exp(-(LT,%*%(parsertemp149320,parsertemp150469))),%*%(rowSums(exp(LT)),parsertemp150470)) +::STMT +MATRIX:w,out +FLOAT:reg +LITERAL_FLOAT:2.0,0.5 ++(*(0.5,sum(*(out,out))),*(/(reg,2.0),sum(*(w,w)))) +::STMT +MATRIX:H_inv +sqrt(diag(H_inv)) +::STMT +MATRIX:parsertemp220853,W,sum_Pi,beta +FLOAT:logU +-(+(parsertemp220853,*(beta,/(W,sum_Pi))),logU) +::STMT +MATRIX:meanDiff,parsertemp570372,parsertemp570375 +LITERAL_FLOAT:-1.0,1.0,2.0 +-(*(/(-1.0,2.0),parsertemp570372),*(/(1.0,2.0),%*%(%*%(meanDiff,parsertemp570375),t(meanDiff)))) +::STMT +MATRIX:W,parsertemp411198,H,parsertemp411200 +LITERAL_FLOAT:1.0E-8 ++(%*%(W,/(*(H,parsertemp411198),t(parsertemp411200))),1.0E-8) +::STMT +FLOAT:parsertemp190487,parsertemp190486,FN,TN,FP,TP +/(-(*(TP,TN),*(FP,FN)),sqrt(*(*(parsertemp190486,parsertemp190487),+(TN,FN)))) +::STMT +MATRIX:vW1,parsertemp146961,dout1 +FLOAT:191_beta2 +LITERAL_FLOAT:1.0,2.0 ++(*(191_beta2,vW1),*(-(1.0,191_beta2),^(%*%(parsertemp146961,dout1),2.0))) +::STMT +MATRIX:r,parsertemp1945 +FLOAT:norm_r2 +LITERAL_FLOAT:2.0 +/(sum(^(+(r,parsertemp1945),2.0)),norm_r2) +::STMT +MATRIX:WM +LITERAL_FLOAT:1.0 +/(sum(WM),-(sum(WM),1.0)) +::STMT +MATRIX:output_values,initial_prediction +LITERAL_FLOAT:0.3 ++(initial_prediction,*(0.3,sum(output_values))) +::STMT +FLOAT:so_exact,so_linear_approx +LITERAL_FLOAT:-0.5 +/(*(-0.5,so_linear_approx),-(so_exact,so_linear_approx)) +::STMT +MATRIX:X +LITERAL_FLOAT:2.0 +rowSums(^(X,2.0)) +::STMT +MATRIX:p,z +LITERAL_FLOAT:-1.0 +*(sum(*(p,z)),-1.0) +::STMT +MATRIX:LT,Y,parsertemp149320,parsertemp150469 +sum(*(Y,-(LT,%*%(parsertemp149320,parsertemp150469)))) +::STMT +MATRIX:finite_linear_terms +LITERAL_FLOAT:0.0,2.0 +-(0.0,^(finite_linear_terms,2.0)) +::STMT +LITERAL_FLOAT:40.0,20.0 +*(20.0,40.0) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:0.0,1.0,2.0 +-(*(2.0,>=(linear_terms,0.0)),1.0) +::STMT +MATRIX:W +LITERAL_FLOAT:1.0,2.0 +*(-(sum(round(W)),1.0),-(sum(round(W)),2.0)) +::STMT +MATRIX:initial_prediction +INT:int744,parsertemp186173 +rand(parsertemp186173,int744,cast.FLOAT(initial_prediction),cast.FLOAT(initial_prediction)) +::STMT +MATRIX:s,w +sum(*(w,s)) +::STMT +MATRIX:252_X +LITERAL_FLOAT:4.5 +-(4.5,cast.FLOAT(252_X)) +::STMT +LITERAL_FLOAT:1.0,2.0,2003.0 +*(-(2003.0,2.0),+(2003.0,1.0)) +::STMT +MATRIX:scale_lambda +LITERAL_FLOAT:0.001 +diag(*(scale_lambda,0.001)) +::STMT +MATRIX:out1,187_dX,parsertemp146955 +FLOAT:int533 +LITERAL_FLOAT:2.0 +^(colSums(*(>(out1,int533),*(parsertemp146955,187_dX))),2.0) +::STMT +MATRIX:R +LITERAL_FLOAT:0.0,4.0 +&(>=(R,4.0),>(R,0.0)) +::STMT +MATRIX:precisions,X,parsertemp436695,bc_matrix,parsertemp436691 +LITERAL_FLOAT:2.0 +-(*(bc_matrix,t(*(parsertemp436691,precisions))),*(2.0,%*%(X,t(parsertemp436695)))) +::STMT +MATRIX:grad +LITERAL_FLOAT:0.0,2.0 +^(-(0.0,grad),2.0) +::STMT +MATRIX:id +==(id,t(id)) +::STMT +FLOAT:link_power +LITERAL_FLOAT:1.0 +-(/(1.0,link_power),1.0) +::STMT +FLOAT:link_power +LITERAL_FLOAT:-1.0,1.0 +-(/(-1.0,link_power),1.0) +::STMT +MATRIX:parsertemp10743,V,parsertemp10742,H,parsertemp10739,parsertemp10738 +FLOAT:Eps +%*%(*(H,/(%*%(parsertemp10738,V),+(parsertemp10742,Eps))),t(*(H,/(parsertemp10739,parsertemp10743)))) +::STMT +MATRIX:W +FLOAT:m2 +LITERAL_FLOAT:1.0 +/(*(m2,sum(round(W))),-(sum(round(W)),1.0)) +::STMT +MATRIX:parsertemp44076,wnew,parsertemp44079 +LITERAL_FLOAT:-1.0,2.0,0.5 ++(*(0.5,cast.FLOAT(%*%(parsertemp44076,wnew))),*(2.0,*(-1.0,sum(parsertemp44079)))) +::STMT +LITERAL_FLOAT:1.0,2.0,2001.0 +*(-(2001.0,2.0),+(2001.0,1.0)) +::STMT +MATRIX:A,foffb +LITERAL_FLOAT:0.0 +*(!=(A,0.0),+(A,foffb)) +::STMT +MATRIX:parsertemp397841,parsertemp397838,W4_rand +LITERAL_FLOAT:0.0873148795050037 +t(%*%(*(0.0873148795050037,W4_rand),t(/(parsertemp397838,parsertemp397841)))) +::STMT +MATRIX:parsertemp220900,parsertemp220899 +LITERAL_FLOAT:300.0,0.0,2.0 +^(-(0.0,*(300.0,-(parsertemp220899,parsertemp220900))),2.0) +::STMT +MATRIX:lambda,g,beta +LITERAL_FLOAT:0.0 +-(0.0,+(g,*(lambda,beta))) +::STMT +MATRIX:parsertemp76118 +LITERAL_FLOAT:0.5,4460.0 +round(+(0.5,/(parsertemp76118,4460.0))) +::STMT +MATRIX:knn_index +FLOAT:iter,i ++(*(iter,ncol(knn_index)),i) +::STMT +MATRIX:output,output1 +LITERAL_FLOAT:0.1 +sum(>=(abs(-(output,output1)),0.1)) +::STMT +MATRIX:n_event_stratum,n_risk_stratum,n_risk +FLOAT:int472 +LITERAL_FLOAT:1.0 +/(*(*(^(n_risk_stratum,int472),*(n_risk,n_event_stratum)),-(n_risk_stratum,n_event_stratum)),*(n_risk_stratum,-(n_risk_stratum,1.0))) +::STMT +MATRIX:X +FLOAT:parsertemp165083 +LITERAL_FLOAT:2.0 ++(*(2.0,ncol(X)),*(nrow(X),parsertemp165083)) +::STMT +FLOAT:float538,int243,42_m2X +LITERAL_FLOAT:1000.0 +sqrt(*(42_m2X,/(1000.0,-(int243,float538)))) +::STMT +MATRIX:C,Xm,parsertemp265706,parsertemp265702,parsertemp265701 +FLOAT:ss ++(%*%(t(%*%(Xm,parsertemp265702)),%*%(Xm,%*%(C,parsertemp265701))),*(parsertemp265706,ss)) +::STMT +FLOAT:parsertemp115814,sum_sq_y_test,n,ss_res +LITERAL_FLOAT:1.0 +-(1.0,/(ss_res,-(sum_sq_y_test,*(n,parsertemp115814)))) +::STMT +MATRIX:parsertemp560507,Y +sum(rowSums(*(Y,parsertemp560507))) +::STMT +FLOAT:parsertemp382948,parsertemp382956,loss_init,parsertemp382953 +LITERAL_FLOAT:0.5,5.0E-7 +-(loss_init,+(*(0.5,parsertemp382948),*(5.0E-7,+(parsertemp382953,parsertemp382956)))) +::STMT +MATRIX:p_CG +FLOAT:parsertemp285739,parsertemp285737,pp_CG +LITERAL_FLOAT:-1.0 +/(+(*(cast.FLOAT(p_CG),-1.0),sqrt(-(parsertemp285737,parsertemp285739))),pp_CG) +::STMT +MATRIX:p,q,A +FLOAT:norm_r2 +*(/(norm_r2,sum(*(p,q))),%*%(t(A),%*%(A,p))) +::STMT +MATRIX:X +FLOAT:n +/(t(colSums(X)),n) +::STMT +MATRIX:finite_linear_terms +LITERAL_FLOAT:1.0,0.231641888 +/(1.0,+(1.0,*(abs(finite_linear_terms),0.231641888))) +::STMT +MATRIX:parsertemp2832 +min(round(parsertemp2832)) +::STMT +MATRIX:parsertemp11277 +FLOAT:block_size +LITERAL_FLOAT:1.0 ++(1.0,*(block_size,parsertemp11277)) +::STMT +MATRIX:objvals +LITERAL_FLOAT:1.5000000000000002E-8 +*(1.5000000000000002E-8,cast.FLOAT(objvals)) +::STMT +FLOAT:std,float481,float926 +LITERAL_FLOAT:2.0 +INT:int300,int902 +^(*(cast.FLOAT(rand(int300,int902,float481,float926)),std),2.0) +::STMT +MATRIX:R,parsertemp40216 +FLOAT:numRows +LITERAL_FLOAT:1.0 +-(/(numRows,+(R,rowSums(parsertemp40216))),1.0) +::STMT +MATRIX:parsertemp147200,X_train +LITERAL_FLOAT:2.0 +*(parsertemp147200,sqrt(/(2.0,ncol(X_train)))) +::STMT +LITERAL_FLOAT:1.0,2.0,2003.0 +-(^(2003.0,2.0),1.0) +::STMT +MATRIX:categorical,X_sys,freq,mask +LITERAL_FLOAT:0.0 ++(*(X_sys,==(mask,0.0)),*(>(categorical,0.0),freq)) +::STMT +MATRIX:id +diag(diag(==(id,t(id)))) +::STMT +LITERAL_FLOAT:1.0,2.0,2000.0 +*(-(2000.0,2.0),+(2000.0,1.0)) +::STMT +MATRIX:parsertemp77570 +LITERAL_FLOAT:0.5,2358.0 +round(+(0.5,/(parsertemp77570,2358.0))) +::STMT +MATRIX:parsertemp379566,m_iter_err_sum,m_err +FLOAT:int404,i_process_item +LITERAL_FLOAT:2.0 +*(*(2.0,/(*(parsertemp379566,int404),i_process_item)),+(colSums(m_err),m_iter_err_sum)) +::STMT +FLOAT:m2,mu +LITERAL_FLOAT:1.0005002501250626 +/(sqrt(*(1.0005002501250626,m2)),mu) +::STMT +MATRIX:r_CG,g_reg,z +LITERAL_FLOAT:0.5 +*(0.5,*(cast.FLOAT(z),+(cast.FLOAT(r_CG),cast.FLOAT(g_reg)))) +::STMT +LITERAL_FLOAT:0.231641888 +0.231641888 +::STMT +MATRIX:W +FLOAT:int553,m3,var,wt,int628 +LITERAL_FLOAT:2.0,3.0 +/(*(^(sum(W),2.0),m3),*(*(-(wt,int553),-(wt,int628)),^(sqrt(var),3.0))) +::STMT +MATRIX:p,r +FLOAT:norm_r2 +*(/(sum(*(r,r)),norm_r2),p) +::STMT +MATRIX:parsertemp116094,parsertemp116097 +LITERAL_FLOAT:0.0,32.0 +sum(|(<(t(parsertemp116094),32.0),==(t(parsertemp116097),0.0))) +::STMT +FLOAT:link_power +LITERAL_FLOAT:1.0,2.0 +-(/(1.0,link_power),2.0) +::STMT +MATRIX:A,scale_X +%*%(diag(scale_X),A) +::STMT +FLOAT:sv,rad,v2 +/(-(rad,sv),v2) +::STMT +MATRIX:B2,ytest,Xtest +t(-(ytest,%*%(Xtest,B2))) +::STMT +MATRIX:V +min(V) +::STMT +MATRIX:diff_nominal,diff,mask +FLOAT:num_std_median +LITERAL_FLOAT:0.0 ++(*(!=(diff_nominal,0.0),num_std_median),*(diff,==(mask,0.0))) +::STMT +MATRIX:s,parsertemp44016,d +*(%*%(t(-(s,parsertemp44016)),d),%*%(t(-(s,parsertemp44016)),d)) +::STMT +MATRIX:col +FLOAT:min_val,bin_width +LITERAL_FLOAT:0.5 +-(/(-(col,min_val),bin_width),0.5) +::STMT +LITERAL_FLOAT:0.7 +0.7 +::STMT +MATRIX:Y_counts,means,Y +%*%(Y_counts,/(colSums(-(Y,means)),sum(Y_counts))) +::STMT +FLOAT:p,P +LITERAL_FLOAT:1.0 ++(+(1.0,p),P) +::STMT +FLOAT:int494,parsertemp115813,sum_sq_y_test,n,ss_res +/(ss_res,-(sum_sq_y_test,*(n,^(parsertemp115813,int494)))) +::STMT +FLOAT:a,c +LITERAL_FLOAT:4.0 +*(*(4.0,a),c) +::STMT +LITERAL_FLOAT:0.95 +0.95 +::STMT +MATRIX:parsertemp409058,parsertemp409054,ctab +LITERAL_FLOAT:0.6 +*(parsertemp409058,>(/(parsertemp409054,rowSums(ctab)),0.6)) +::STMT +MATRIX:cov +LITERAL_FLOAT:1.0 +/(1.0,sqrt(cov)) +::STMT +FLOAT:m2 +LITERAL_FLOAT:2.0 +^(m2,2.0) +::STMT +FLOAT:parsertemp459295 +LITERAL_FLOAT:1.0,128.0 ++(+(parsertemp459295,1.0),128.0) +::STMT +MATRIX:parsertemp472305,_funvar2708,Iright,_funvar2706,_funvar2707 +FLOAT:numI +-(-(cast.FLOAT(_funvar2706),*(/(parsertemp472305,numI),_funvar2707)),*(/(rowSums(Iright),numI),_funvar2708)) +::STMT +MATRIX:parsertemp170251,lt_pos_neg +FLOAT:int953 +LITERAL_FLOAT:2.0,0.5 +*(-(0.5,lt_pos_neg),exp(/(*(parsertemp170251,int953),2.0))) +::STMT +MATRIX:Xd,out +FLOAT:int515 +sum(*(*(Xd,>(out,int515)),Xd)) +::STMT +MATRIX:parsertemp500439,y +LITERAL_FLOAT:0.5 +*(0.5,sum(*(-(parsertemp500439,y),-(parsertemp500439,y)))) +::STMT +MATRIX:oldE +LITERAL_FLOAT:1.0 +/(sum(oldE),1.0) +::STMT +MATRIX:csgaps,csmask +*(csgaps,>(csgaps,csmask)) +::STMT +MATRIX:X_cluster_local,X_comp,X_sim +|(X_cluster_local,*(X_comp,X_sim)) +::STMT +MATRIX:2364_2360_Y_prime,W2,W3,2364_2359_Y,parsertemp389610 +FLOAT:int704 +LITERAL_FLOAT:1.0 +%*%(*(-(1.0,^(2364_2359_Y,int704)),%*%(*(2364_2360_Y_prime,parsertemp389610),W3)),W2) +::STMT +LITERAL_FLOAT:1.0E-8 +1.0E-8 +::STMT +MATRIX:Y,parsertemp2773,Xw +LITERAL_FLOAT:0.0,1.0 +>(-(1.0,*(Y,+(Xw,parsertemp2773))),0.0) +::STMT +MATRIX:W,H +LITERAL_FLOAT:1.0E-8 ++(%*%(W,H),1.0E-8) +::STMT +MATRIX:A,b +LITERAL_FLOAT:-1.0,2.0 +^(%*%(*(t(A),-1.0),b),2.0) +::STMT +MATRIX:C,C_old +LITERAL_FLOAT:2.0 +sum(^(-(C,C_old),2.0)) +::STMT +MATRIX:P,lambda,X,Y,B_new ++(%*%(t(X),-(P,Y)),*(lambda,B_new)) +::STMT +MATRIX:Xtest_dists +LITERAL_FLOAT:0.0,1.0 +rowSums(*(<=(Xtest_dists,1.0),<(0.0,Xtest_dists))) +::STMT +LITERAL_FLOAT:16.0,15.0 +*(15.0,16.0) +::STMT +MATRIX:parsertemp414376,parsertemp414378 +LITERAL_FLOAT:0.0,1.0,199.0 +-(1.0,<=(/(-(parsertemp414376,parsertemp414378),199.0),0.0)) +::STMT +LITERAL_FLOAT:0.05473123640475826 +0.05473123640475826 +::STMT +FLOAT:parsertemp164939 +LITERAL_FLOAT:100.0 +*(100.0,parsertemp164939) +::STMT +MATRIX:r,parsertemp1936,parsertemp1937 +FLOAT:parsertemp1941,norm_r2 +LITERAL_FLOAT:-1.0 +*(+(r,*(/(norm_r2,parsertemp1941),+(parsertemp1936,parsertemp1937))),-1.0) +::STMT +MATRIX:_sbcvar1716 +LITERAL_FLOAT:0.8 +*(_sbcvar1716,0.8) +::STMT +MATRIX:A +rowSums(abs(A)) +::STMT +MATRIX:parsertemp30951,G,authorities,hubs +-(/(%*%(t(G),%*%(G,authorities)),max(%*%(parsertemp30951,hubs))),authorities) +::STMT +MATRIX:negSampleMeans,negSamples +FLOAT:int960,int292 +LITERAL_FLOAT:1.0,1500.0 +/(-(colSums(^(negSamples,int960)),*(1500.0,^(negSampleMeans,int292))),-(1500.0,1.0)) +::STMT +MATRIX:X,Y +FLOAT:x +*(/(-(x,X),-(X,X)),Y) +::STMT +LITERAL_FLOAT:1.0,10000.0,0.8 ++(*(10000.0,0.8),1.0) +::STMT +MATRIX:Hdiff,betamax,beta +FLOAT:int762,int537 +LITERAL_FLOAT:1.0E20 +==(+(*(>=(Hdiff,int537),betamax),*(<(Hdiff,int762),beta)),1.0E20) +::STMT +MATRIX:addedE +LITERAL_FLOAT:20.0 +/(sum(addedE),20.0) +::STMT +MATRIX:parsertemp570372 +LITERAL_FLOAT:-1.0,2.0 +*(/(-1.0,2.0),parsertemp570372) +::STMT +MATRIX:parsertemp43634 +FLOAT:int332 +LITERAL_FLOAT:0.0,2.0 +sum(^(+(0.0,*(int332,parsertemp43634)),2.0)) +::STMT +MATRIX:dotMissing,parsertemp553021,dotM2 +FLOAT:int159 +t(sqrt(-(+(dotM2,dotMissing),*(int159,parsertemp553021)))) +::STMT +MATRIX:parsertemp436043 +LITERAL_FLOAT:1.0 +INT:int684,n_col +%*%(parsertemp436043,rand(int684,n_col,1.0,1.0)) +::STMT +MATRIX:features,beta_unscaled +FLOAT:intercept,parsertemp176418 +LITERAL_FLOAT:3.0 +-(sqrt(parsertemp176418),*(3.0,+(%*%(features,beta_unscaled),intercept))) +::STMT +MATRIX:X,I +LITERAL_FLOAT:1.0 +-(/(nrow(X),t(colSums(I))),1.0) +::STMT +MATRIX:parsertemp506990 +LITERAL_FLOAT:0.7 +<(parsertemp506990,0.7) +::STMT +MATRIX:252_K +LITERAL_FLOAT:0.0 +-(0.0,cast.FLOAT(252_K)) +::STMT +MATRIX:addedE +LITERAL_FLOAT:40.0 +/(sum(addedE),40.0) +::STMT +LITERAL_FLOAT:8.674675786448736 +8.674675786448736 +::STMT +MATRIX:e,X,tS +FLOAT:l +%*%(t(e),==(%*%(X,tS),l)) +::STMT +MATRIX:_sbcvar332 +LITERAL_FLOAT:9999.0 +/(_sbcvar332,9999.0) +::STMT +MATRIX:TK +LITERAL_FLOAT:0.0 ++(TK,==(TK,0.0)) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:-1.0,1.0 +-(exp(*(linear_terms,-1.0)),1.0) +::STMT +MATRIX:parsertemp31908,X +FLOAT:l +/(nrow(X),t(colSums(==(parsertemp31908,l)))) +::STMT +MATRIX:p,Z +cast.FLOAT(%*%(t(p),%*%(Z,p))) +::STMT +MATRIX:W +FLOAT:m2,int169 +LITERAL_FLOAT:1.0,2.0,3.0 +*(*(*(3.0,^(m2,int169)),^(sum(W),2.0)),-(sum(round(W)),1.0)) +::STMT +MATRIX:parsertemp43619 +LITERAL_FLOAT:1.0 +-(/(1.0,+(1.0,exp(parsertemp43619))),1.0) +::STMT +MATRIX:minD,parsertemp222602,parsertemp222599 +FLOAT:int967 +rowSums(<=(+(*(int967,parsertemp222599),t(parsertemp222602)),minD)) +::STMT +FLOAT:num_hidden1,m +LITERAL_FLOAT:6.0 +/(sqrt(6.0),sqrt(+(m,num_hidden1))) +::STMT +FLOAT:pad_size,Hin +LITERAL_FLOAT:1.0 +-(Hin,-(pad_size,1.0)) +::STMT +MATRIX:R,parsertemp500360,parsertemp500307,parsertemp500359 +FLOAT:int52 ++(%*%(rowSums(^(R,int52)),parsertemp500359),%*%(parsertemp500360,t(rowSums(parsertemp500307)))) +::STMT +MATRIX:RDMean,parsertemp265748 +LITERAL_FLOAT:2.0 +-(parsertemp265748,^(RDMean,2.0)) +::STMT +FLOAT:float503,float111 +LITERAL_FLOAT:1.0 +INT:int154,int585 +/(1.0,+(1.0,exp(rand(int585,int154,float503,float111)))) +::STMT +MATRIX:parsertemp460642 +LITERAL_FLOAT:0.05 +*(parsertemp460642,0.05) +::STMT +MATRIX:Y,missing_mask_Y +LITERAL_FLOAT:0.0,1.0 ++(*(missing_mask_Y,+(max(Y),1.0)),*(Y,==(missing_mask_Y,0.0))) +::STMT +LITERAL_FLOAT:1.0,1000.0 +-(1000.0,1.0) +::STMT +MATRIX:vW2,dW2 +FLOAT:193_beta2 +LITERAL_FLOAT:1.0,2.0 ++(*(193_beta2,vW2),*(-(1.0,193_beta2),^(dW2,2.0))) +::STMT +MATRIX:F +%*%(rowSums(F),colSums(F)) +::STMT +MATRIX:parsertemp146940,184_dtemp,mb3 +FLOAT:beta1 +LITERAL_FLOAT:1.0 ++(*(beta1,mb3),*(-(1.0,beta1),colSums(-(184_dtemp,parsertemp146940)))) +::STMT +MATRIX:S,V +LITERAL_FLOAT:2.0 +^(sum(*(S,V)),2.0) +::STMT +MATRIX:tmp,X ++(%*%(t(X),X),diag(tmp)) +::STMT +MATRIX:P,gradients,Theta +FLOAT:alpha +*(alpha,%*%(t(gradients),%*%(P,Theta))) +::STMT +MATRIX:parsertemp389212 +LITERAL_FLOAT:1058.0 +/(parsertemp389212,1058.0) +::STMT +MATRIX:linear_terms +FLOAT:link_power +LITERAL_FLOAT:0.0,1.0 +^(linear_terms,-(/(0.0,link_power),1.0)) +::STMT +FLOAT:parsertemp22485,parsertemp22452,parsertemp22453 +LITERAL_FLOAT:2.0 ++(parsertemp22485,*(2.0,sqrt(+(parsertemp22452,parsertemp22453)))) +::STMT +MATRIX:parsertemp10964,C +==(parsertemp10964,C) +::STMT +MATRIX:parsertemp146931,184_dtemp,parsertemp146929,184_unnorm_probs,parsertemp146936,W3 +%*%(-(*(*(parsertemp146929,parsertemp146931),/(184_unnorm_probs,parsertemp146936)),*(/(184_unnorm_probs,parsertemp146936),rowSums(184_dtemp))),t(W3)) +::STMT +MATRIX:linear_terms +FLOAT:link_power +LITERAL_FLOAT:2.0 +^(linear_terms,-(/(2.0,link_power),2.0)) +::STMT +MATRIX:linear_terms +FLOAT:link_power +LITERAL_FLOAT:0.0,2.0 +^(linear_terms,-(/(0.0,link_power),2.0)) +::STMT +FLOAT:s_rows,h +LITERAL_FLOAT:2.0 +/(-(s_rows,h),2.0) +::STMT +FLOAT:i +LITERAL_FLOAT:314.0 ++(314.0,i) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:1.0 +/(1.0,+(1.0,exp(linear_terms))) +::STMT +LITERAL_FLOAT:1.0,100.0 +INT:int212,int982 +rand(int212,int982,1.0,100.0) +::STMT +MATRIX:parsertemp181045 +FLOAT:window_size,q,parsertemp181038 +LITERAL_FLOAT:1.0 +-(1.0,/(-(q,*(window_size,parsertemp181038)),*(window_size,cast.FLOAT(parsertemp181045)))) +::STMT +MATRIX:col_nonzeros,parsertemp383019,parsertemp383016,row_nonzeros +FLOAT:reg +*(reg,+(sum(*(parsertemp383016,row_nonzeros)),sum(*(parsertemp383019,col_nonzeros)))) +::STMT +MATRIX:output1,dataset +LITERAL_FLOAT:0.16 +sum(>=(abs(-(output1,dataset)),0.16)) +::STMT +LITERAL_FLOAT:1.0,2.0,7000.0 +*(^(7000.0,2.0),-(7000.0,1.0)) +::STMT +MATRIX:P,scale_X,shift_X,X,Y,Grad ++(%*%(diag(scale_X),%*%(t(X),-(P,Y))),%*%(shift_X,Grad)) +::STMT +MATRIX:g_new,s,g_old +*(/(sum(*(g_new,g_new)),sum(*(g_old,g_old))),s) +::STMT +MATRIX:centroid_placer,All_Centroids,X_samples ++(All_Centroids,%*%(centroid_placer,%*%(centroid_placer,X_samples))) +::STMT +MATRIX:C,tmp,XtZ +FLOAT:ZtZ_sum +trace(*(tmp,%*%(t(C),/(XtZ,ZtZ_sum)))) +::STMT +MATRIX:ytest +FLOAT:mean_y_test,int501,int192 +LITERAL_FLOAT:1.0 +/(-(sum(^(ytest,int501)),*($1:nrow(ytest),^(mean_y_test,int192))),-($1,1.0)) +::STMT +FLOAT:m2 +LITERAL_FLOAT:1.0004995004995005,44.75488800120049 +/(sqrt(*(1.0004995004995005,m2)),44.75488800120049) +::STMT +LITERAL_FLOAT:0.5107539184552492 +0.5107539184552492 +::STMT +FLOAT:Woutc20,Houtc20,F1 +LITERAL_FLOAT:1.0 ++(*(*(F1,Houtc20),Woutc20),1.0) +::STMT +LITERAL_FLOAT:1.0005 +1.0005 +::STMT +MATRIX:e_r_rev_agg,Xi_agg_rev_agg,X_agg +LITERAL_FLOAT:2.0 +/(*(X_agg,Xi_agg_rev_agg),^(e_r_rev_agg,2.0)) +::STMT +LITERAL_FLOAT:12.0,4.0 +*(12.0,4.0) +::STMT +MATRIX:z +FLOAT:trust_delta_sq +-(sum(*(z,z)),trust_delta_sq) +::STMT +LITERAL_FLOAT:1.0E-12 +INT:int210,int691 +rand(int691,int210,1.0E-12,1.0E-12) +::STMT +MATRIX:X +LITERAL_FLOAT:0.0 +-(0.0,cast.MATRIX(sum(X))) +::STMT +MATRIX:parsertemp443530,parsertemp443534,resp,parsertemp443533,X +FLOAT:float582 +LITERAL_FLOAT:2.22E-16 +%*%(*(t(/(parsertemp443533,parsertemp443534)),+(colSums(resp),2.22E-16)),/(%*%(t(resp),X),t(+(parsertemp443530,float582)))) +::STMT +FLOAT:i,j +LITERAL_FLOAT:1.0,10.0 ++(*(-(i,1.0),10.0),j) +::STMT +FLOAT:norm_r2_LS,p_LS,parsertemp170552,lambda_LS +/(norm_r2_LS,*(cast.FLOAT(p_LS),+(*(parsertemp170552,p_LS),*(lambda_LS,p_LS)))) +::STMT +FLOAT:q +LITERAL_FLOAT:1.0,10000.0 +*(10000.0,-(q,1.0)) +::STMT +LITERAL_FLOAT:12.0,8.0 +*(12.0,8.0) +::STMT +MATRIX:parsertemp472359,I +LITERAL_FLOAT:0.0 +*(I,==(*(t(parsertemp472359),I),0.0)) +::STMT +MATRIX:Y +sum(==(Y,min(Y))) +::STMT +FLOAT:var_lag,xq_lag,arch_coef,var_coef,a0 +INT:int818,int723 +rand(int818,int723,+(+(a0,*(arch_coef,xq_lag)),*(var_coef,var_lag)),+(+(a0,*(arch_coef,xq_lag)),*(var_coef,var_lag))) +::STMT +MATRIX:means,parsertemp560530 +LITERAL_FLOAT:1.0 +/(sum(<(*(means,parsertemp560530),1.0)),*(nrow(means),ncol(means))) +::STMT +MATRIX:classCounts +LITERAL_FLOAT:10000.0 +/(classCounts,10000.0) +::STMT +MATRIX:ones,classFeatureCounts +LITERAL_FLOAT:1.0,750.0 +%*%(+(rowSums(classFeatureCounts),*(750.0,1.0)),ones) +::STMT +MATRIX:Y_prob +LITERAL_FLOAT:0.0,1.0 +*(Y_prob,-(1.0,<=(Y_prob,0.0))) +::STMT +LITERAL_FLOAT:12.0 +*(12.0,12.0) +::STMT +MATRIX:P,R,I,L +LITERAL_FLOAT:0.0 +*(==(%*%(P,I),0.0),%*%(%*%(P,L),R)) +::STMT +MATRIX:E +LITERAL_FLOAT:2.0,0.5 +*(0.5,sum(^(E,2.0))) +::STMT +LITERAL_FLOAT:12.0,40.0 +*(12.0,40.0) +::STMT +MATRIX:P,X,Y +LITERAL_FLOAT:2.0 +^(%*%(t(X),-(P,Y)),2.0) +::STMT +MATRIX:r,parsertemp1936,parsertemp1937 +FLOAT:parsertemp1941,norm_r2 +*(+(r,*(/(norm_r2,parsertemp1941),+(parsertemp1936,parsertemp1937))),+(r,*(/(norm_r2,parsertemp1941),+(parsertemp1936,parsertemp1937)))) +::STMT +MATRIX:linear_terms +FLOAT:link_power +LITERAL_FLOAT:2.0 +/(^(linear_terms,/(2.0,link_power)),2.0) +::STMT +MATRIX:linear_terms +FLOAT:var_power +LITERAL_FLOAT:1.0 +/(linear_terms,-(1.0,var_power)) +::STMT +MATRIX:Y_prob,Y +-(*(Y,Y_prob),*(Y,Y_prob)) +::STMT +MATRIX:P +LITERAL_FLOAT:1.0,100.0 +INT:int801,int859 +%*%(P,rand(int859,int801,1.0,100.0)) +::STMT +FLOAT:502_strideh,502_padh,int986,502_Hin,502_Hf +LITERAL_FLOAT:2.0 ++(-(*(502_strideh,-(502_Hin,int986)),*(2.0,502_padh)),502_Hf) +::STMT +MATRIX:parsertemp195899 +FLOAT:center +LITERAL_FLOAT:1.0 +t(-(1.0,abs(-(parsertemp195899,center)))) +::STMT +MATRIX:parsertemp539203 +FLOAT:int999 +LITERAL_FLOAT:2.0,0.6666666666666666 +min(^(/(*(parsertemp539203,int999),2.0),0.6666666666666666)) +::STMT +MATRIX:parsertemp32833,parsertemp32842,X,Y,parsertemp32827,parsertemp32824,K,parsertemp32839 +FLOAT:x +LITERAL_FLOAT:1.0 ++(*(-(*(K,parsertemp32833),-(Y,Y)),-(1.0,/(parsertemp32824,parsertemp32827))),*(+(*(parsertemp32839,parsertemp32842),-(Y,Y)),/(-(x,X),-(X,X)))) +::STMT +MATRIX:X,Y,out +%*%(t(X),*(out,Y)) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,33.0 ++(*(-(i,1.0),33.0),1.0) +::STMT +MATRIX:lambda,parsertemp149248,V,X,P_1K,parsertemp149251 ++(%*%(t(X),-(*(P_1K,parsertemp149248),*(P_1K,parsertemp149251))),*(lambda,V)) +::STMT +MATRIX:rowSums_X_sq +LITERAL_FLOAT:0.5 +/(0.5,max(sqrt(rowSums_X_sq))) +::STMT +MATRIX:X,Y,K +-(*(cast.FLOAT(K),-(cast.FLOAT(X),cast.FLOAT(X))),-(cast.FLOAT(Y),cast.FLOAT(Y))) +::STMT +LITERAL_FLOAT:110.0,3000.0 +*(3000.0,110.0) +::STMT +MATRIX:s +FLOAT:int741,alpha,n +LITERAL_FLOAT:1.0 +*(-(1.0,alpha),-(*(/(int741,s),n),1.0)) +::STMT +LITERAL_FLOAT:3.0,5.0,2000.0 +*(+(2000.0,5.0),-(2000.0,3.0)) +::STMT +MATRIX:the_exp +FLOAT:int91,int490 +LITERAL_FLOAT:1.0,1.0E7 +*(-(1.0,==(+(int91,the_exp),1.0E7)),-(1.0,exp(-(int490,the_exp)))) +::STMT +FLOAT:parsertemp557354,parsertemp557356,prob_true +/(*(prob_true,parsertemp557354),parsertemp557356) +::STMT +MATRIX:parsertemp42288,_sbcvar332,parsertemp42289 +FLOAT:meanX +LITERAL_FLOAT:9999.0,0.5 +*(/(_sbcvar332,9999.0),-(+(-(parsertemp42288,parsertemp42289),0.5),meanX)) +::STMT +MATRIX:_sbcvar11 +LITERAL_FLOAT:1000.0 +/(%*%(rowSums(_sbcvar11),colSums(_sbcvar11)),1000.0) +::STMT +MATRIX:parsertemp436682 +FLOAT:d +t(*(d,parsertemp436682)) +::STMT +MATRIX:parsertemp31023,parsertemp31025 +LITERAL_FLOAT:2.0,99.0,990000.0 +/(^(/(-(parsertemp31023,parsertemp31025),99.0),2.0),990000.0) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,32.0 ++(*(-(i,1.0),32.0),1.0) +::STMT +FLOAT:alpha_LS,r_LS,norm_r2_LS,p_LS,int933 +LITERAL_FLOAT:0.0 ++(-(0.0,+(r_LS,*(alpha_LS,p_LS))),*(/(^(r_LS,int933),norm_r2_LS),cast.FLOAT(p_LS))) +::STMT +MATRIX:resp,mean,X +*(mean,%*%(t(resp),X)) +::STMT +MATRIX:mW2,dW2 +FLOAT:193_beta1 +LITERAL_FLOAT:1.0 ++(*(193_beta1,mW2),*(-(1.0,193_beta1),dW2)) +::STMT +FLOAT:idx +LITERAL_FLOAT:1.0,12.0 ++(-(12.0,idx),1.0) +::STMT +MATRIX:_sbcvar1716 +LITERAL_FLOAT:30.0 ++(30.0,nrow(_sbcvar1716)) +::STMT +FLOAT:sig,q,mu,int505 +LITERAL_FLOAT:1.0,4.0 +-(1.0,/(-(q,*(int505,mu)),*(4.0,*(sig,sig)))) +::STMT +MATRIX:posSamples,posSampleMeans +FLOAT:int950,int417 +LITERAL_FLOAT:6999.0,7000.0 +/(-(colSums(^(posSamples,int950)),*(7000.0,^(posSampleMeans,int417))),6999.0) +::STMT +MATRIX:dout,X +LITERAL_FLOAT:0.0 +*(>(X,0.0),dout) +::STMT +MATRIX:features,beta_unscaled +FLOAT:intercept ++(%*%(features,beta_unscaled),intercept) +::STMT +MATRIX:X_batch,mW1,parsertemp146957,187_dX +FLOAT:beta1 +LITERAL_FLOAT:1.0 ++(*(beta1,mW1),*(-(1.0,beta1),%*%(t(X_batch),*(parsertemp146957,187_dX)))) +::STMT +FLOAT:parsertemp40813,m2,mu +LITERAL_FLOAT:5.0 +-(mu,*(5.0,sqrt(*(parsertemp40813,m2)))) +::STMT +MATRIX:Y,linear_terms +-(Y,exp(linear_terms)) +::STMT +LITERAL_FLOAT:61.0,4.0 +/(61.0,4.0) +::STMT +MATRIX:qLow,length +<(length,qLow) +::STMT +MATRIX:inactive_set,w +FLOAT:int224 +sum(abs(-(inactive_set,!=(w,int224)))) +::STMT +MATRIX:W1_rand,stds,parsertemp393478 +LITERAL_FLOAT:0.07261134713572442 +t(%*%(*(0.07261134713572442,W1_rand),t(/(parsertemp393478,stds)))) +::STMT +LITERAL_FLOAT:1.0004995004995005 +1.0004995004995005 +::STMT +LITERAL_FLOAT:12.0,2.0 +*(12.0,2.0) +::STMT +MATRIX:parsertemp496901 +FLOAT:std +cast.MATRIX(*(cast.FLOAT(parsertemp496901),std)) +::STMT +FLOAT:m2 +LITERAL_FLOAT:1.0,2003.0 +*(/(2003.0,-(2003.0,1.0)),m2) +::STMT +MATRIX:Y,parsertemp2796,Xw +LITERAL_FLOAT:0.0,1.0 +*(>(-(1.0,*(Y,Xw)),0.0),-(1.0,*(Y,+(Xw,parsertemp2796)))) +::STMT +LITERAL_FLOAT:3.4011973816621555 +3.4011973816621555 +::STMT +MATRIX:parsertemp396420,W4_rand,parsertemp396423 +LITERAL_FLOAT:0.08681986202598489 +t(%*%(*(0.08681986202598489,W4_rand),t(/(parsertemp396420,parsertemp396423)))) +::STMT +LITERAL_FLOAT:Infinity +INT:int207,parsertemp163324 +rand(parsertemp163324,int207,Infinity,Infinity) +::STMT +LITERAL_FLOAT:1.0 +INT:int223,int713 +rand(int223,int713,1.0,1.0) +::STMT +LITERAL_FLOAT:-1.0 +INT:int121,n +rand(n,int121,-1.0,-1.0) +::STMT +LITERAL_FLOAT:-1.0,1.0 +INT:num_hidden1,m +rand(num_hidden1,m,-1.0,1.0) +::STMT +MATRIX:parsertemp16858 +LITERAL_FLOAT:1.0E-6 +*(<(sqrt(rowSums(parsertemp16858)),1.0E-6),1.0E-6) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,3.0,32.0 ++(*(-(i,1.0),32.0),3.0) +::STMT +MATRIX:parsertemp129018 +LITERAL_FLOAT:2.0 +*(max(parsertemp129018),2.0) +::STMT +LITERAL_FLOAT:2.0,64.0 +/(64.0,2.0) +::STMT +MATRIX:p,parsertemp477949,parsertemp477948 +FLOAT:norm_r2 +/(norm_r2,sum(*(p,%*%(parsertemp477948,parsertemp477949)))) +::STMT +MATRIX:residual_matrix +FLOAT:273_lambda +LITERAL_FLOAT:2.0 +/(^(sum(residual_matrix),2.0),+(nrow(residual_matrix),273_lambda)) +::STMT +MATRIX:X +LITERAL_FLOAT:1.0,-1.0 ++(1.0,exp(*(X,-1.0))) +::STMT +MATRIX:prediction,target +LITERAL_FLOAT:1.0 +/(*(/(1.0,nrow(target)),-(prediction,target)),*(prediction,-(1.0,prediction))) +::STMT +MATRIX:parsertemp44107,parsertemp44109,wnew +LITERAL_FLOAT:2.0 +^(+(wnew,*(2.0,%*%(parsertemp44107,parsertemp44109))),2.0) +::STMT +LITERAL_FLOAT:1.0,2.0 +INT:int199,parsertemp282730 +rand(parsertemp282730,int199,1.0,2.0) +::STMT +MATRIX:R,parsertemp40216,parsertemp40215,parsertemp40225 +FLOAT:level +/(+(R,rowSums(*(parsertemp40216,parsertemp40225))),+(R,rowSums(==(parsertemp40215,level)))) +::STMT +MATRIX:r,d +FLOAT:r2 +*(/(cast.FLOAT(%*%(r,r)),r2),d) +::STMT +MATRIX:parsertemp130418 +LITERAL_FLOAT:4.0 +*(max(parsertemp130418),4.0) +::STMT +MATRIX:lambda,scale_X,gXY,beta +FLOAT:int164 +t(+(*(scale_X,-(int164,gXY)),*(lambda,beta))) +::STMT +MATRIX:ss,se +FLOAT:130_eAvg,130_alpha +LITERAL_FLOAT:1.0 +*(130_alpha,-(/(/(se,ss),130_eAvg),1.0)) +::STMT +MATRIX:D,parsertemp570375,classMeans +%*%(%*%(-(D,classMeans),parsertemp570375),t(-(D,classMeans))) +::STMT +FLOAT:nc +LITERAL_FLOAT:1.0,10.0 +*(+(10.0,1.0),-(nc,1.0)) +::STMT +LITERAL_FLOAT:3.0,5.0,2003.0 +*(+(2003.0,5.0),-(2003.0,3.0)) +::STMT +FLOAT:FN,FP,TN,TP +*(*(*(+(TP,FP),+(TP,FN)),+(TN,FP)),+(TN,FN)) +::STMT +LITERAL_FLOAT:64.0,8.0 +/(64.0,8.0) +::STMT +MATRIX:parsertemp170238 +FLOAT:float74 +LITERAL_FLOAT:1.0,1.061405429 +*(/(1.0,+(1.0,*(parsertemp170238,float74))),1.061405429) +::STMT +MATRIX:W,H +LITERAL_FLOAT:1.0E-8 ++(%*%(%*%(t(W),W),H),1.0E-8) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:1.0,750.0 ++(rowSums(classFeatureCounts),*(750.0,1.0)) +::STMT +MATRIX:X,outlierFilter +LITERAL_FLOAT:0.0 +*(==(outlierFilter,0.0),X) +::STMT +MATRIX:Y,linear_terms +FLOAT:link_power +LITERAL_FLOAT:1.0 +-(Y,^(linear_terms,/(1.0,link_power))) +::STMT +LITERAL_FLOAT:4.0,64.0 +/(64.0,4.0) +::STMT +FLOAT:m2X +LITERAL_FLOAT:1.0005 +sqrt(*(m2X,1.0005)) +::STMT +MATRIX:parsertemp460644 +LITERAL_FLOAT:0.0625,1.4142135623730951 +/(*(parsertemp460644,0.0625),1.4142135623730951) +::STMT +MATRIX:_sbcvar415,X2 +LITERAL_FLOAT:0.050000000000000044,1.0 +*(0.050000000000000044,-(/(nrow(X2),_sbcvar415),1.0)) +::STMT +MATRIX:lambda,scale_X,p_CG,w,X,parsertemp285715 ++(*(lambda,p_CG),%*%(diag(scale_X),%*%(t(X),*(w,parsertemp285715)))) +::STMT +MATRIX:X +FLOAT:2917_N,2917_split +LITERAL_FLOAT:1.0 ++(-(nrow(X),round(*(2917_N,2917_split))),1.0) +::STMT +MATRIX:C,X +FLOAT:int301 +LITERAL_FLOAT:-2.0 ++(*(-2.0,%*%(X,t(C))),t(rowSums(^(C,int301)))) +::STMT +MATRIX:Y_counts,Y,avg_tot_Y +LITERAL_FLOAT:2.0 +colSums(^(-(Y,%*%(Y_counts,avg_tot_Y)),2.0)) +::STMT +MATRIX:parsertemp555766,target +LITERAL_FLOAT:1.0 +*(-(1.0,target),parsertemp555766) +::STMT +MATRIX:samples_vs_runs_map,centroid_placer,X_samples +LITERAL_FLOAT:2.0 +%*%(samples_vs_runs_map,rowSums(^(%*%(centroid_placer,X_samples),2.0))) +::STMT +MATRIX:parsertemp285718,p_CG,shift_X,parsertemp285720,temp_CG +sum(*(p_CG,+(+(parsertemp285718,parsertemp285720),%*%(shift_X,temp_CG)))) +::STMT +LITERAL_FLOAT:3.0,5.0,2001.0 +*(+(2001.0,5.0),-(2001.0,3.0)) +::STMT +MATRIX:parsertemp386457,parsertemp386448,parsertemp386451,parsertemp386453,withinEps +FLOAT:int257,int227 +LITERAL_FLOAT:0.0 +*(*(>(*(parsertemp386448,withinEps),0.0),&(==(parsertemp386451,int257),>(parsertemp386453,int227))),parsertemp386457) +::STMT +MATRIX:W +LITERAL_FLOAT:1.0,6.0 +*(*(6.0,sum(round(W))),-(sum(round(W)),1.0)) +::STMT +MATRIX:var_X_cols,parsertemp1522 +FLOAT:int590 +LITERAL_FLOAT:1.0 +/(1.0,sqrt(+(*(var_X_cols,parsertemp1522),<=(var_X_cols,int590)))) +::STMT +LITERAL_FLOAT:1.0,2003.0 +/(2003.0,-(2003.0,1.0)) +::STMT +MATRIX:mu +cast.FLOAT(*(mu,mu)) +::STMT +MATRIX:posSamples,posSampleMeans +FLOAT:int131,int672 +LITERAL_FLOAT:1.0,2000.0 +/(-(colSums(^(posSamples,int672)),*(2000.0,^(posSampleMeans,int131))),-(2000.0,1.0)) +::STMT +MATRIX:parsertemp43993,d,Hd,parsertemp44001 +*(cast.FLOAT(/(sum(parsertemp43993),%*%(parsertemp44001,Hd))),d) +::STMT +MATRIX:parsertemp399256,W4_rand,parsertemp399259 +LITERAL_FLOAT:0.08725945907447251 +t(%*%(*(0.08725945907447251,W4_rand),t(/(parsertemp399256,parsertemp399259)))) +::STMT +MATRIX:d,X,logisticD +*(logisticD,%*%(X,d)) +::STMT +MATRIX:P,I,X2 +LITERAL_FLOAT:0.0 +!=(*(t(%*%(X2,P)),I),0.0) +::STMT +MATRIX:is_zero_y_corr,is_one_y_corr,parsertemp171113 +FLOAT:parsertemp171116 +-(parsertemp171113,*(parsertemp171116,+(is_zero_y_corr,is_one_y_corr))) +::STMT +MATRIX:b,X +*(X,exp(%*%(X,b))) +::STMT +MATRIX:W4_rand +LITERAL_FLOAT:0.08725945907447251 +*(0.08725945907447251,W4_rand) +::STMT +FLOAT:i,n +LITERAL_FLOAT:-1.0,3.0 +*(n,^(3.0,*(i,-1.0))) +::STMT +MATRIX:2700_X,2700_W,2726_dpred,parsertemp459177,2699_probs +LITERAL_FLOAT:5.0E-4 ++(%*%(t(2700_X),-(*(2726_dpred,2699_probs),*(2699_probs,parsertemp459177))),*(5.0E-4,2700_W)) +::STMT +LITERAL_FLOAT:0.0,1.0 +INT:int840,int752,int382,int905 ++(%*%(rand(int382,int905,0.0,1.0),rand(int840,int752,0.0,1.0)),0.0) +::STMT +MATRIX:ts +LITERAL_FLOAT:4.0 +-(length(ts),4.0) +::STMT +MATRIX:linear_terms,Y +LITERAL_FLOAT:1.0 +*(^(exp(linear_terms),1.0),-(Y,exp(linear_terms))) +::STMT +MATRIX:linear_terms,Y +LITERAL_FLOAT:-1.0 +*(^(exp(linear_terms),-1.0),-(Y,exp(linear_terms))) +::STMT +MATRIX:W2_rand +LITERAL_FLOAT:0.5107539184552492 +*(0.5107539184552492,W2_rand) +::STMT +MATRIX:r +LITERAL_FLOAT:0.0,9.999999999999998E-15 +*(-(0.0,cast.FLOAT(%*%(r,r))),9.999999999999998E-15) +::STMT +FLOAT:p,i +LITERAL_FLOAT:1.0 +-(+(p,1.0),i) +::STMT +LITERAL_FLOAT:1.0,6.0,2000.0 +*(*(6.0,2000.0),-(2000.0,1.0)) +::STMT +MATRIX:s,g_old +FLOAT:step_sz +*(step_sz,cast.FLOAT(%*%(t(s),g_old))) +::STMT +MATRIX:lambda,parsertemp171604,beta,parsertemp171603 +LITERAL_FLOAT:2.0 +sum(^(+(+(parsertemp171603,parsertemp171604),*(lambda,beta)),2.0)) +::STMT +FLOAT:parsertemp40812,m2,int666 +LITERAL_FLOAT:5.0 +*(5.0,sqrt(*(/(int666,parsertemp40812),m2))) +::STMT +MATRIX:output,outputR,leading_NA ++(*(outputR,leading_NA),output) +::STMT +MATRIX:scale_X,parsertemp274081 +FLOAT:N +LITERAL_FLOAT:0.0 +*(-(0.0,/(t(parsertemp274081),N)),scale_X) +::STMT +MATRIX:parsertemp389187,parsertemp389190 +FLOAT:int284,int38 +LITERAL_FLOAT:1.0,2.0 +-(1.0,^(/(-(parsertemp389187,int284),+(parsertemp389190,int38)),2.0)) +::STMT +MATRIX:p,q,parsertemp1939 +FLOAT:norm_r2 +*(/(norm_r2,cast.FLOAT(%*%(parsertemp1939,q))),p) +::STMT +MATRIX:X +LITERAL_FLOAT:-1.0,2.0 +^(*(t(colSums(X)),-1.0),2.0) +::STMT +MATRIX:key_unique,key +t(==(key_unique,t(key))) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,42.0 ++(*(-(i,1.0),42.0),1.0) +::STMT +MATRIX:P ++(P,t(P)) +::STMT +MATRIX:ss +FLOAT:130_n +/(130_n,ss) +::STMT +MATRIX:Xm,Z,parsertemp265713 +cast.FLOAT(%*%(colSums(%*%(Z,parsertemp265713)),rowSums(t(Xm)))) +::STMT +MATRIX:W +LITERAL_FLOAT:1.0,2.0 +*(-(sum(round(W)),1.0),-(sum(round(W)),2.0)) +::STMT +MATRIX:out3,parsertemp146931,parsertemp146929,184_unnorm_probs,parsertemp146936,184_scores,parsertemp146933 +*(/(exp(-(out3,parsertemp146933)),rowSums(exp(184_scores))),rowSums(*(*(parsertemp146929,parsertemp146931),/(184_unnorm_probs,parsertemp146936)))) +::STMT +MATRIX:p_LS,parsertemp170552 +FLOAT:lambda_LS +sum(*(p_LS,+(%*%(parsertemp170552,p_LS),*(lambda_LS,p_LS)))) +::STMT +FLOAT:ss3,ss2,int486,ssPrev,Fn,m,n +/(/(-(+(Fn,ss2),*(int486,ss3)),*(n,m)),ssPrev) +::STMT +FLOAT:a,b,c +LITERAL_FLOAT:2.0,4.0 +-(^(b,2.0),*(*(4.0,a),c)) +::STMT +MATRIX:parsertemp16858,parsertemp16867,parsertemp16865,77_X_row_norm +FLOAT:float257,float144 +LITERAL_FLOAT:1.0E-6 +%*%(+(sqrt(rowSums(parsertemp16858)),*(<(77_X_row_norm,float144),1.0E-6)),t(+(sqrt(parsertemp16865),*(parsertemp16867,float257)))) +::STMT +MATRIX:WM +sum(WM) +::STMT +MATRIX:X +FLOAT:parsertemp78,parsertemp80 +/(-(X,parsertemp78),sqrt(parsertemp80)) +::STMT +MATRIX:Train,2342_m_colmin +LITERAL_FLOAT:2.0 +*(2.0,-(Train,2342_m_colmin)) +::STMT +MATRIX:E,O +*(sum(-(O,E)),sum(-(O,E))) +::STMT +MATRIX:D,parsertemp10958 +%*%(D,t(parsertemp10958)) +::STMT +FLOAT:run_index +LITERAL_FLOAT:96.0 +*(96.0,run_index) +::STMT +FLOAT:padh,int343,parsertemp195863,strideh,out_padh,Hf ++(+(-(*(strideh,parsertemp195863),*(int343,padh)),Hf),out_padh) +::STMT +MATRIX:P,Z,ZERODIAG,parsertemp220891 +FLOAT:int1,parsertemp220894 +rowSums(*(-(P,/(Z,parsertemp220894)),*(/(int1,parsertemp220891),ZERODIAG))) +::STMT +MATRIX:parsertemp386457,parsertemp386459,parsertemp386449,parsertemp386452,parsertemp386454 +FLOAT:int981 +-(*(*(>(parsertemp386449,int981),&(parsertemp386452,parsertemp386454)),parsertemp386457),parsertemp386459) +::STMT +MATRIX:p_CG,z +*(cast.FLOAT(%*%(t(p_CG),z)),cast.FLOAT(%*%(t(p_CG),z))) +::STMT +MATRIX:Q1,X,IQR +FLOAT:k +<(X,-(Q1,*(k,IQR))) +::STMT +MATRIX:Q3,X,IQR +FLOAT:k +>(X,+(Q3,*(k,IQR))) +::STMT +MATRIX:ubScores,fSizes,parsertemp31451 +FLOAT:int463,minsc,level,int864 +&(&(fSizes,&(>(ubScores,minsc),>(ubScores,int463))),==(rowSums(!=(parsertemp31451,int864)),level)) +::STMT +LITERAL_FLOAT:53.0,8.0 +/(53.0,8.0) +::STMT +MATRIX:pearson_residual_sq +LITERAL_FLOAT:900.0 +/(sum(pearson_residual_sq),900.0) +::STMT +MATRIX:W +FLOAT:int267,wt,int283 +LITERAL_FLOAT:1.0,3.0,6.0 +/(*(*(6.0,sum(W)),-(sum(W),1.0)),*(*(-(wt,int283),+(wt,int267)),+(sum(W),3.0))) +::STMT +MATRIX:linear_terms +FLOAT:var_power,link_power +LITERAL_FLOAT:2.0 +^(linear_terms,/(-(2.0,var_power),link_power)) +::STMT +FLOAT:m2X,W,float189 +sqrt(*(m2X,/(W,-(W,float189)))) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:2.0 +/(exp(*(linear_terms,2.0)),2.0) +::STMT +LITERAL_FLOAT:7.996E9 +7.996E9 +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:1.0,500.0 +INT:int259,int839 +%*%(+(rowSums(classFeatureCounts),*(500.0,1.0)),rand(int839,int259,1.0,1.0)) +::STMT +FLOAT:522_strideh,parsertemp193444,522_Hin +LITERAL_FLOAT:1.0 ++(/(-(+(522_Hin,parsertemp193444),1.0),522_strideh),1.0) +::STMT +MATRIX:R,dssp,parsertemp40220 +FLOAT:numRows +LITERAL_FLOAT:1.0 +-(/(numRows,-(+(R,dssp),rowSums(parsertemp40220))),1.0) +::STMT +MATRIX:parsertemp171377,Y_prob,Y,parsertemp171381 +FLOAT:float771 +LITERAL_FLOAT:2.0 +/(^(rowSums(Y),2.0),*(*(*(parsertemp171377,Y_prob),Y_prob),^(*(parsertemp171381,float771),2.0))) +::STMT +MATRIX:finite_linear_terms +LITERAL_FLOAT:1.0E7 ++(1.0E7,exp(finite_linear_terms)) +::STMT +MATRIX:pt_gp,Y,linear_terms,the_gauss_exp +FLOAT:int79,int185 +LITERAL_FLOAT:0.5 ++(-(Y,*(rowSums(Y),>=(linear_terms,int185))),*(*(*(the_gauss_exp,pt_gp),rowSums(Y)),-(>=(linear_terms,int79),0.5))) +::STMT +MATRIX:parsertemp1516,parsertemp1514 +FLOAT:n +LITERAL_FLOAT:0.0,1.0 +<=(/(-(t(parsertemp1514),*(n,parsertemp1516)),-(n,1.0)),0.0) +::STMT +MATRIX:err,ncCnts,maxsc,cCnts +FLOAT:int684,int597,float897,minSup +sum(&(&(>=(cCnts,minSup),>(err,int684)),|(>(ncCnts,int597),>(maxsc,float897)))) +::STMT +FLOAT:i1 +LITERAL_FLOAT:1.0,2.0 ++(1.0,*(i1,2.0)) +::STMT +LITERAL_FLOAT:-1.453152027 +-1.453152027 +::STMT +MATRIX:s +LITERAL_FLOAT:2.0 ++(s,2.0) +::STMT +FLOAT:i,cols,n +LITERAL_FLOAT:1.0 ++(-(n,-(+(i,cols),1.0)),1.0) +::STMT +MATRIX:means,parsertemp560511,parsertemp560515 +FLOAT:int468 +LITERAL_FLOAT:2.0 +-(rowSums(*(means,^(parsertemp560515,int468))),^(rowSums(*(means,parsertemp560511)),2.0)) +::STMT +MATRIX:X +FLOAT:m2X +LITERAL_FLOAT:1.0 +*(m2X,/(nrow(X),-(nrow(X),1.0))) +::STMT +MATRIX:parsertemp222331 +FLOAT:sample_block_size +LITERAL_FLOAT:0.5 ++(0.5,/(parsertemp222331,sample_block_size)) +::STMT +MATRIX:parsertemp387405,Ks,Kss +abs(-(cast.FLOAT(Kss),cast.FLOAT(%*%(parsertemp387405,Ks)))) +::STMT +MATRIX:X +LITERAL_FLOAT:1.0 ++(ncol(X),1.0) +::STMT +MATRIX:imputed_Y +LITERAL_FLOAT:NaN ++(imputed_Y,NaN) +::STMT +MATRIX:X_batch,parsertemp389604,parsertemp389600,parsertemp389601 +FLOAT:int708,int998 +LITERAL_FLOAT:1.0,2.0 +*(-(/(-(parsertemp389600,int708),+(parsertemp389600,int998)),X_batch),-(1.0,^(/(parsertemp389601,parsertemp389604),2.0))) +::STMT +MATRIX:parsertemp146961,dout1,mW1 +FLOAT:191_t,191_lr,191_beta1,parsertemp146980,int721 +LITERAL_FLOAT:1.0 +*(/(*(191_lr,sqrt(parsertemp146980)),-(1.0,^(191_beta1,191_t))),+(*(191_beta1,mW1),*(-(int721,191_beta1),%*%(parsertemp146961,dout1)))) +::STMT +MATRIX:q_CG,z +FLOAT:parsertemp170094,pp_CG,pq_CG +LITERAL_FLOAT:0.5 ++(*(*(0.5,/(parsertemp170094,pp_CG)),pq_CG),*(cast.FLOAT(z),cast.FLOAT(q_CG))) +::STMT +MATRIX:Y +FLOAT:minv +sum(==(Y,minv)) +::STMT +FLOAT:i +LITERAL_FLOAT:100.0 +*(i,100.0) +::STMT +MATRIX:X +LITERAL_FLOAT:0.0 ++(ncol(X),0.0) +::STMT +MATRIX:Hdiff,betamax,beta +FLOAT:int175,int467 +LITERAL_FLOAT:1.0E20 +!=(+(*(>=(Hdiff,int467),betamax),*(<(Hdiff,int175),beta)),1.0E20) +::STMT +MATRIX:B +FLOAT:ncolX +-(ncolX,nrow(B)) +::STMT +MATRIX:surv,se_surv +FLOAT:z_alpha_2 +LITERAL_FLOAT:-1.0 +/(*(*(z_alpha_2,-1.0),se_surv),surv) +::STMT +MATRIX:X +LITERAL_FLOAT:1.0 +<(X,1.0) +::STMT +MATRIX:parsertemp170239 +FLOAT:float481 +LITERAL_FLOAT:1.0,1.061405429,-1.453152027 ++(-1.453152027,*(/(1.0,+(float481,parsertemp170239)),1.061405429)) +::STMT +MATRIX:R,parsertemp503780 +%*%(t(+(R,diag(parsertemp503780))),+(R,diag(parsertemp503780))) +::STMT +FLOAT:var_power +LITERAL_FLOAT:2.0 +-(2.0,var_power) +::STMT +FLOAT:featureCorrection +LITERAL_FLOAT:0.0 +-(0.0,featureCorrection) +::STMT +MATRIX:parsertemp500606,parsertemp500607,parsertemp500604,w,parsertemp500610 +FLOAT:int952 +%*%(t(-(*(parsertemp500607,parsertemp500610),w)),-(*(*(parsertemp500604,parsertemp500606),>(parsertemp500606,int952)),w)) +::STMT +MATRIX:parsertemp472316,parsertemp472314,ig +FLOAT:min_leaf +rev(*(&(>=(parsertemp472314,min_leaf),>=(parsertemp472316,min_leaf)),ig)) +::STMT +MATRIX:parsertemp31026,parsertemp31033 +FLOAT:parsertemp31034,parsertemp31027 +LITERAL_FLOAT:150.0,100.0 +sqrt(+(/(/(parsertemp31026,parsertemp31027),100.0),/(/(parsertemp31033,parsertemp31034),150.0))) +::STMT +FLOAT:m2 +LITERAL_FLOAT:1.0005002501250626 +sqrt(*(1.0005002501250626,m2)) +::STMT +MATRIX:is_natural_parameter_log_zero,Y +FLOAT:int849 +LITERAL_FLOAT:0.0,1.0 +/(*(>(Y,0.0),is_natural_parameter_log_zero),-(1.0,*(>(Y,int849),is_natural_parameter_log_zero))) +::STMT +MATRIX:P,parsertemp222624,X +/(%*%(t(/(P,parsertemp222624)),X),t(colSums(/(P,parsertemp222624)))) +::STMT +FLOAT:m2 +LITERAL_FLOAT:1.0004995004995005,5.0 +*(5.0,sqrt(*(1.0004995004995005,m2))) +::STMT +MATRIX:Xd,out +FLOAT:int853 +sum(*(*(Xd,>(out,int853)),Xd)) +::STMT +MATRIX:id +diag(==(id,t(id))) +::STMT +MATRIX:z +FLOAT:trust_delta_sq +-(*(cast.FLOAT(z),cast.FLOAT(z)),trust_delta_sq) +::STMT +MATRIX:X,Y,out,parsertemp2798 +FLOAT:int662,int861 +%*%(t(X),*(*(>(out,int861),-(int662,parsertemp2798)),Y)) +::STMT +MATRIX:d,exp_Xb,X +*(X,*(%*%(X,d),exp_Xb)) +::STMT +MATRIX:output_values +FLOAT:log_odds +LITERAL_FLOAT:0.3 ++(log_odds,*(0.3,cast.FLOAT(output_values))) +::STMT +MATRIX:_sbcvar78 +LITERAL_FLOAT:10000.0 +/(%*%(rowSums(_sbcvar78),colSums(_sbcvar78)),10000.0) +::STMT +MATRIX:parsertemp403509,W4_rand +FLOAT:int45,int391 +LITERAL_FLOAT:0.086386842558136 +%*%(*(0.086386842558136,W4_rand),t(/(-(parsertemp403509,int391),+(parsertemp403509,int45)))) +::STMT +MATRIX:X,parsertemp32827,Y,parsertemp32824 +FLOAT:x +LITERAL_FLOAT:1.0 ++(*(-(1.0,/(parsertemp32824,parsertemp32827)),Y),*(/(-(x,X),-(X,X)),Y)) +::STMT +MATRIX:W,X ++(%*%(X,W),W) +::STMT +MATRIX:lambda,parsertemp170067,parsertemp170065,p_CG,shift_X,parsertemp170060,temp_CG ++(+(*(cast.FLOAT(lambda),cast.FLOAT(p_CG)),*(cast.FLOAT(parsertemp170060),cast.FLOAT(temp_CG))),*(cast.FLOAT(shift_X),cast.FLOAT(%*%(parsertemp170065,parsertemp170067)))) +::STMT +MATRIX:parsertemp115858,X,parsertemp115860 +FLOAT:n +LITERAL_FLOAT:0.0,1.0 +<=(/(-(t(parsertemp115858),*(n,parsertemp115860)),-(nrow(X),1.0)),0.0) +::STMT +MATRIX:I,y2 +/(%*%(I,y2),sum(I)) +::STMT +MATRIX:termination_bitmap,parsertemp441285,tmp +==(*(parsertemp441285,termination_bitmap),min(tmp)) +::STMT +MATRIX:the_exp,linear_terms,Y +FLOAT:int894 +*(*(exp(-(int894,the_exp)),exp(linear_terms)),rowSums(Y)) +::STMT +MATRIX:_sbcvar1156 +FLOAT:num_records +LITERAL_FLOAT:1.0 +*(+(num_records,1.0),-(1.0,_sbcvar1156)) +::STMT +MATRIX:parsertemp383010,U,X,X_nonzero_ind +LITERAL_FLOAT:2.0 +*(X_nonzero_ind,^(-(X,%*%(U,parsertemp383010)),2.0)) +::STMT +MATRIX:G,authorities +max(%*%(t(G),%*%(G,authorities))) +::STMT +FLOAT:i +LITERAL_FLOAT:42.0 ++(42.0,i) +::STMT +FLOAT:parsertemp13703 +LITERAL_FLOAT:1000.0 +*(parsertemp13703,1000.0) +::STMT +MATRIX:D,beta +LITERAL_FLOAT:-1.0 +*(*(D,-1.0),beta) +::STMT +LITERAL_FLOAT:1.0E-15 +1.0E-15 +::STMT +MATRIX:linear_terms +FLOAT:var_power +LITERAL_FLOAT:1.0 +/(/(1.0,linear_terms),-(1.0,var_power)) +::STMT +FLOAT:parsertemp380175,interval,i_process_item +LITERAL_FLOAT:1.0 ++(-(i_process_item,+(*(parsertemp380175,interval),1.0)),1.0) +::STMT +MATRIX:X2 +FLOAT:parsertemp31772 +-(ncol(X2),parsertemp31772) +::STMT +MATRIX:parsertemp132035,left,parsertemp132041,right +==(%*%(parsertemp132035,left),%*%(parsertemp132041,right)) +::STMT +FLOAT:int252,a,b,c,x ++(+(*(a,^(x,int252)),*(b,x)),c) +::STMT +MATRIX:parsertemp40482,totalE,l +/(t(%*%(t(totalE),==(parsertemp40482,l))),t(colSums(==(parsertemp40482,l)))) +::STMT +MATRIX:X_Train,X_Test +FLOAT:float605,float128,float454,float355 +INT:int571,int543,int998,int370 +-(+(sum(rand(int571,int370,float454,float128)),sum(rand(int998,int543,float605,float355))),+(sum(X_Train),sum(X_Test))) +::STMT +FLOAT:s_err_vars,s_err_mean +LITERAL_FLOAT:-0.001 +/(-(-0.001,s_err_mean),s_err_vars) +::STMT +FLOAT:qmle_val,_funvar2930 +LITERAL_FLOAT:1.0E-5 +/(-(_funvar2930,qmle_val),1.0E-5) +::STMT +MATRIX:W +FLOAT:m2 +LITERAL_FLOAT:2.0,3.0 +*(*(3.0,^(m2,2.0)),^(sum(round(W)),2.0)) +::STMT +MATRIX:parsertemp31338,_sbcvar264 +FLOAT:parsertemp31331,float537 +LITERAL_FLOAT:9999.0,1.0 +-(1.0,/(sum(*(parsertemp31338,_sbcvar264)),*(9999.0,/(parsertemp31331,float537)))) +::STMT +MATRIX:s,parsertemp44016 +FLOAT:delta2 +-(delta2,cast.FLOAT(%*%(t(s),-(s,parsertemp44016)))) +::STMT +LITERAL_FLOAT:6.0,2000.0 +*(6.0,2000.0) +::STMT +MATRIX:parsertemp467657,Xd,parsertemp467661 +FLOAT:dd,step_sz,wd +/(-(+(wd,*(step_sz,dd)),sum(*(parsertemp467657,Xd))),+(dd,sum(*(parsertemp467661,Xd)))) +::STMT +MATRIX:Y_counts,parsertemp560606,Y +LITERAL_FLOAT:1.0,2.0 +/(colSums(^(-(Y,parsertemp560606),2.0)),-(sum(Y_counts),1.0)) +::STMT +MATRIX:W +LITERAL_FLOAT:3.0 ++(sum(round(W)),3.0) +::STMT +MATRIX:K1 +cast.FLOAT(K1) +::STMT +MATRIX:proposer_pointers +LITERAL_FLOAT:1.0 ++(cast.FLOAT(proposer_pointers),1.0) +::STMT +MATRIX:finite_linear_terms +LITERAL_FLOAT:1.0E7 +==(+(1.0E7,exp(finite_linear_terms)),1.0E7) +::STMT +MATRIX:W +LITERAL_FLOAT:1.0 ++(sum(round(W)),1.0) +::STMT +MATRIX:parsertemp31277 +FLOAT:parsertemp31279,varY +LITERAL_FLOAT:1.0 +sqrt(-(1.0,/(sum(parsertemp31277),*(parsertemp31279,varY)))) +::STMT +MATRIX:2792_NID +LITERAL_FLOAT:1.0,2.0 ++(*(2.0,2792_NID),1.0) +::STMT +MATRIX:p,parsertemp116065,lambda,shift_X +sum(*(p,+(+(parsertemp116065,shift_X),*(lambda,p)))) +::STMT +FLOAT:191_beta2,191_t,int124 +LITERAL_FLOAT:1.0 +sqrt(-(1.0,^(191_beta2,+(191_t,int124)))) +::STMT +MATRIX:S +LITERAL_FLOAT:2.0,479.0 +/(^(diag(S),2.0),479.0) +::STMT +FLOAT:parsertemp164939,n +LITERAL_FLOAT:2.0 ++(2.0,*(n,parsertemp164939)) +::STMT +MATRIX:leaf_ids,out +FLOAT:boundary_right,boundary_left,step_size +-(+(out,&(>=(leaf_ids,boundary_left),<(leaf_ids,boundary_right))),&(!(<(leaf_ids,boundary_right)),<(leaf_ids,+(boundary_right,step_size)))) +::STMT +FLOAT:int313,int889 +LITERAL_FLOAT:0.0 +INT:int69,int17 +*(rand(int69,int17,int889,int313),0.0) +::STMT +MATRIX:X +FLOAT:x +cast.FLOAT(-(x,X)) +::STMT +MATRIX:w,yt,Xt +LITERAL_FLOAT:0.0 +sum(>(*(yt,%*%(Xt,w)),0.0)) +::STMT +MATRIX:ytest,yhat +/(sum(-(ytest,yhat)),nrow(ytest)) +::STMT +MATRIX:W,X,H +LITERAL_FLOAT:1.0E-8 +/(X,+(%*%(W,H),1.0E-8)) +::STMT +FLOAT:index +LITERAL_FLOAT:2.0 +*(index,2.0) +::STMT +MATRIX:parsertemp399243,parsertemp399246,W3_rand +LITERAL_FLOAT:0.6546536707079771 +t(%*%(*(0.6546536707079771,W3_rand),t(/(parsertemp399243,parsertemp399246)))) +::STMT +MATRIX:X,Centering +LITERAL_FLOAT:1.0,2.0 +/(colSums(^(-(X,Centering),2.0)),-(nrow(X),1.0)) +::STMT +MATRIX:X2p,maxsc +LITERAL_FLOAT:0.0 +|(>(t(colSums(X2p)),0.0),>(maxsc,0.0)) +::STMT +LITERAL_FLOAT:1.0,0.7 +-(1.0,0.7) +::STMT +MATRIX:_sbcvar92,parsertemp27718,parsertemp27720,220_E +FLOAT:220_W,float561 +LITERAL_FLOAT:2.0 +sum(/(^(-(_sbcvar92,220_E),2.0),+(*(parsertemp27720,float561),/(parsertemp27718,220_W)))) +::STMT +MATRIX:X_batch,dout1 +FLOAT:191_beta2 +LITERAL_FLOAT:1.0,2.0 +*(-(1.0,191_beta2),^(%*%(t(X_batch),dout1),2.0)) +::STMT +MATRIX:fP +FLOAT:max_values +/(^($1:ncol(fP),max_values),$1) +::STMT +FLOAT:g +LITERAL_FLOAT:1.0,2.0 +*(-(g,1.0),2.0) +::STMT +MATRIX:p,q,r,parsertemp1597,lambda +FLOAT:norm_r2 ++(r,*(/(norm_r2,sum(parsertemp1597)),+(q,*(lambda,p)))) +::STMT +MATRIX:parsertemp389212,parsertemp389214 +FLOAT:n +*(-(/(colSums(parsertemp389214),n),*(/(parsertemp389212,n),/(parsertemp389212,n))),n) +::STMT +MATRIX:y_hat,b,R +LITERAL_FLOAT:2.0 +^(-(-(b,%*%(R,y_hat)),y_hat),2.0) +::STMT +FLOAT:sample_block_size +LITERAL_FLOAT:3.0 +*(sample_block_size,3.0) +::STMT +MATRIX:is_zero_y_corr,is_one_y_corr,parsertemp317435 +FLOAT:float284 +LITERAL_FLOAT:1.0 +-(+(parsertemp317435,/(is_one_y_corr,-(float284,is_one_y_corr))),/(is_zero_y_corr,-(1.0,is_zero_y_corr))) +::STMT +MATRIX:parsertemp220853,parsertemp220854,Hneg,beta,betamin,Hpos +LITERAL_FLOAT:0.0,3.4011973816621555 +*(<(-(+(parsertemp220853,parsertemp220854),3.4011973816621555),0.0),+(beta,+(*(Hneg,betamin),*(Hpos,beta)))) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,7.0 ++(*(-(i,1.0),7.0),1.0) +::STMT +FLOAT:check_max,check_min +-(check_max,check_min) +::STMT +FLOAT:mantissa +LITERAL_FLOAT:-1.0 +*(mantissa,-1.0) +::STMT +FLOAT:m_orig +LITERAL_FLOAT:1.0 +*(m_orig,1.0) +::STMT +MATRIX:X,Y +LITERAL_FLOAT:1.0E-10 ++(+(abs(X),abs(Y)),1.0E-10) +::STMT +MATRIX:p,lambda,X +%*%(t(p),+(%*%(t(X),%*%(X,p)),*(lambda,p))) +::STMT +MATRIX:R,parsertemp40215 +FLOAT:numRows,level +/(numRows,+(R,rowSums(==(parsertemp40215,level)))) +::STMT +MATRIX:p,Z +FLOAT:norm_r2 +/(norm_r2,cast.FLOAT(%*%(t(p),%*%(Z,p)))) +::STMT +FLOAT:odds +LITERAL_FLOAT:1.0 +/(odds,-(1.0,odds)) +::STMT +MATRIX:parsertemp131906,parsertemp132092,outBucket +==(outBucket,%*%(parsertemp132092,t(parsertemp131906))) +::STMT +MATRIX:V,y +LITERAL_FLOAT:-1.0 +*(*(%*%(t(V),y),-1.0),*(%*%(t(V),y),-1.0)) +::STMT +MATRIX:p_CG +FLOAT:parsertemp254766,int972,parsertemp254749,int767,z +*(parsertemp254766,/(+(*(z,int972),sqrt(parsertemp254749)),sum(^(p_CG,int767)))) +::STMT +MATRIX:parsertemp122290,X2 +LITERAL_FLOAT:0.0,4.0 +&(>=(t(colSums(X2)),4.0),>(t(%*%(parsertemp122290,X2)),0.0)) +::STMT +FLOAT:i +LITERAL_FLOAT:8.0 +*(i,8.0) +::STMT +MATRIX:Y,parsertemp221025 +LITERAL_FLOAT:1.0 +*(/(1.0,+(Y,1.0)),+(diag(parsertemp221025),1.0)) +::STMT +MATRIX:sample_rec_ids +FLOAT:num_records +LITERAL_FLOAT:1.0 +-(1.0,<=(sample_rec_ids,num_records)) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,7.0 +*(-(i,1.0),7.0) +::STMT +FLOAT:i +LITERAL_FLOAT:7.0 +*(i,7.0) +::STMT +MATRIX:linear_terms +FLOAT:var_power +LITERAL_FLOAT:2.0 +/(linear_terms,-(2.0,var_power)) +::STMT +MATRIX:parsertemp171084,parsertemp171083,parsertemp171091 +FLOAT:float122 +LITERAL_FLOAT:-2.0,1.432788 +*(sqrt(*(-2.0,parsertemp171083)),+(1.432788,*(sqrt(parsertemp171084),+(float122,parsertemp171091)))) +::STMT +MATRIX:neighbors +LITERAL_FLOAT:0.0 +<(0.0,-(neighbors,diag(diag(neighbors)))) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,8.0 +*(-(i,1.0),8.0) +::STMT +LITERAL_FLOAT:2.302585092994046 +2.302585092994046 +::STMT +MATRIX:y_corr +LITERAL_FLOAT:3.141592653589793,0.5 +*(-(y_corr,0.5),3.141592653589793) +::STMT +MATRIX:m +FLOAT:sum +sqrt(-(m,sum)) +::STMT +MATRIX:z +LITERAL_FLOAT:2.0 +^(cast.FLOAT(z),2.0) +::STMT +FLOAT:i +LITERAL_FLOAT:12.0 +*(i,12.0) +::STMT +MATRIX:y_batch +LITERAL_FLOAT:0.0,1.0 +*(/(1.0,nrow(y_batch)),-(0.0,y_batch)) +::STMT +FLOAT:num_records +LITERAL_FLOAT:10.0 +*(num_records,10.0) +::STMT +MATRIX:parsertemp43631,parsertemp43633 +LITERAL_FLOAT:0.0,2.0 +INT:int81,int873,int500,int493 +*(+(rand(int493,int500,0.0,0.0),*(2.0,%*%(parsertemp43631,parsertemp43633))),+(rand(int81,int873,0.0,0.0),*(2.0,%*%(parsertemp43631,parsertemp43633)))) +::STMT +LITERAL_FLOAT:0.1651445647689541 +0.1651445647689541 +::STMT +FLOAT:p_CG,parsertemp170088,z,pp_CG,parsertemp170090 +LITERAL_FLOAT:-1.0 +/(+(*(*(z,p_CG),-1.0),sqrt(-(parsertemp170088,parsertemp170090))),pp_CG) +::STMT +FLOAT:index +LITERAL_FLOAT:4.0 +*(index,4.0) +::STMT +FLOAT:FN,TN,FP,TP +-(*(TP,TN),*(FP,FN)) +::STMT +MATRIX:R,S,parsertemp382932,HS +FLOAT:norm_R2,alpha ++(-(R,*(alpha,HS)),*(/(sum(parsertemp382932),norm_R2),S)) +::STMT +MATRIX:P1,P2,S ++(%*%(P1,S),%*%(P2,S)) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:0.0 +<(linear_terms,0.0) +::STMT +MATRIX:S,V +FLOAT:norm_R2,parsertemp149264 +LITERAL_FLOAT:2.0 +^(+(S,*(/(norm_R2,parsertemp149264),V)),2.0) +::STMT +MATRIX:scale_lambda +LITERAL_FLOAT:1.0E-7 +*(scale_lambda,1.0E-7) +::STMT +MATRIX:r +FLOAT:norm_r2_initial,int736 +sqrt(/(sum(^(r,int736)),norm_r2_initial)) +::STMT +MATRIX:U,V,X +LITERAL_FLOAT:2.0 +^(-(X,%*%(U,t(V))),2.0) +::STMT +LITERAL_FLOAT:0.0,1.0,2.0 +INT:int48,parsertemp282730 +>(rand(parsertemp282730,int48,1.0,2.0),0.0) +::STMT +FLOAT:int710,n +LITERAL_FLOAT:1.0,2.0,0.6 +*(-(+(-(n,int710),1.0),2.0),0.6) +::STMT +FLOAT:x_to_truncate +abs(x_to_truncate) +::STMT +MATRIX:R,dssp,dsep +FLOAT:4_eAvg +/(/(+(R,dsep),+(R,dssp)),4_eAvg) +::STMT +FLOAT:i +LITERAL_FLOAT:32.0 +*(i,32.0) +::STMT +MATRIX:_sbcvar2306 +max(t(_sbcvar2306)) +::STMT +MATRIX:class_counts +LITERAL_FLOAT:50000.0 +/(class_counts,50000.0) +::STMT +FLOAT:i +LITERAL_FLOAT:33.0 +*(i,33.0) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,33.0 +*(-(i,1.0),33.0) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,32.0 +*(-(i,1.0),32.0) +::STMT +MATRIX:p_CG +FLOAT:trust_delta_sq,int862,int622,z +-(*(*(cast.FLOAT(z),sum(p_CG)),*(cast.FLOAT(z),sum(p_CG))),*(sum(^(p_CG,int622)),-(^(z,int862),trust_delta_sq))) +::STMT +FLOAT:k +LITERAL_FLOAT:40.0 +*(k,40.0) +::STMT +FLOAT:var_power,link_power +LITERAL_FLOAT:1.0 +-(/(-(1.0,var_power),link_power),1.0) +::STMT +MATRIX:simplex +FLOAT:num_func_invoc +LITERAL_FLOAT:1.0 +-(+(num_func_invoc,ncol(simplex)),1.0) +::STMT +MATRIX:a,b,t,parsertemp32856,Y,parsertemp32827,parsertemp32824 +FLOAT:int277,int378 ++(+(*(-(int378,t),Y),*(/(parsertemp32824,parsertemp32827),Y)),*(*(/(parsertemp32824,parsertemp32827),-(int277,t)),+(*(a,parsertemp32856),*(b,t)))) +::STMT +FLOAT:i +LITERAL_FLOAT:42.0 +*(i,42.0) +::STMT +MATRIX:W +LITERAL_FLOAT:2.0 +^(sum(round(W)),2.0) +::STMT +FLOAT:i +LITERAL_FLOAT:16.0 +*(i,16.0) +::STMT +FLOAT:df,int687 +LITERAL_FLOAT:4.890349128221754 ++(int687,*(df,4.890349128221754)) +::STMT +MATRIX:parsertemp500608,parsertemp500604,parsertemp500605,X +FLOAT:lambda +LITERAL_FLOAT:0.0 +%*%(X,*(*(parsertemp500604,-(parsertemp500605,lambda)),>(-(parsertemp500608,lambda),0.0))) +::STMT +MATRIX:parsertemp459793,parsertemp459795 +FLOAT:val_loss +LITERAL_FLOAT:50.0 ++(val_loss,/(sum(*(parsertemp459793,parsertemp459795)),50.0)) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:1.0 ++(classFeatureCounts,1.0) +::STMT +MATRIX:rowSums_X_sq +LITERAL_FLOAT:1.6583123951777 +/(1.6583123951777,max(sqrt(rowSums_X_sq))) +::STMT +FLOAT:i +LITERAL_FLOAT:16.0,1.0 +*(-(i,1.0),16.0) +::STMT +MATRIX:Q,parsertemp500360 +FLOAT:int245 +%*%(parsertemp500360,t(rowSums(^(Q,int245)))) +::STMT +MATRIX:X +LITERAL_FLOAT:7.0 +<(X,7.0) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,11.0 +*(-(i,1.0),11.0) +::STMT +MATRIX:prediction,target +sum(rowSums(abs(-(prediction,target)))) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,10.0 +*(-(i,1.0),10.0) +::STMT +MATRIX:CMeans,CFreqs +FLOAT:my +LITERAL_FLOAT:2.0 +*(CFreqs,^(-(CMeans,my),2.0)) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,12.0 +*(-(i,1.0),12.0) +::STMT +MATRIX:qLow,length,qUp +|(<(length,qLow),>(length,qUp)) +::STMT +MATRIX:G,authorities +/(%*%(G,authorities),max(%*%(G,authorities))) +::STMT +MATRIX:linear_terms +FLOAT:var_power,float356 +LITERAL_FLOAT:2.0 +/(exp(*(linear_terms,-(float356,var_power))),-(2.0,var_power)) +::STMT +FLOAT:log_ten,parsertemp169812 +LITERAL_FLOAT:0.5 +-(/(parsertemp169812,log_ten),0.5) +::STMT +MATRIX:parsertemp220853,parsertemp220854,betamin +LITERAL_FLOAT:0.0,3.4011973816621555 +*(<(-(+(parsertemp220853,parsertemp220854),3.4011973816621555),0.0),betamin) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,128.0 +*(-(i,1.0),128.0) +::STMT +MATRIX:R,S,parsertemp40214 +FLOAT:level ++(R,rowSums(==(%*%(S,parsertemp40214),level))) +::STMT +MATRIX:Y,predicted_Y +LITERAL_FLOAT:0.0 +==(-(predicted_Y,Y),0.0) +::STMT +MATRIX:parsertemp31046,parsertemp31051,parsertemp31042,parsertemp31043 +FLOAT:parsertemp31049,parsertemp31054 +LITERAL_FLOAT:2.0 +round(/(^(+(parsertemp31042,parsertemp31043),2.0),+(/(parsertemp31046,parsertemp31049),/(parsertemp31051,parsertemp31054)))) +::STMT +MATRIX:is_one_y_corr,parsertemp317435 +LITERAL_FLOAT:1.0 ++(parsertemp317435,/(is_one_y_corr,-(1.0,is_one_y_corr))) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,100.0 +*(-(i,1.0),100.0) +::STMT +MATRIX:Q,R,parsertemp500308,parsertemp500300 +FLOAT:int213,int786,int864,int854 +LITERAL_FLOAT:2.0 +INT:int279,parsertemp500306,int987,parsertemp500303 +-(+(%*%(rowSums(parsertemp500300),rand(int279,parsertemp500303,int854,int213)),%*%(rand(parsertemp500306,int987,int864,int786),t(parsertemp500308))),*(2.0,%*%(R,t(Q)))) +::STMT +FLOAT:s,parsertemp454319 +LITERAL_FLOAT:3.0 +*(parsertemp454319,^(3.0,s)) +::STMT +MATRIX:parsertemp553013,M2,parsertemp553121,parsertemp553122 ++(%*%(rowSums(*(M2,M2)),parsertemp553121),t(%*%(rowSums(parsertemp553013),parsertemp553122))) +::STMT +MATRIX:Y +LITERAL_FLOAT:0.0,24.0 +-(+(nrow(Y),0.0),24.0) +::STMT +MATRIX:neighbors,corePts,withinEps +LITERAL_FLOAT:0.0 +>(*(*(neighbors,corePts),withinEps),0.0) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,61.0 +*(-(i,1.0),61.0) +::STMT +MATRIX:log_prob,log_det_chol +FLOAT:parsertemp436710,float252 +LITERAL_FLOAT:-0.5 ++(*(-0.5,+(*(parsertemp436710,float252),log_prob)),log_det_chol) +::STMT +MATRIX:linear_terms +FLOAT:int709 +LITERAL_FLOAT:1.0 +/(1.0,-(exp(-(int709,linear_terms)),1.0)) +::STMT +MATRIX:w,parsertemp43626 +FLOAT:int89 +LITERAL_FLOAT:2.0,0.5 ++(*(0.5,%*%(t(w),w)),*(2.0,sum(*(parsertemp43626,int89)))) +::STMT +MATRIX:sq_sums,mu +LITERAL_FLOAT:2.0,4.0 +-(/(sq_sums,4.0),^(cast.FLOAT(mu),2.0)) +::STMT +MATRIX:parsertemp171314,t_gp,parsertemp171318,parsertemp171306 +FLOAT:float174,int607 +LITERAL_FLOAT:1.0,2.0,0.254829592 +*(exp(/(-(int607,parsertemp171318),2.0)),*(/(1.0,+(float174,parsertemp171306)),+(0.254829592,*(t_gp,parsertemp171314)))) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,64.0 +*(-(i,1.0),64.0) +::STMT +MATRIX:neighbors,corePts,withinEps +LITERAL_FLOAT:0.0,1.0 +*(>(*(*(neighbors,corePts),withinEps),0.0),&(==(t(corePts),0.0),>(colSums(neighbors),1.0))) +::STMT +MATRIX:parsertemp220853,Ws,beta +LITERAL_FLOAT:0.0,3.4011973816621555 +<(-(+(parsertemp220853,*(beta,Ws)),3.4011973816621555),0.0) +::STMT +MATRIX:r,parsertemp500439,y +LITERAL_FLOAT:0.5 +*(0.5,cast.FLOAT(%*%(t(r),-(parsertemp500439,y)))) +::STMT +MATRIX:parsertemp1510 +FLOAT:n +LITERAL_FLOAT:2.0 +*(n,^(/(t(parsertemp1510),n),2.0)) +::STMT +MATRIX:parsertemp31910,parsertemp31913 +FLOAT:eAvg +LITERAL_FLOAT:1.0 +-(/(/(t(parsertemp31913),t(parsertemp31910)),eAvg),1.0) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,42.0 +*(-(i,1.0),42.0) +::STMT +MATRIX:shift_X,w,ssX_p_CG,X +*(cast.FLOAT(shift_X),%*%(t(X),*(w,%*%(X,ssX_p_CG)))) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:1.0,50.0 +/(+(classFeatureCounts,1.0),+(rowSums(classFeatureCounts),50.0)) +::STMT +MATRIX:X +FLOAT:parsertemp78,parsertemp80 +LITERAL_FLOAT:3.0 +^(/(-(X,parsertemp78),sqrt(parsertemp80)),3.0) +::STMT +MATRIX:W,H,X,parsertemp410975 +FLOAT:eps +*(H,%*%(t(W),/(X,+(parsertemp410975,eps)))) +::STMT +MATRIX:A +LITERAL_FLOAT:1.0 +/(1.0,cast.FLOAT(A)) +::STMT +FLOAT:i +LITERAL_FLOAT:133.0 +*(133.0,i) +::STMT +FLOAT:parsertemp40812,m2,int416 +LITERAL_FLOAT:2000.0 +/(sqrt(*(/(int416,parsertemp40812),m2)),sqrt(2000.0)) +::STMT +MATRIX:parsertemp410978,W,X,H,parsertemp410980 +FLOAT:eps +%*%(/(X,+(%*%(W,H),eps)),t(/(*(H,parsertemp410978),t(parsertemp410980)))) +::STMT +MATRIX:U,row_nonzeros +LITERAL_FLOAT:1.0E-6 +*(*(1.0E-6,U),row_nonzeros) +::STMT +MATRIX:A,B,C,D,X +==(%*%(<=(%*%(X,A),B),C),D) +::STMT +FLOAT:i +LITERAL_FLOAT:3.0 +-(3.0,i) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:105.0,1.0 +/(+(classFeatureCounts,1.0),+(rowSums(classFeatureCounts),105.0)) +::STMT +FLOAT:Hin +LITERAL_FLOAT:2.0 +/(/(Hin,2.0),2.0) +::STMT +MATRIX:parsertemp24102 +LITERAL_FLOAT:1.0 +-(1.0,<(+(round(parsertemp24102),1.0),1.0)) +::STMT +MATRIX:parsertemp150470,parsertemp149323,LT +%*%(rowSums(exp(-(LT,parsertemp149323))),parsertemp150470) +::STMT +MATRIX:tpr,fpr +LITERAL_FLOAT:2.0 +/(*(-(fpr,fpr),+(tpr,tpr)),2.0) +::STMT +FLOAT:float878,m2,int725 +LITERAL_FLOAT:2001.0 +sqrt(*(/(2001.0,-(int725,float878)),m2)) +::STMT +MATRIX:p_CG,z +FLOAT:trust_delta_sq +LITERAL_FLOAT:2.0 +*(sum(^(p_CG,2.0)),-(*(cast.FLOAT(z),cast.FLOAT(z)),trust_delta_sq)) +::STMT +MATRIX:simplex +-(rowSums(simplex),simplex) +::STMT +FLOAT:m2,wt,float618 +LITERAL_FLOAT:5.0 +*(5.0,sqrt(/(*(m2,wt),-(wt,float618)))) +::STMT +MATRIX:parsertemp383172,X_nonzero_ind +FLOAT:parsertemp383177,reg,parsertemp383180,loss_init +-(loss_init,+(sum(*(X_nonzero_ind,parsertemp383172)),*(reg,+(parsertemp383177,parsertemp383180)))) +::STMT +MATRIX:C,parsertemp11064 +LITERAL_FLOAT:10000.0 +/(sum(==(parsertemp11064,C)),10000.0) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:1.0,500.0 +/(+(classFeatureCounts,1.0),+(rowSums(classFeatureCounts),500.0)) +::STMT +LITERAL_FLOAT:2.7182818284 +2.7182818284 +::STMT +FLOAT:217_a22,int533,parsertemp22450,parsertemp22451,parsertemp22485 +/(parsertemp22485,sqrt(+(+(parsertemp22450,parsertemp22451),/(int533,217_a22)))) +::STMT +MATRIX:Grad +FLOAT:int907 +LITERAL_FLOAT:2.0 +sqrt(sum(^(*(Grad,int907),2.0))) +::STMT +MATRIX:parsertemp553017,M2,parsertemp553121,parsertemp553020,parsertemp553009 +LITERAL_FLOAT:2.0 +sqrt(-(+(%*%(parsertemp553009,parsertemp553121),t(parsertemp553017)),*(2.0,%*%(M2,parsertemp553020)))) +::STMT +MATRIX:parsertemp500609,parsertemp500606,parsertemp500604 +FLOAT:int192 +sum(abs(*(*(parsertemp500604,parsertemp500606),>(parsertemp500609,int192)))) +::STMT +MATRIX:R,dssp,dsep,dssm,dsem +/(-(+(R,dsep),dsem),-(+(R,dssp),dssm)) +::STMT +MATRIX:parsertemp131907,offset,parsertemp131910,parsertemp132092,rightHist,mask,outBucket +LITERAL_FLOAT:1.0 +/(-(-(offset,%*%(mask,parsertemp131910)),1.0),%*%(==(outBucket,%*%(parsertemp132092,parsertemp131907)),rightHist)) +::STMT +MATRIX:r,Hd +FLOAT:c +LITERAL_FLOAT:-1.0 +*(+(r,*(c,Hd)),-1.0) +::STMT +MATRIX:X +FLOAT:parsertemp496694,a0 +LITERAL_FLOAT:2.0 ++(parsertemp496694,/(^(cast.FLOAT(X),2.0),a0)) +::STMT +MATRIX:parsertemp379560,m_iter_err_sum,m_err +LITERAL_FLOAT:-1.0 +*(-(t(+(parsertemp379560,m_iter_err_sum)),+(colSums(m_err),m_iter_err_sum)),-1.0) +::STMT +FLOAT:parsertemp13703 +LITERAL_FLOAT:999.0,1000.0 +/(*(parsertemp13703,1000.0),999.0) +::STMT +MATRIX:W +FLOAT:parsertemp112,int190,parsertemp91 +LITERAL_FLOAT:2.0,3.0,4.0,5.0 +/(*(*(4.0,-(parsertemp112,int190)),^(sqrt(parsertemp91),2.0)),*(+(sum(W),5.0),-(sum(W),3.0))) +::STMT +MATRIX:parsertemp379566 +FLOAT:int699,i_process_item +LITERAL_FLOAT:2.0 +*(^(/(*(parsertemp379566,int699),i_process_item),2.0),i_process_item) +::STMT +MATRIX:Xm,Z,parsertemp265732 +/(sum(-(%*%(Z,parsertemp265732),Xm)),sum(Xm)) +::STMT +MATRIX:parsertemp396406,W3_rand +FLOAT:int564,int269 +LITERAL_FLOAT:0.16823164622761327 +%*%(*(0.16823164622761327,W3_rand),t(/(-(parsertemp396406,int564),+(parsertemp396406,int269)))) +::STMT +MATRIX:D,ZERODIAG,beta +FLOAT:int694 +*(exp(*(-(int694,D),beta)),ZERODIAG) +::STMT +LITERAL_FLOAT:3352500.0 +3352500.0 +::STMT +MATRIX:parsertemp171366,p_one_m_one +LITERAL_FLOAT:3.141592653589793,0.5 ++(0.5,/(%*%(parsertemp171366,p_one_m_one),3.141592653589793)) +::STMT +FLOAT:K +LITERAL_FLOAT:151.0 +*(151.0,K) +::STMT +MATRIX:r,c,E,F +FLOAT:int785 +LITERAL_FLOAT:1.0E-4 +-(F,+(*(==(E,int785),1.0E-4),/(%*%(r,c),sum(F)))) +::STMT +MATRIX:W +LITERAL_FLOAT:1.0,2.0,3.0 +*(*(-(sum(W),1.0),-(sum(W),2.0)),-(sum(round(W)),3.0)) +::STMT +MATRIX:os,y,o +LITERAL_FLOAT:0.0 +*(-(0.0,y),+(o,os)) +::STMT +MATRIX:B1 +FLOAT:nc +LITERAL_FLOAT:1.0 +/(nrow(B1),-(nc,1.0)) +::STMT +MATRIX:cumLens +FLOAT:i +LITERAL_FLOAT:1.0 +/(-(i,1.0),cumLens) +::STMT +MATRIX:W,H,parsertemp411100,parsertemp411104,parsertemp411105 +%*%(W,%*%(*(H,/(parsertemp411100,parsertemp411104)),t(*(H,parsertemp411105)))) +::STMT +MATRIX:p,z +FLOAT:pp,parsertemp169870,pz +LITERAL_FLOAT:-1.0 +-(*(sum(*(p,z)),-1.0),sqrt(-(*(pz,pz),*(pp,parsertemp169870)))) +::STMT +MATRIX:parsertemp185168,parsertemp185169,parsertemp185166,parsertemp185165 +>(-(parsertemp185165,parsertemp185166),-(parsertemp185168,parsertemp185169)) +::STMT +MATRIX:d_r,parsertemp409781 +sum(*(rev(d_r),parsertemp409781)) +::STMT +FLOAT:norm_grad_initial +LITERAL_FLOAT:0.001 +*(0.001,norm_grad_initial) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:1.0,2.0 +/(/(1.0,^(linear_terms,2.0)),2.0) +::STMT +MATRIX:r_CG,g_reg,z +*(cast.FLOAT(z),+(cast.FLOAT(r_CG),cast.FLOAT(g_reg))) +::STMT +MATRIX:selCols,selCols2 +-(sum(selCols),sum(selCols2)) +::STMT +MATRIX:_sbcvar92,220_r,220_c,220_E +FLOAT:int65 +LITERAL_FLOAT:1.0E-4 +-(_sbcvar92,+(*(==(220_E,int65),1.0E-4),/(%*%(220_r,220_c),sum(_sbcvar92)))) +::STMT +MATRIX:parsertemp16875 +FLOAT:epsilon +*(<(sqrt(rowSums(parsertemp16875)),epsilon),epsilon) +::STMT +MATRIX:s +LITERAL_FLOAT:2.0 +^(s,2.0) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:-0.0 +^(linear_terms,-0.0) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:-2.0 +^(linear_terms,-2.0) +::STMT +MATRIX:t_gp,pt_gp,parsertemp171320,Y,the_gauss_exp,parsertemp171316 +FLOAT:one_over_sqrt_two_pi,int5 +LITERAL_FLOAT:2.0,0.25 +/(*(*(exp(parsertemp171320),^(one_over_sqrt_two_pi,int5)),rowSums(Y)),*(*(0.25,*(t_gp,parsertemp171316)),-(2.0,*(the_gauss_exp,pt_gp)))) +::STMT +MATRIX:X +LITERAL_FLOAT:2.0 +t(colSums(^(X,2.0))) +::STMT +MATRIX:p,r,Z +FLOAT:norm_r2,parsertemp503396 +LITERAL_FLOAT:0.0 +-(0.0,+(r,*(/(norm_r2,parsertemp503396),%*%(Z,p)))) +::STMT +MATRIX:resp,X,weight +/(%*%(t(resp),*(X,X)),t(weight)) +::STMT +MATRIX:parsertemp472180,I,yhat +LITERAL_FLOAT:2.0 +rowSums(^(*(I,-(yhat,parsertemp472180)),2.0)) +::STMT +MATRIX:p,parsertemp285529,g +FLOAT:pp,pq,int710,pz,parsertemp285543,parsertemp285521 +*(+(+(*(parsertemp285543,pq),sum(parsertemp285529)),sum(*(g,p))),/(-(*(pz,int710),sqrt(parsertemp285521)),pp)) +::STMT +MATRIX:parsertemp220902,parsertemp220903 +FLOAT:tol +LITERAL_FLOAT:2.0 +*(sum(^(-(parsertemp220902,parsertemp220903),2.0)),tol) +::STMT +FLOAT:ssPrev,parsertemp265725,parsertemp265724 +LITERAL_FLOAT:1.0,4000.0 +-(1.0,/(/(-(parsertemp265724,parsertemp265725),4000.0),ssPrev)) +::STMT +LITERAL_FLOAT:0.0,1.0,2.0 +INT:D,M +*(rand(D,M,0.0,1.0),sqrt(/(2.0,D))) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:-1.0 +^(linear_terms,-1.0) +::STMT +MATRIX:e_r_rev_agg,select,d_r_rev,X_rev_agg +colSums(/(*(%*%(select,X_rev_agg),d_r_rev),e_r_rev_agg)) +::STMT +MATRIX:Y +FLOAT:num_categories +LITERAL_FLOAT:0.0,-1.0 +*(+(*(Y,-1.0),num_categories),<=(Y,0.0)) +::STMT +MATRIX:X +FLOAT:x +/(-(x,X),-(X,X)) +::STMT +MATRIX:G,authorities,hubs +-(/(%*%(G,authorities),max(%*%(G,authorities))),hubs) +::STMT +MATRIX:W1_rand,stds,parsertemp396314 +LITERAL_FLOAT:0.07808688094430302 +t(%*%(*(0.07808688094430302,W1_rand),t(/(parsertemp396314,stds)))) +::STMT +MATRIX:dist +FLOAT:i +LITERAL_FLOAT:1.0 +-(+(i,cast.FLOAT(dist)),1.0) +::STMT +MATRIX:residual_matrix +FLOAT:273_lambda ++(nrow(residual_matrix),273_lambda) +::STMT +MATRIX:diff_nominal,diff,_sbcvar1151 +FLOAT:num_std_median +LITERAL_FLOAT:0.0 ++(*(!=(diff_nominal,0.0),num_std_median),*(diff,_sbcvar1151)) +::STMT +MATRIX:Xd,parsertemp2775 +FLOAT:int811 +LITERAL_FLOAT:0.0 +*(*(Xd,>(-(int811,parsertemp2775),0.0)),Xd) +::STMT +MATRIX:Y_counts,means,parsertemp560511 +*(Y_counts,rowSums(*(means,parsertemp560511))) +::STMT +MATRIX:col,parsertemp24101,parsertemp24103 +FLOAT:int720,num_bins,float276,int627 +LITERAL_FLOAT:1.0 +*(-(-(1.0,<(col,int720)),>(+(parsertemp24103,int627),num_bins)),+(round(-(parsertemp24101,float276)),1.0)) +::STMT +MATRIX:n_event_stratum,n_risk_stratum,n_risk +LITERAL_FLOAT:1.0 +/(*(*(n_risk,n_event_stratum),-(n_risk_stratum,n_event_stratum)),*(n_risk_stratum,-(n_risk_stratum,1.0))) +::STMT +MATRIX:Y +FLOAT:num_categories,int206 +LITERAL_FLOAT:0.0 ++(Y,*(+(*(Y,int206),num_categories),<=(Y,0.0))) +::STMT +MATRIX:parsertemp409723,R +LITERAL_FLOAT:1.0 +-(+(cast.FLOAT(parsertemp409723),cast.FLOAT(R)),1.0) +::STMT +MATRIX:linear_terms +FLOAT:var_power +LITERAL_FLOAT:2.0 +exp(*(linear_terms,-(2.0,var_power))) +::STMT +MATRIX:parsertemp195898 +FLOAT:int22,parsertemp195894,factor_up +abs(-(/(parsertemp195898,factor_up),/(/(parsertemp195894,int22),factor_up))) +::STMT +FLOAT:index +LITERAL_FLOAT:2.0,3.0,4.0 ++(+(*(index,4.0),2.0),3.0) +::STMT +MATRIX:x,y +LITERAL_FLOAT:2.0 +cast.FLOAT(/(+(x,y),2.0)) +::STMT +MATRIX:w,out +LITERAL_FLOAT:1.0,0.5 +*(1.0,+(*(0.5,cast.FLOAT(out)),*(1.0,cast.FLOAT(w)))) +::STMT +MATRIX:V,W,H,parsertemp10738 +LITERAL_FLOAT:1.0E-8 +/(%*%(t(W),V),+(%*%(%*%(parsertemp10738,W),H),1.0E-8)) +::STMT +LITERAL_FLOAT:1.0 ++(1.0,1.0) +::STMT +FLOAT:link_power +LITERAL_FLOAT:1.0 +/(1.0,link_power) +::STMT +FLOAT:link_power +LITERAL_FLOAT:-1.0 +/(-1.0,link_power) +::STMT +MATRIX:y +FLOAT:beta +LITERAL_FLOAT:2.0 +sum(^(-(beta,y),2.0)) +::STMT +MATRIX:dout,parsertemp555660,parsertemp555659 +FLOAT:int582,int684 +LITERAL_FLOAT:1.0 +*(*(/(1.0,+(int582,parsertemp555659)),-(1.0,/(int684,parsertemp555660))),dout) +::STMT +MATRIX:p,e,u,G +FLOAT:alpha +LITERAL_FLOAT:1.0 ++(*(alpha,%*%(G,p)),*(-(1.0,alpha),%*%(%*%(e,u),p))) +::STMT +MATRIX:X +FLOAT:val +<=(X,val) +::STMT +MATRIX:prob,pred,test_Y +FLOAT:threshold ++(*(pred,>(prob,threshold)),*(test_Y,<=(prob,threshold))) +::STMT +MATRIX:parsertemp79022 +LITERAL_FLOAT:0.5,1270.0 ++(0.5,/(parsertemp79022,1270.0)) +::STMT +MATRIX:X +FLOAT:397_C +*(nrow(X),/(ncol(X),397_C)) +::STMT +MATRIX:output_values +FLOAT:log_odds +LITERAL_FLOAT:0.3,2.7182818284 +^(2.7182818284,+(log_odds,*(0.3,cast.FLOAT(output_values)))) +::STMT +LITERAL_FLOAT:0.0,1.0 ++(1.0,0.0) +::STMT +MATRIX:X2p +LITERAL_FLOAT:0.0 +>(t(colSums(X2p)),0.0) +::STMT +MATRIX:p,parsertemp169865,z +FLOAT:pp,trust_delta_sq +-(*(sum(*(p,z)),sum(*(p,z))),*(pp,-(sum(parsertemp169865),trust_delta_sq))) +::STMT +MATRIX:s,d,alpha_deno +FLOAT:norm_r2 ++(s,*(cast.FLOAT(/(norm_r2,alpha_deno)),d)) +::STMT +LITERAL_FLOAT:0.0,1.0 +INT:int434,int699,int424,int815 +%*%(rand(int434,int699,0.0,1.0),rand(int424,int815,0.0,1.0)) +::STMT +MATRIX:p,p2 +LITERAL_FLOAT:1.0E8 +sum(>(abs(-(p2,p)),1.0E8)) +::STMT +MATRIX:parsertemp171090,is_one_y_corr,t,parsertemp171099,parsertemp171096 +FLOAT:int352,float868 +LITERAL_FLOAT:1.0 ++(*(+(*(t,int352),/(parsertemp171090,parsertemp171096)),-(1.0,*(float868,parsertemp171099))),/(is_one_y_corr,-(1.0,is_one_y_corr))) +::STMT +MATRIX:parsertemp387409,Ks,Kss +abs(cast.FLOAT(-(Kss,%*%(parsertemp387409,Ks)))) +::STMT +MATRIX:Y +LITERAL_FLOAT:0.0,1.0 +-(+(nrow(Y),0.0),1.0) +::STMT +MATRIX:parsertemp170247,t_gp,parsertemp170252,lt_pos_neg,parsertemp170239 +FLOAT:float539,float739 +LITERAL_FLOAT:1.0,0.5,0.254829592 +*(*(-(0.5,lt_pos_neg),exp(/(parsertemp170252,float739))),*(/(1.0,+(float539,parsertemp170239)),+(0.254829592,*(t_gp,parsertemp170247)))) +::STMT +MATRIX:T_1,parsertemp410245,event,parsertemp410248 +FLOAT:float847,int506 +LITERAL_FLOAT:0.6666666666666666 +/(^(/(-(int506,parsertemp410245),*(float847,parsertemp410248)),0.6666666666666666),/(-(max(T_1),min(T_1)),sum(event))) +::STMT +MATRIX:classes +LITERAL_FLOAT:0.19999999999999996 +*(cast.FLOAT(classes),0.19999999999999996) +::STMT +FLOAT:ytest,int816 +LITERAL_FLOAT:1.0,2.0 +-(^(cast.FLOAT(ytest),2.0),*(1.0,^(/(ytest,int816),2.0))) +::STMT +LITERAL_FLOAT:2.0,7000.0 +^(7000.0,2.0) +::STMT +MATRIX:r,scale_X,shift_X,y,parsertemp116003 +FLOAT:int428 ++(*(scale_X,%*%(-(int428,parsertemp116003),y)),*(cast.FLOAT(r),shift_X)) +::STMT +MATRIX:d,X,logisticD +LITERAL_FLOAT:2.0 +*(2.0,%*%(t(X),*(logisticD,%*%(X,d)))) +::STMT +MATRIX:tmp_Xw,Y,Xd +LITERAL_FLOAT:0.0,1.0 +*(Xd,>(-(1.0,*(Y,tmp_Xw)),0.0)) +::STMT +MATRIX:W,H +LITERAL_FLOAT:1.0E-8 ++(%*%(%*%(t(W),W),H),1.0E-8) +::STMT +MATRIX:upd_W1 +LITERAL_FLOAT:0.95 +*(0.95,upd_W1) +::STMT +MATRIX:W +LITERAL_FLOAT:1.0,2.0 +*(-(sum(round(W)),2.0),+(sum(round(W)),1.0)) +::STMT +MATRIX:parsertemp222418,parsertemp222424 +FLOAT:sample_block_size +LITERAL_FLOAT:1.0 ++(*(sample_block_size,parsertemp222424),+(t(colSums(parsertemp222418)),1.0)) +::STMT +MATRIX:parsertemp254737 +FLOAT:parsertemp254766,2124_sq_root_d,float33,parsertemp254751 ++(float33,*(parsertemp254766,/(+(parsertemp254751,2124_sq_root_d),sum(parsertemp254737)))) +::STMT +MATRIX:parsertemp389328,parsertemp389331 +LITERAL_FLOAT:1.0 +t(/(-(exp(parsertemp389328),1.0),+(exp(parsertemp389331),1.0))) +::STMT +MATRIX:M +-(M,max(M)) +::STMT +MATRIX:img_in1,img_in2 +FLOAT:weight +LITERAL_FLOAT:1.0 ++(*(-(1.0,weight),img_in1),*(weight,img_in2)) +::STMT +MATRIX:parsertemp43993,os,d,X,alpha_deno ++(os,*(/(sum(parsertemp43993),cast.FLOAT(alpha_deno)),%*%(X,d))) +::STMT +FLOAT:n,norm +LITERAL_FLOAT:-2.0 +*(*(-2.0,norm),n) +::STMT +MATRIX:X +LITERAL_FLOAT:1.0 +*($1:ncol(X),+($1,1.0)) +::STMT +MATRIX:parsertemp265709,tmp,parsertemp265718,parsertemp265714 +FLOAT:Xm +LITERAL_FLOAT:2.0 +-(+(Xm,trace(*(tmp,parsertemp265714))),*(2.0,cast.FLOAT(%*%(parsertemp265718,parsertemp265709)))) +::STMT +MATRIX:minD,D,parsertemp222603,parsertemp222600 +colSums(/(<=(+(parsertemp222600,parsertemp222603),minD),rowSums(<=(D,minD)))) +::STMT +FLOAT:i +LITERAL_FLOAT:48.0 ++(48.0,i) +::STMT +MATRIX:log_prob,X +FLOAT:parsertemp436712 ++(*(ncol(X),parsertemp436712),log_prob) +::STMT +MATRIX:176_mask,W2,175_out +FLOAT:p +%*%(/(*(175_out,176_mask),p),W2) +::STMT +MATRIX:r,parsertemp1936,parsertemp1937 +FLOAT:parsertemp1941,norm_r2 +LITERAL_FLOAT:2.0 +^(+(r,*(/(norm_r2,parsertemp1941),+(parsertemp1936,parsertemp1937))),2.0) +::STMT +MATRIX:y_corr +LITERAL_FLOAT:2.0,0.5 +*(2.0,>(y_corr,0.5)) +::STMT +FLOAT:m2,float572,wt +LITERAL_FLOAT:4.0 +^(sqrt(/(*(m2,wt),-(wt,float572))),4.0) +::STMT +MATRIX:C,Xm,parsertemp265707,parsertemp265705,parsertemp265713 ++(sum(*(Xm,Xm)),trace(*(+(parsertemp265705,parsertemp265707),%*%(parsertemp265713,C)))) +::STMT +MATRIX:w_X,z_LS +LITERAL_FLOAT:2000.0 +/(2000.0,cast.FLOAT(%*%(t(w_X),z_LS))) +::STMT +FLOAT:float15,m2,wt +LITERAL_FLOAT:3.0 +^(sqrt(/(*(m2,wt),-(wt,float15))),3.0) +::STMT +MATRIX:n_risk_stratum +LITERAL_FLOAT:1.0 +*(n_risk_stratum,-(n_risk_stratum,1.0)) +::STMT +MATRIX:parsertemp498242,m_iter_err_sum,m_err +LITERAL_FLOAT:0.0 +-(0.0,-(t(+(parsertemp498242,m_iter_err_sum)),+(colSums(m_err),m_iter_err_sum))) +::STMT +MATRIX:col,more_than_ub,parsertemp24107,parsertemp24102,parsertemp24103 +FLOAT:int33,num_bins +LITERAL_FLOAT:1.0 ++(+(*(-(parsertemp24107,more_than_ub),+(parsertemp24103,int33)),*(>(col,num_bins),num_bins)),<(+(round(parsertemp24102),1.0),1.0)) +::STMT +MATRIX:R,S,Grad +LITERAL_FLOAT:-0.5 +*(-0.5,-(sum(*(S,Grad)),sum(*(S,R)))) +::STMT +MATRIX:sample_rec_ids +FLOAT:num_records +*(sample_rec_ids,<=(sample_rec_ids,num_records)) +::STMT +MATRIX:X +LITERAL_FLOAT:8.0 +==(X,8.0) +::STMT +LITERAL_FLOAT:990000.0 +990000.0 +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,11.0 ++(*(-(i,1.0),11.0),11.0) +::STMT +MATRIX:lambda,parsertemp171475 +FLOAT:new_log_l +LITERAL_FLOAT:0.5 +-(new_log_l,*(0.5,sum(*(lambda,parsertemp171475)))) +::STMT +MATRIX:parsertemp31112,parsertemp31114,parsertemp31105,parsertemp31107 +FLOAT:int146,int788,int637,int150 +LITERAL_FLOAT:1500.0,2000.0 ++(/(/(-(parsertemp31105,parsertemp31107),-(int150,int637)),2000.0),/(/(-(parsertemp31112,parsertemp31114),-(int788,int146)),1500.0)) +::STMT +MATRIX:Xi,X_rev_2 +*(X_rev_2,rev(Xi)) +::STMT +FLOAT:var_lag,xq_lag,arch_coef,var_coef,a0 ++(+(a0,*(arch_coef,xq_lag)),*(var_coef,var_lag)) +::STMT +MATRIX:minD,parsertemp72030,parsertemp72033,parsertemp72034,parsertemp72031 +FLOAT:int588 +/(<=(+(*(int588,parsertemp72030),t(parsertemp72033)),minD),rowSums(<=(+(parsertemp72031,parsertemp72034),minD))) +::STMT +MATRIX:G +!=(rowSums(G),t(colSums(G))) +::STMT +MATRIX:e,X,tS +FLOAT:l +*(==(%*%(X,tS),l),e) +::STMT +FLOAT:cmLabels +LITERAL_FLOAT:1.0,10000.0 +*(cmLabels,/(10000.0,-(10000.0,1.0))) +::STMT +MATRIX:y_corr +LITERAL_FLOAT:0.0,1.0 +/(==(y_corr,0.0),-(1.0,==(y_corr,0.0))) +::STMT +LITERAL_FLOAT:3.37275E9 +3.37275E9 +::STMT +FLOAT:i +LITERAL_FLOAT:96.0 ++(96.0,i) +::STMT +MATRIX:p_CG +FLOAT:parsertemp170089,z,pp_CG +LITERAL_FLOAT:-1.0 ++(*(*(cast.FLOAT(z),sum(p_CG)),-1.0),sqrt(-(*(z,z),*(pp_CG,parsertemp170089)))) +::STMT +MATRIX:V +t(V) +::STMT +MATRIX:ssX_p_CG,shift_X,p_CG ++(ssX_p_CG,cast.FLOAT(%*%(t(shift_X),p_CG))) +::STMT +MATRIX:U,V,X,parsertemp382841,row_nonzeros +FLOAT:int259 +LITERAL_FLOAT:1.0E-6 ++(%*%(*(!=(X,int259),-(parsertemp382841,X)),V),*(*(1.0E-6,U),row_nonzeros)) +::STMT +FLOAT:i,n +LITERAL_FLOAT:1.0,11.0 +-(n,-(+(i,11.0),1.0)) +::STMT +LITERAL_FLOAT:1.061405429 +1.061405429 +::STMT +MATRIX:X +LITERAL_FLOAT:1.0E-6 +*(1.0E-6,nrow(X)) +::STMT +MATRIX:m_active_flag_tmp,m_active_flag +LITERAL_FLOAT:1.0 +sum(-(>=(+(m_active_flag,m_active_flag_tmp),1.0),1.0)) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,12.0 ++(*(-(i,1.0),12.0),1.0) +::STMT +LITERAL_FLOAT:0.10938070012761454 +0.10938070012761454 +::STMT +MATRIX:prevTK2,totalE,X2 +%*%(t(totalE),==(%*%(X2,t(prevTK2)),t(rowSums(prevTK2)))) +::STMT +FLOAT:X +LITERAL_FLOAT:50.0,1.0E-6 +/(*(1.0E-6,X),50.0) +::STMT +MATRIX:os,d,X,alpha_deno +FLOAT:norm_r2 ++(os,*(cast.FLOAT(/(norm_r2,alpha_deno)),%*%(X,d))) +::STMT +MATRIX:M2 +LITERAL_FLOAT:0.0 +!(!=(M2,0.0)) +::STMT +MATRIX:S,parsertemp175056 +exp(-(S,parsertemp175056)) +::STMT +MATRIX:R +LITERAL_FLOAT:0.0 +==(colSums(!=(R,0.0)),0.0) +::STMT +MATRIX:parsertemp472147,I,y2 +%*%(/(%*%(I,y2),sum(I)),parsertemp472147) +::STMT +MATRIX:lambda,parsertemp149401,parsertemp149400,B_new +LITERAL_FLOAT:2.0 +sum(^(+(%*%(parsertemp149400,parsertemp149401),*(lambda,B_new)),2.0)) +::STMT +MATRIX:lambda +FLOAT:newbeta,new_log_l,int183 +LITERAL_FLOAT:0.5 +-(new_log_l,*(0.5,*(cast.FLOAT(lambda),^(newbeta,int183)))) +::STMT +MATRIX:2846_Q,X +LITERAL_FLOAT:2.0 ++(rowSums(^(X,2.0)),sum(^(2846_Q,2.0))) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:Infinity +==(linear_terms,Infinity) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:-Infinity +==(linear_terms,-Infinity) +::STMT +MATRIX:mask +LITERAL_FLOAT:1.0 +==(mask,1.0) +::STMT +FLOAT:X +LITERAL_FLOAT:1.0E-6,100.0 +/(*(1.0E-6,X),100.0) +::STMT +FLOAT:j +LITERAL_FLOAT:4.0 +-(4.0,j) +::STMT +MATRIX:parsertemp195898 +FLOAT:parsertemp195893,int52,factor_up +LITERAL_FLOAT:2.0 +-(/(parsertemp195898,factor_up),/(/(-(parsertemp195893,int52),2.0),factor_up)) +::STMT +MATRIX:T,parsertemp537734 +LITERAL_FLOAT:0.0 +sum(==(%*%(parsertemp537734,T),0.0)) +::STMT +MATRIX:X +FLOAT:m2X,float920,W +sqrt(*(m2X,/(nrow(X),-(W,float920)))) +::STMT +MATRIX:parsertemp385504 +LITERAL_FLOAT:0.0,6.0 +-(6.0,sum(!=(t(parsertemp385504),0.0))) +::STMT +MATRIX:w_X,z_LS,X +*(/(nrow(X),sum(*(w_X,z_LS))),z_LS) +::STMT +MATRIX:parsertemp31104,parsertemp31106 +FLOAT:int975 +LITERAL_FLOAT:1.0,2.0,2000.0 +^(/(-(colSums(parsertemp31104),*(int975,parsertemp31106)),-(2000.0,1.0)),2.0) +::STMT +FLOAT:i,n +LITERAL_FLOAT:1.0,12.0 +-(n,-(+(i,12.0),1.0)) +::STMT +MATRIX:parsertemp195899,parsertemp195900 +FLOAT:center +LITERAL_FLOAT:1.0 +%*%(-(1.0,abs(-(parsertemp195899,center))),t(-(1.0,abs(parsertemp195900)))) +::STMT +MATRIX:p,parsertemp1597,beta_unscaled +FLOAT:norm_r2 ++(beta_unscaled,*(/(norm_r2,sum(parsertemp1597)),p)) +::STMT +MATRIX:parsertemp174552 +LITERAL_FLOAT:0.0 +==(parsertemp174552,0.0) +::STMT +MATRIX:linear_terms +FLOAT:var_power +LITERAL_FLOAT:1.0,2.0 +/(/(1.0,linear_terms),-(2.0,var_power)) +::STMT +MATRIX:ss +LITERAL_FLOAT:20.0 +/(20.0,ss) +::STMT +MATRIX:X,Y +FLOAT:eps ++(+(abs(X),abs(Y)),eps) +::STMT +MATRIX:parsertemp146974,mW1,190_dW,parsertemp146977 +FLOAT:parsertemp146983,191_lr,parsertemp146981,int10,191_beta1,parsertemp146971,191_epsilon +/(*(/(*(191_lr,parsertemp146981),-(int10,parsertemp146983)),+(*(191_beta1,mW1),*(parsertemp146971,190_dW))),+(sqrt(+(parsertemp146974,parsertemp146977)),191_epsilon)) +::STMT +MATRIX:lambda,beta +LITERAL_FLOAT:0.0 ++(0.0,*(cast.FLOAT(lambda),cast.FLOAT(beta))) +::STMT +MATRIX:parsertemp31030,parsertemp31032 +FLOAT:int387,int994 +LITERAL_FLOAT:1.0,2.0,150.0 +/(^(/(-(parsertemp31030,parsertemp31032),-(int994,int387)),2.0),*(^(150.0,2.0),-(150.0,1.0))) +::STMT +MATRIX:C,Xm,parsertemp265701 +%*%(t(%*%(Xm,%*%(C,parsertemp265701))),%*%(Xm,%*%(C,parsertemp265701))) +::STMT +MATRIX:g_reg,g,parsertemp285556 +sqrt(cast.FLOAT(%*%(t(g_reg),+(g,parsertemp285556)))) +::STMT +MATRIX:linear_terms +FLOAT:link_power +LITERAL_FLOAT:1.0 +^(linear_terms,-(/(1.0,link_power),1.0)) +::STMT +MATRIX:linear_terms +FLOAT:link_power +LITERAL_FLOAT:-1.0,1.0 +^(linear_terms,-(/(-1.0,link_power),1.0)) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:1.0E7 ++(exp(linear_terms),==(+(1.0E7,exp(linear_terms)),1.0E7)) +::STMT +MATRIX:parsertemp10744,W,H +LITERAL_FLOAT:1.0E-8 ++(%*%(W,%*%(*(H,parsertemp10744),t(H))),1.0E-8) +::STMT +FLOAT:int53 +LITERAL_FLOAT:0.0 +INT:int403,m +rand(m,int403,0.0,int53) +::STMT +MATRIX:Xi_X_rev_agg,e_r_rev_agg,select,Xi_agg_rev_agg,X_agg +LITERAL_FLOAT:2.0 +-(/(%*%(select,Xi_X_rev_agg),e_r_rev_agg),/(*(X_agg,Xi_agg_rev_agg),^(e_r_rev_agg,2.0))) +::STMT +MATRIX:err,cCnts +FLOAT:minSup +LITERAL_FLOAT:0.0 +sum(|(<(cCnts,minSup),==(err,0.0))) +::STMT +FLOAT:m2 +LITERAL_FLOAT:1.0004995004995005,4.0 +^(sqrt(*(1.0004995004995005,m2)),4.0) +::STMT +MATRIX:linear_terms +FLOAT:link_power +LITERAL_FLOAT:1.0,2.0 +^(linear_terms,-(/(1.0,link_power),2.0)) +::STMT +FLOAT:m2 +LITERAL_FLOAT:1.0004995004995005,3.0 +^(sqrt(*(1.0004995004995005,m2)),3.0) +::STMT +MATRIX:is_zero_y_corr,is_one_y_corr,parsertemp171113 +FLOAT:parsertemp171116 +LITERAL_FLOAT:1.0 ++(-(parsertemp171113,*(parsertemp171116,+(is_zero_y_corr,is_one_y_corr))),/(is_one_y_corr,-(1.0,is_one_y_corr))) +::STMT +MATRIX:ZtZ,Xm,parsertemp265709,parsertemp265706,Z,parsertemp265702,XtZ +FLOAT:ss,ZtZ_sum +*(+(%*%(t(Z),%*%(Xm,parsertemp265702)),*(parsertemp265706,ss)),%*%(t(/(XtZ,ZtZ_sum)),/(%*%(parsertemp265709,Z),sum(ZtZ)))) +::STMT +MATRIX:tmp +FLOAT:N +LITERAL_FLOAT:0.0,1.0 +<=(/(tmp,-(N,1.0)),0.0) +::STMT +MATRIX:CFreqs1 +LITERAL_FLOAT:0.0,1.0 +diag(-(1.0,==(CFreqs1,0.0))) +::STMT +MATRIX:y_hat,X +sum(*(-(X,y_hat),-(X,y_hat))) +::STMT +MATRIX:parsertemp31024,parsertemp31022 +FLOAT:int346 +LITERAL_FLOAT:99.0,100.0 +/(/(-(colSums(parsertemp31022),*(int346,parsertemp31024)),99.0),100.0) +::STMT +FLOAT:D +LITERAL_FLOAT:2.0 +sqrt(/(2.0,D)) +::STMT +MATRIX:lengths +abs(-(cast.FLOAT(lengths),cast.FLOAT(lengths))) +::STMT +MATRIX:X,Y +LITERAL_FLOAT:2.0 +^(-(X,Y),2.0) +::STMT +MATRIX:resp,Y,parsertemp506189 +==(+(resp,t(parsertemp506189)),Y) +::STMT +MATRIX:e_r_rev_agg,parsertemp409787,parsertemp409796 +LITERAL_FLOAT:-1.0 ++(*(t(colSums(parsertemp409787)),-1.0),t(colSums(/(parsertemp409796,e_r_rev_agg)))) +::STMT +MATRIX:X,Centering,ScaleFactor +colSums(/(-(X,Centering),ScaleFactor)) +::STMT +MATRIX:parsertemp402079,W3_rand,parsertemp402082 +LITERAL_FLOAT:0.1092173494617922 +t(%*%(*(0.1092173494617922,W3_rand),t(/(parsertemp402079,parsertemp402082)))) +::STMT +MATRIX:parsertemp76118 +LITERAL_FLOAT:4460.0 +/(parsertemp76118,4460.0) +::STMT +MATRIX:W,Y,sumW +LITERAL_FLOAT:300.0,0.0 +-(0.0,*(300.0,-(*(Y,sumW),%*%(W,Y)))) +::STMT +MATRIX:grad +LITERAL_FLOAT:-1.0 +sum(*(*(grad,-1.0),*(grad,-1.0))) +::STMT +MATRIX:Kss,parsertemp387410 +sqrt(abs(cast.FLOAT(-(Kss,parsertemp387410)))) +::STMT +MATRIX:img +FLOAT:Hf,Wf +*(*(nrow(img),Hf),Wf) +::STMT +MATRIX:z +sqrt(sum(*(z,z))) +::STMT +MATRIX:p,V +FLOAT:eps ++(%*%(t(V),%*%(V,p)),*(eps,p)) +::STMT +MATRIX:posSamples,posSampleMeans +FLOAT:int892,int522 +LITERAL_FLOAT:1999.0,2000.0 +/(-(colSums(^(posSamples,int892)),*(2000.0,^(posSampleMeans,int522))),1999.0) +::STMT +MATRIX:parsertemp2832 +==(round(parsertemp2832),min(round(parsertemp2832))) +::STMT +MATRIX:parsertemp77570 +LITERAL_FLOAT:2358.0 +/(parsertemp77570,2358.0) +::STMT +FLOAT:factor_up,parsertemp195891,parsertemp195892 +LITERAL_FLOAT:1.0,2.0 +/(/(-(-(parsertemp195891,parsertemp195892),1.0),2.0),factor_up) +::STMT +MATRIX:439_Ranks,parsertemp42225 +FLOAT:parsertemp42214,parsertemp42216,parsertemp42218,meanY,parsertemp42220 +/(sum(*(t(parsertemp42225),-(439_Ranks,meanY))),*(sqrt(*(parsertemp42214,parsertemp42216)),sqrt(*(parsertemp42218,parsertemp42220)))) +::STMT +FLOAT:ssPrev,parsertemp265725,parsertemp265724,m,n +LITERAL_FLOAT:1.0 +-(1.0,/(/(-(parsertemp265724,parsertemp265725),*(n,m)),ssPrev)) +::STMT +MATRIX:ytest,yhat +LITERAL_FLOAT:2.0 +^(-(cast.FLOAT(ytest),cast.FLOAT(yhat)),2.0) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:0.0 +-(0.0,exp(-(0.0,linear_terms))) +::STMT +MATRIX:parsertemp170240,parsertemp170238 +FLOAT:float911,float541 +LITERAL_FLOAT:1.0,1.061405429,-1.453152027 +*(/(1.0,+(1.0,*(parsertemp170238,float541))),+(-1.453152027,*(/(float911,parsertemp170240),1.061405429))) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,12.0 ++(*(-(i,1.0),12.0),12.0) +::STMT +MATRIX:parsertemp389215,parsertemp389217 +LITERAL_FLOAT:1057.0,1058.0 +sqrt(/(*(-(parsertemp389215,parsertemp389217),1058.0),1057.0)) +::STMT +MATRIX:finite_linear_terms +LITERAL_FLOAT:2.0 +/(exp(finite_linear_terms),2.0) +::STMT +MATRIX:A,CFreqs +-(nrow(A),nrow(CFreqs)) +::STMT +MATRIX:parsertemp129186,parsertemp129185,key_unique,key +t(==(%*%(key_unique,parsertemp129185),%*%(parsertemp129186,t(key)))) +::STMT +MATRIX:F +-(F,/(%*%(rowSums(F),colSums(F)),sum(F))) +::STMT +MATRIX:S,V,W +*(W,%*%(S,t(V))) +::STMT +MATRIX:parsertemp220853,Ws,beta +FLOAT:logU +LITERAL_FLOAT:0.0 +>=(-(+(parsertemp220853,*(beta,Ws)),logU),0.0) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,12.0,8.0 ++(*(-(i,1.0),12.0),8.0) +::STMT +MATRIX:grad +FLOAT:psi +*(psi,sqrt(sum(*(grad,grad)))) +::STMT +MATRIX:r,parsertemp44063,grad +LITERAL_FLOAT:-0.5 +*(-0.5,-(cast.FLOAT(%*%(parsertemp44063,grad)),cast.FLOAT(%*%(parsertemp44063,r)))) +::STMT +MATRIX:W +LITERAL_FLOAT:3.0,5.0 +*(+(sum(round(W)),5.0),-(sum(round(W)),3.0)) +::STMT +MATRIX:p,q,lambda +FLOAT:norm_r2 +*(/(norm_r2,sum(*(p,q))),+(q,*(lambda,p))) +::STMT +MATRIX:Q1,IQR +LITERAL_FLOAT:2.0 +-(Q1,*(2.0,IQR)) +::STMT +FLOAT:rho +LITERAL_FLOAT:10000.0 +*(10000.0,rho) +::STMT +MATRIX:r,parsertemp44063,parsertemp44065,grad +LITERAL_FLOAT:-0.5 +cast.FLOAT(*(-0.5,-(%*%(parsertemp44063,grad),%*%(parsertemp44065,r)))) +::STMT +FLOAT:cols,parsertemp451837 +LITERAL_FLOAT:1.0 ++(+(*(parsertemp451837,cols),1.0),cols) +::STMT +MATRIX:finite_linear_terms +LITERAL_FLOAT:-1.0 +exp(*(exp(finite_linear_terms),-1.0)) +::STMT +MATRIX:parsertemp115947,TK +*(rowSums(TK),parsertemp115947) +::STMT +MATRIX:scale_lambda,parsertemp150455 +FLOAT:reg +*(%*%(scale_lambda,parsertemp150455),reg) +::STMT +MATRIX:inactive_set,w +LITERAL_FLOAT:0.0 +-(inactive_set,!=(w,0.0)) +::STMT +FLOAT:m2,mu,float907,wt +/(sqrt(/(*(m2,wt),-(wt,float907))),mu) +::STMT +MATRIX:valueCount,parsertemp552530,Y +FLOAT:int866,int933 +INT:parsertemp552529,idx +*(==(+(rand(parsertemp552529,idx,int933,int866),t(parsertemp552530)),Y),valueCount) +::STMT +MATRIX:prediction,target +/(-(prediction,target),nrow(target)) +::STMT +MATRIX:posSampleMeans +LITERAL_FLOAT:2.0,100.0 +*(100.0,^(posSampleMeans,2.0)) +::STMT +MATRIX:252_Y +FLOAT:252_X,float555 +LITERAL_FLOAT:1.0 +*(-(1.0,/(-(float555,252_X),-(252_X,252_X))),cast.FLOAT(252_Y)) +::STMT +FLOAT:vicinity,a0 +LITERAL_FLOAT:1.0 +*(-(1.0,vicinity),a0) +::STMT +MATRIX:Y +-(nrow(Y),sum(Y)) +::STMT +MATRIX:mu +FLOAT:window_size +*(window_size,cast.FLOAT(*(mu,mu))) +::STMT +MATRIX:parsertemp459193,2701_dX,vb3 +FLOAT:lr,mu +-(*(mu,vb3),*(lr,colSums(*(parsertemp459193,2701_dX)))) +::STMT +MATRIX:lambda,g,beta +LITERAL_FLOAT:0.0 +-(0.0,+(g,*(cast.FLOAT(lambda),cast.FLOAT(beta)))) +::STMT +MATRIX:parsertemp555752 +FLOAT:int398 +LITERAL_FLOAT:0.5 +sum(*(0.5,rowSums(^(parsertemp555752,int398)))) +::STMT +MATRIX:Xm,parsertemp265707,parsertemp265705,parsertemp265702 +t(/(%*%(t(Xm),%*%(Xm,parsertemp265702)),sum(+(parsertemp265705,parsertemp265707)))) +::STMT +MATRIX:parsertemp191275 +FLOAT:397_C +*(397_C,t(parsertemp191275)) +::STMT +MATRIX:ts +FLOAT:q ++(-(q,*(cast.FLOAT(ts),cast.FLOAT(ts))),*(cast.FLOAT(ts),cast.FLOAT(ts))) +::STMT +FLOAT:Z_logl +LITERAL_FLOAT:-1.0 +*(abs(Z_logl),-1.0) +::STMT +MATRIX:classFeatureCounts +FLOAT:numFeatures,laplaceCorrection ++(rowSums(classFeatureCounts),*(numFeatures,laplaceCorrection)) +::STMT +MATRIX:X +FLOAT:2917_split +round(*(nrow(X),2917_split)) +::STMT +FLOAT:parsertemp557354,parsertemp557358,prob_true,prob_false +LITERAL_FLOAT:0.6931471805599453 ++(/(*(prob_true,parsertemp557354),0.6931471805599453),/(*(prob_false,parsertemp557358),0.6931471805599453)) +::STMT +MATRIX:mn,mx +LITERAL_FLOAT:1.0 ++(-(mx,mn),1.0) +::STMT +MATRIX:parsertemp409803 +FLOAT:D +LITERAL_FLOAT:0.5 +/(*(0.5,sqrt(D)),max(sqrt(rowSums(parsertemp409803)))) +::STMT +MATRIX:r,parsertemp1945 +FLOAT:norm_r2 +/(sum(*(+(r,parsertemp1945),+(r,parsertemp1945))),norm_r2) +::STMT +FLOAT:x1,x2 +LITERAL_FLOAT:-1.0,2.0 +*(-1.0,^(-(x1,x2),2.0)) +::STMT +MATRIX:R,parsertemp40226 +FLOAT:eAvg +/(/(+(R,rowSums(parsertemp40226)),R),eAvg) +::STMT +MATRIX:V +max(V) +::STMT +MATRIX:Y_prob,Y,linear_terms +FLOAT:int926 +LITERAL_FLOAT:3.141592653589793,1.0 +*(*(*(rowSums(Y),Y_prob),Y_prob),*(+(1.0,^(linear_terms,int926)),3.141592653589793)) +::STMT +MATRIX:obj,objnew,gs +-(-(cast.FLOAT(objnew),cast.FLOAT(obj)),cast.FLOAT(gs)) +::STMT +MATRIX:prob,pred,test_Y +FLOAT:threshold +LITERAL_FLOAT:0.0 ++(*(pred,>(prob,threshold)),*(test_Y,==(>(prob,threshold),0.0))) +::STMT +FLOAT:K +LITERAL_FLOAT:300.0 +*(300.0,K) +::STMT +FLOAT:acc +LITERAL_FLOAT:1.0,100.0 +cast.MATRIX(-(1.0,/(acc,100.0))) +::STMT +MATRIX:u,minDist +!=(u,minDist) +::STMT +MATRIX:N_T,tmp,X +<=(rowSums(*(X,tmp)),%*%(tmp,t(N_T))) +::STMT +MATRIX:parsertemp32006,simplex +LITERAL_FLOAT:2.0,4.0 +-(*(2.0,/(-(parsertemp32006,simplex),4.0)),simplex) +::STMT +MATRIX:s,parsertemp44005,d +FLOAT:parsertemp44004 +cast.FLOAT(%*%(t(+(s,parsertemp44005)),+(s,*(parsertemp44004,d)))) +::STMT +MATRIX:parsertemp171348,is_too_small,parsertemp171346,parsertemp171344,parsertemp171353,linear_terms,Y,the_exp,parsertemp171349 +FLOAT:int369,int803 +/(*(*(exp(parsertemp171344),exp(linear_terms)),rowSums(Y)),+(/(*(parsertemp171348,parsertemp171349),+(the_exp,is_too_small)),*(==(parsertemp171346,int803),-(int369,parsertemp171353)))) +::STMT +MATRIX:betamax,parsertemp220870,Hpos,beta +FLOAT:INF,int237 +LITERAL_FLOAT:2.0 ++(*(*(*(int237,Hpos),==(betamax,INF)),beta),/(*(*(Hpos,parsertemp220870),+(beta,betamax)),2.0)) +::STMT +MATRIX:_sbcvar1782 +FLOAT:_sbcvar1783 +LITERAL_FLOAT:8.0 +/(_sbcvar1782,-(8.0,_sbcvar1783)) +::STMT +MATRIX:y_hat +FLOAT:parsertemp176421,k +-(sqrt(parsertemp176421),*(k,y_hat)) +::STMT +MATRIX:F +LITERAL_FLOAT:0.0 +==(/(%*%(rowSums(F),colSums(F)),sum(F)),0.0) +::STMT +MATRIX:w,X,y +FLOAT:int485,int701 +INT:int178,m +%*%(t(-(%*%(X,w),y)),-(%*%(X,rand(m,int178,int485,int701)),y)) +::STMT +MATRIX:X +LITERAL_FLOAT:480.0 +/(colSums(X),480.0) +::STMT +MATRIX:Yhat_prime,E +t(colSums(*(E,Yhat_prime))) +::STMT +MATRIX:t_gp,parsertemp171320,Y,linear_terms,parsertemp171316 +LITERAL_FLOAT:0.0,0.5 +*(*(*(exp(parsertemp171320),*(t_gp,parsertemp171316)),rowSums(Y)),-(>=(linear_terms,0.0),0.5)) +::STMT +FLOAT:prob_true,prob_false +LITERAL_FLOAT:1.0,2.0 +-(1.0,+(^(prob_true,2.0),^(prob_false,2.0))) +::STMT +LITERAL_FLOAT:1.0,100000.0 +-(100000.0,1.0) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,8.0 ++(*(-(i,1.0),8.0),1.0) +::STMT +MATRIX:2700_dX,parsertemp459190,2702_X +FLOAT:int389,lr +*(lr,colSums(*(>(2702_X,int389),*(parsertemp459190,2700_dX)))) +::STMT +MATRIX:R,B,parsertemp503364 +LITERAL_FLOAT:-1.0 +*(%*%(t(+(R,parsertemp503364)),B),-1.0) +::STMT +MATRIX:parsertemp230374 +t(t(parsertemp230374)) +::STMT +MATRIX:parsertemp409216,parsertemp409212,ctab +LITERAL_FLOAT:0.45 +*(parsertemp409216,>(/(parsertemp409212,rowSums(ctab)),0.45)) +::STMT +MATRIX:out2,184_probs,183_dpred,parsertemp146939,W3 +LITERAL_FLOAT:0.0 +*(>(out2,0.0),%*%(-(*(183_dpred,184_probs),*(184_probs,parsertemp146939)),t(W3))) +::STMT +FLOAT:n_components,n_features +LITERAL_FLOAT:1.0 +*(*(n_components,n_features),+(n_features,1.0)) +::STMT +MATRIX:parsertemp472298,I +LITERAL_FLOAT:0.0 +*(==(*(t(parsertemp472298),I),0.0),I) +::STMT +MATRIX:p,q,lambda +cast.FLOAT(%*%(t(p),+(q,*(lambda,p)))) +::STMT +LITERAL_FLOAT:0.999 +0.999 +::STMT +MATRIX:X +FLOAT:x +/(cast.FLOAT(-(x,X)),-(cast.FLOAT(X),cast.FLOAT(X))) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:1.0,2.0 +-(1.0,^(linear_terms,2.0)) +::STMT +MATRIX:output_values +LITERAL_FLOAT:0.3 +*(0.3,cast.FLOAT(output_values)) +::STMT +MATRIX:X,_sbcvar2948 +cast.FLOAT(%*%(t(-(X,_sbcvar2948)),-(X,_sbcvar2948))) +::STMT +MATRIX:P,parsertemp220889,Z,ZERODIAG,parsertemp220891 +FLOAT:int302,int765 +LITERAL_FLOAT:4.0 +-(*(P,4.0),/(*(/(int302,parsertemp220891),+(parsertemp220889,int765)),sum(*(Z,ZERODIAG)))) +::STMT +FLOAT:nc +LITERAL_FLOAT:1.0,20.0 +*(+(20.0,1.0),-(nc,1.0)) +::STMT +MATRIX:_sbcvar78,parsertemp22266 +FLOAT:int315 +LITERAL_FLOAT:2.0,10000.0 +/(^(-(_sbcvar78,/(parsertemp22266,int315)),2.0),/(%*%(rowSums(_sbcvar78),colSums(_sbcvar78)),10000.0)) +::STMT +MATRIX:negSampleMeans,negSamples +FLOAT:int318,int839 +LITERAL_FLOAT:149.0,150.0 +/(-(colSums(^(negSamples,int839)),*(150.0,^(negSampleMeans,int318))),149.0) +::STMT +MATRIX:A +LITERAL_FLOAT:1.0 ++(+(/(1.0,cast.FLOAT(A)),/(1.0,cast.FLOAT(A))),/(1.0,cast.FLOAT(A))) +::STMT +MATRIX:P,Q,parsertemp220896,Y,Z,ZERODIAG +-(*(Y,rowSums(*(parsertemp220896,Z))),%*%(*(-(P,Q),*(Z,ZERODIAG)),Y)) +::STMT +FLOAT:int496,parsertemp98,var,m4,parsertemp99,int864,parsertemp93,parsertemp94,wt,parsertemp105,parsertemp104 +LITERAL_FLOAT:4.0 +/(-(*(*(parsertemp93,parsertemp94),m4),*(*(parsertemp98,parsertemp99),-(wt,int496))),*(*(*(parsertemp104,parsertemp105),-(wt,int864)),^(sqrt(var),4.0))) +::STMT +MATRIX:parsertemp24101 +FLOAT:float99 +LITERAL_FLOAT:1.0 +<(+(round(-(parsertemp24101,float99)),1.0),1.0) +::STMT +MATRIX:parsertemp145796,y +LITERAL_FLOAT:-1.0 +rowSums(*(*(y,-1.0),parsertemp145796)) +::STMT +MATRIX:linear_terms +FLOAT:var_power,link_power,float123 +LITERAL_FLOAT:1.0 +/(^(linear_terms,/(-(float123,var_power),link_power)),-(1.0,var_power)) +::STMT +FLOAT:_sbcvar1751 +LITERAL_FLOAT:6.0 +-(6.0,_sbcvar1751) +::STMT +MATRIX:out2,parsertemp146940,184_dtemp,W3 +LITERAL_FLOAT:0.0 +colSums(*(>(out2,0.0),%*%(-(184_dtemp,parsertemp146940),t(W3)))) +::STMT +MATRIX:parsertemp555766,parsertemp555764,parsertemp555762,parsertemp555761,target +/(sum(-(*(parsertemp555761,parsertemp555762),*(parsertemp555764,parsertemp555766))),nrow(target)) +::STMT +MATRIX:parsertemp437192,parsertemp437191,parsertemp437190,mean,parsertemp437236,X,weight,parsertemp437188 +FLOAT:float202,int107 +LITERAL_FLOAT:2.0 ++(-(/(%*%(parsertemp437190,parsertemp437236),t(weight)),*(2.0,^(mean,int107))),/(*(/(parsertemp437191,parsertemp437192),%*%(parsertemp437190,X)),t(+(parsertemp437188,float202)))) +::STMT +MATRIX:parsertemp220896,W,Y,Z +LITERAL_FLOAT:300.0 +*(300.0,-(*(Y,rowSums(W)),%*%(*(parsertemp220896,Z),Y))) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:-1.0 +*(exp(*(exp(linear_terms),-1.0)),exp(linear_terms)) +::STMT +MATRIX:s +FLOAT:n +LITERAL_FLOAT:1.0 +*(/(1.0,s),n) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:3.141592653589793,1.0,2.0 +*(+(1.0,^(linear_terms,2.0)),3.141592653589793) +::STMT +MATRIX:CVars,CFreqs +LITERAL_FLOAT:1.0 +*(-(CFreqs,1.0),CVars) +::STMT +MATRIX:s,parsertemp44016,d +*(%*%(t(-(s,parsertemp44016)),d),%*%(t(-(s,parsertemp44016)),d)) +::STMT +MATRIX:P +sum(+(P,t(P))) +::STMT +MATRIX:A +FLOAT:a11,a12,int33,int524 +LITERAL_FLOAT:1.0 ++(+(+(/(int524,a11),/(int33,a12)),/(1.0,cast.FLOAT(A))),/(1.0,cast.FLOAT(A))) +::STMT +MATRIX:lambda,g,beta ++(g,*(cast.FLOAT(lambda),cast.FLOAT(beta))) +::STMT +MATRIX:Y +FLOAT:num_features,num_records +LITERAL_FLOAT:1.0 +*(-(num_records,num_features),-(ncol(Y),1.0)) +::STMT +MATRIX:L,m +FLOAT:sum +/(-(m,sum),L) +::STMT +FLOAT:e,initial_lr,decay +LITERAL_FLOAT:1.0 +*(initial_lr,/(1.0,+(1.0,*(decay,e)))) +::STMT +FLOAT:new_log_l,log_l +LITERAL_FLOAT:-1.0 ++(*(new_log_l,-1.0),log_l) +::STMT +MATRIX:r_CG,p_CG +FLOAT:rr_CG,old_rr_CG +LITERAL_FLOAT:0.0 ++(-(0.0,r_CG),*(/(rr_CG,old_rr_CG),p_CG)) +::STMT +LITERAL_FLOAT:1.0,2.0,100.0 +*(^(100.0,2.0),-(100.0,1.0)) +::STMT +MATRIX:parsertemp220911,dY,Y +-(+(Y,dY),parsertemp220911) +::STMT +MATRIX:X_train +LITERAL_FLOAT:2.0 +/(2.0,ncol(X_train)) +::STMT +MATRIX:parsertemp389218 +FLOAT:int620 +LITERAL_FLOAT:1.0E-17,1057.0 ++(sqrt(/(*(parsertemp389218,int620),1057.0)),1.0E-17) +::STMT +MATRIX:S,U,W +%*%(t(U),*(W,%*%(U,t(S)))) +::STMT +FLOAT:int602,avg_tot,sum_sq_y_test,n +LITERAL_FLOAT:1.0 +/(-(sum_sq_y_test,*(n,^(avg_tot,int602))),-(n,1.0)) +::STMT +MATRIX:W +LITERAL_FLOAT:1.0,2.0,4.0 +*(4.0,-(^(sum(W),2.0),1.0)) +::STMT +MATRIX:linear_terms +FLOAT:var_power,link_power +LITERAL_FLOAT:1.0 +^(linear_terms,/(-(1.0,var_power),link_power)) +::STMT +MATRIX:H +-(+(H,t(H)),diag(diag(H))) +::STMT +MATRIX:col +FLOAT:min_val +-(col,min_val) +::STMT +MATRIX:parsertemp146930,184_unnorm_probs,parsertemp146928,184_scores +FLOAT:int210,parsertemp146927 +rowSums(*(*(*(parsertemp146927,parsertemp146928),/(int210,parsertemp146930)),/(exp(184_scores),rowSums(184_unnorm_probs)))) +::STMT +MATRIX:W +LITERAL_FLOAT:1.0,2.0 +*(-(sum(round(W)),2.0),+(sum(round(W)),1.0)) +::STMT +MATRIX:P,parsertemp220889,Y,Z,parsertemp220891 +FLOAT:int593,int40,parsertemp220894,int923 +%*%(*(-(*(P,int923),/(Z,parsertemp220894)),*(/(int593,parsertemp220891),+(parsertemp220889,int40))),Y) +::STMT +MATRIX:W +LITERAL_FLOAT:5.0 ++(sum(round(W)),5.0) +::STMT +MATRIX:parsertemp437192,parsertemp437191,resp,X,parsertemp437188 +FLOAT:float205,int295 +LITERAL_FLOAT:2.0 +-(/(%*%(t(resp),^(X,int295)),t(+(parsertemp437188,float205))),*(2.0,^(/(parsertemp437191,parsertemp437192),2.0))) +::STMT +MATRIX:parsertemp386844,parsertemp386845 +LITERAL_FLOAT:0.0,2.0 +&(>(rowSums(|(parsertemp386844,parsertemp386845)),0.0),<(rowSums(|(parsertemp386844,parsertemp386845)),2.0)) +::STMT +MATRIX:parsertemp410977,W,H,parsertemp410974 +rowSums(/(*(H,%*%(parsertemp410974,parsertemp410977)),t(colSums(W)))) +::STMT +MATRIX:lambda,scale_X,gXY,beta +FLOAT:int58 +%*%(t(+(*(scale_X,gXY),*(lambda,beta))),+(*(scale_X,-(int58,gXY)),*(lambda,beta))) +::STMT +MATRIX:scale_X,X +%*%(diag(scale_X),%*%(t(X),X)) +::STMT +MATRIX:out2,parsertemp146942,184_dscores +FLOAT:int741 +LITERAL_FLOAT:2.0 +^(colSums(*(>(out2,int741),%*%(184_dscores,parsertemp146942))),2.0) +::STMT +MATRIX:p,q,lambda +%*%(t(p),+(q,*(lambda,p))) +::STMT +MATRIX:parsertemp220988,parsertemp220989,dY +LITERAL_FLOAT:300.0,2.0,0.9 +^(-(*(0.9,dY),*(300.0,-(parsertemp220988,parsertemp220989))),2.0) +::STMT +MATRIX:_sbcvar1674 +FLOAT:int964 +LITERAL_FLOAT:0.0,2.0 +INT:int411,parsertemp282730 +*(>(rand(parsertemp282730,int411,int964,2.0),0.0),_sbcvar1674) +::STMT +MATRIX:parsertemp555753,target +LITERAL_FLOAT:0.5 +/(sum(*(0.5,rowSums(parsertemp555753))),nrow(target)) +::STMT +MATRIX:W +sum(round(W)) +::STMT +MATRIX:one_featureX +FLOAT:287_x,287_y +LITERAL_FLOAT:2.0 +!(<(one_featureX,/(+(287_x,287_y),2.0))) +::STMT +LITERAL_FLOAT:0.0,1.0 +INT:n_components,parsertemp506195 +rowSums(rand(parsertemp506195,n_components,0.0,1.0)) +::STMT +MATRIX:R +FLOAT:s,i8 +-(ncol(R),*(s,i8)) +::STMT +MATRIX:p,r +FLOAT:norm_r2 +*(/(sum(*(r,r)),norm_r2),p) +::STMT +LITERAL_FLOAT:1.0,750.0 +*(750.0,1.0) +::STMT +MATRIX:ss,X2 +FLOAT:alpha +LITERAL_FLOAT:1.0 +*(-(1.0,alpha),-(/(nrow(X2),ss),1.0)) +::STMT +FLOAT:float634,parsertemp254709,parsertemp254694,2123_sq_root_d,pp_CG ++(float634,*(parsertemp254709,/(+(parsertemp254694,2123_sq_root_d),pp_CG))) +::STMT +FLOAT:b,int894,rad +/(int894,+(b,rad)) +::STMT +MATRIX:w,out +LITERAL_FLOAT:1.0,0.5 +*(1.0,+(*(0.5,cast.FLOAT(out)),*(0.5,cast.FLOAT(w)))) +::STMT +MATRIX:X +LITERAL_FLOAT:0.0 +colSums(!=(X,0.0)) +::STMT +LITERAL_FLOAT:0.1092173494617922 +0.1092173494617922 +::STMT +MATRIX:r_CG,g_reg,z +%*%(t(z),+(r_CG,g_reg)) +::STMT +MATRIX:parsertemp498247,m_iter_err_sum +FLOAT:i_process_item +LITERAL_FLOAT:0.0,2.0 +*(2.0,/(-(0.0,-(parsertemp498247,m_iter_err_sum)),i_process_item)) +::STMT +MATRIX:D,ZERODIAG +LITERAL_FLOAT:1.0 +*(/(1.0,+(D,1.0)),ZERODIAG) +::STMT +MATRIX:intercept,X,beta +LITERAL_FLOAT:1.0 +INT:num_records,int303 ++(%*%(X,beta),%*%(rand(num_records,int303,1.0,1.0),intercept)) +::STMT +MATRIX:t,parsertemp171088,parsertemp171083,parsertemp171094 +FLOAT:float536 +LITERAL_FLOAT:-1.0,1.0,2.515517 ++(*(sqrt(*(float536,parsertemp171083)),-1.0),/(+(2.515517,*(t,parsertemp171088)),+(1.0,*(t,parsertemp171094)))) +::STMT +MATRIX:y_batch,parsertemp146892 +LITERAL_FLOAT:0.0 +sum(*(-(0.0,y_batch),parsertemp146892)) +::STMT +MATRIX:output_values,current_prediction +LITERAL_FLOAT:0.3 ++(current_prediction,*(0.3,cast.FLOAT(output_values))) +::STMT +MATRIX:w_X,z_LS +LITERAL_FLOAT:1000.0 +/(1000.0,cast.FLOAT(%*%(t(w_X),z_LS))) +::STMT +MATRIX:B,S,X +%*%(X,+(B,S)) +::STMT +MATRIX:s,d,alpha +-(s,*(cast.FLOAT(alpha),d)) +::STMT +MATRIX:M +FLOAT:parsertemp178174 +cast.MATRIX(+(max(M),parsertemp178174)) +::STMT +MATRIX:parsertemp394988,W3_rand +FLOAT:int204,int625 +LITERAL_FLOAT:0.21483446221182986 +%*%(*(0.21483446221182986,W3_rand),t(/(-(parsertemp394988,int625),+(parsertemp394988,int204)))) +::STMT +MATRIX:F,parsertemp42207,parsertemp42208,438_Ranks +FLOAT:parsertemp42222,int325,meanY,meanX,int938 +*(t(*(/(F,parsertemp42222),-(438_Ranks,meanX))),-(+(-(parsertemp42207,parsertemp42208),/(int325,int938)),meanY)) +::STMT +FLOAT:2344_s_err_mean +LITERAL_FLOAT:-1.0,0.001 +-(*(0.001,-1.0),2344_s_err_mean) +::STMT +MATRIX:history +FLOAT:float452 +-(max(history),float452) +::STMT +MATRIX:colSD,colMean +LITERAL_FLOAT:3.0 +-(colMean,*(3.0,colSD)) +::STMT +MATRIX:is_zero_y_corr,is_one_y_corr +LITERAL_FLOAT:-0.36651292058166435 +*(-0.36651292058166435,+(is_zero_y_corr,is_one_y_corr)) +::STMT +MATRIX:H,betamin,beta +FLOAT:logU +LITERAL_FLOAT:0.0 ++(*(<(-(H,logU),0.0),betamin),*(>=(-(H,logU),0.0),beta)) +::STMT +MATRIX:parsertemp171087,parsertemp171084,parsertemp171093 +FLOAT:float298,float780 +LITERAL_FLOAT:1.0,2.515517 +/(+(2.515517,*(sqrt(parsertemp171084),+(float780,parsertemp171087))),+(1.0,*(sqrt(parsertemp171084),+(float298,parsertemp171093)))) +::STMT +MATRIX:parsertemp235660,parsertemp235671 +FLOAT:parsertemp235661 +LITERAL_FLOAT:0.0 +sum(*(-(0.0,/(parsertemp235660,parsertemp235661)),parsertemp235671)) +::STMT +MATRIX:qLow,length,qUp +rowSums(|(<(length,qLow),>(length,qUp))) +::STMT +MATRIX:_sbcvar1750 +FLOAT:_sbcvar1751 +LITERAL_FLOAT:6.0 +/(_sbcvar1750,-(6.0,_sbcvar1751)) +::STMT +MATRIX:intercept +LITERAL_FLOAT:1.0 +INT:num_records,int615 +%*%(rand(num_records,int615,1.0,1.0),intercept) +::STMT +FLOAT:cmLabels,int624,float396 +LITERAL_FLOAT:10000.0 +sqrt(*(cmLabels,/(10000.0,-(int624,float396)))) +::STMT +MATRIX:parsertemp98,X,Y +LITERAL_FLOAT:2.0 +/(abs(-(X,Y)),/(parsertemp98,2.0)) +::STMT +MATRIX:V +FLOAT:std_dev,int434,mu +*(>(V,+(mu,*(int434,std_dev))),V) +::STMT +MATRIX:V +FLOAT:std_dev,int654,mu +*(<(V,-(mu,*(int654,std_dev))),V) +::STMT +MATRIX:X,y +LITERAL_FLOAT:-1.0 +%*%(*(t(X),-1.0),y) +::STMT +MATRIX:X_batch,186_dX,parsertemp146949,parsertemp146957,parsertemp146955 +FLOAT:beta1 +LITERAL_FLOAT:1.0 +*(-(1.0,beta1),%*%(t(X_batch),*(*(parsertemp146957,parsertemp146955),%*%(186_dX,parsertemp146949)))) +::STMT +MATRIX:neighbors,corePts,withinEps +*(*(neighbors,corePts),withinEps) +::STMT +MATRIX:r_LS +FLOAT:norm_r2_LS,p_LS,parsertemp170552,lambda_LS ++(cast.FLOAT(r_LS),*(/(norm_r2_LS,*(p_LS,p_LS)),+(*(parsertemp170552,p_LS),*(lambda_LS,p_LS)))) +::STMT +FLOAT:parsertemp496735,var,arch_coef,var_coef,a0 +sqrt(+(+(a0,*(arch_coef,parsertemp496735)),*(var_coef,var))) +::STMT +LITERAL_FLOAT:0.0,2.0 +*(2.0,0.0) +::STMT +MATRIX:parsertemp171084,parsertemp171083 +LITERAL_FLOAT:-2.0,0.001308,0.189269 +*(sqrt(*(-2.0,parsertemp171083)),+(0.189269,*(sqrt(parsertemp171084),0.001308))) +::STMT +FLOAT:277_sq_root_d,parsertemp170093,pp_CG,pq_CG +LITERAL_FLOAT:0.5 +*(*(0.5,/(-(parsertemp170093,277_sq_root_d),pp_CG)),pq_CG) +::STMT +FLOAT:int199,s,num_groups,int805 +LITERAL_FLOAT:1.0 ++(+(*(-(s,int199),-(num_groups,int805)),1.0),num_groups) +::STMT +MATRIX:grad +FLOAT:int842 +LITERAL_FLOAT:0.1 +*(0.1,sqrt(sum(^(grad,int842)))) +::STMT +MATRIX:X,y +FLOAT:int276,int931,int845,int559 +INT:int786,m,int690 +*(-(%*%(X,rand(m,int690,int845,int559)),y),-(%*%(X,rand(m,int786,int931,int276)),y)) +::STMT +MATRIX:W +FLOAT:int221,int797,wt +LITERAL_FLOAT:1.0,3.0,6.0 +/(*(*(6.0,sum(W)),-(sum(W),1.0)),*(*(-(wt,int221),+(wt,int797)),+(sum(W),3.0))) +::STMT +MATRIX:scale_X,z,beta +%*%(diag(scale_X),+(beta,z)) +::STMT +MATRIX:X +-(X,round(X)) +::STMT +MATRIX:u,minDist +sum(!=(u,minDist)) +::STMT +LITERAL_FLOAT:0.0,1.0 +INT:int374,int767 +sum(rand(int374,int767,0.0,1.0)) +::STMT +MATRIX:parsertemp539203 +FLOAT:int106 +LITERAL_FLOAT:1.0,2.0,1.5 +min(^(/(*(parsertemp539203,int106),2.0),/(1.0,1.5))) +::STMT +FLOAT:ID +LITERAL_FLOAT:1.0,2.0 ++(*(2.0,ID),1.0) +::STMT +MATRIX:parsertemp472412,fP +FLOAT:max_values +<=(parsertemp472412,/(^($1:ncol(fP),max_values),$1)) +::STMT +MATRIX:prediction,target +rowSums(abs(-(prediction,target))) +::STMT +MATRIX:parsertemp382905,S,V,W,row_nonzeros +FLOAT:reg +*(S,+(%*%(*(W,parsertemp382905),V),*(*(reg,S),row_nonzeros))) +::STMT +MATRIX:X +cast.MATRIX(sum(X)) +::STMT +MATRIX:parsertemp43618,o +LITERAL_FLOAT:1.0 +/(1.0,+(1.0,exp(*(parsertemp43618,o)))) +::STMT +MATRIX:lambda,parsertemp286535,beta +FLOAT:float296 +LITERAL_FLOAT:0.0 +cast.FLOAT(%*%(t(+(float296,parsertemp286535)),+(0.0,*(lambda,beta)))) +::STMT +FLOAT:Hin +LITERAL_FLOAT:2.0,64.0 +*(64.0,/(/(Hin,2.0),2.0)) +::STMT +MATRIX:scale_X,w,ssX_p_CG,X +*(cast.FLOAT(diag(scale_X)),%*%(t(X),*(w,%*%(X,ssX_p_CG)))) +::STMT +MATRIX:s,w +FLOAT:step_sz +LITERAL_FLOAT:2.0 +^(+(w,*(step_sz,s)),2.0) +::STMT +MATRIX:Y +LITERAL_FLOAT:0.0 +==(rowSums(Y),0.0) +::STMT +MATRIX:Y,the_exp +FLOAT:int549 +-(*(rowSums(Y),exp(*(the_exp,int549))),Y) +::STMT +MATRIX:images +LITERAL_FLOAT:1.0,2.0,255.0 +-(*(/(images,255.0),2.0),1.0) +::STMT +FLOAT:parsertemp459295 +LITERAL_FLOAT:1.0,128.0 +-(+(+(parsertemp459295,1.0),128.0),1.0) +::STMT +MATRIX:negSampleMeans +LITERAL_FLOAT:2.0,150.0 +*(150.0,^(negSampleMeans,2.0)) +::STMT +MATRIX:output,output1 +LITERAL_FLOAT:0.0,0.1 +==(<(abs(-(output,output1)),0.1),0.0) +::STMT +MATRIX:_sbcvar1734 +FLOAT:_sbcvar1735 +LITERAL_FLOAT:12.0 +/(_sbcvar1734,-(12.0,_sbcvar1735)) +::STMT +MATRIX:r +FLOAT:int12 +LITERAL_FLOAT:9.999999999999998E-15 +sqrt(*(sum(^(r,int12)),9.999999999999998E-15)) +::STMT +MATRIX:parsertemp2832 +==(round(parsertemp2832),max(round(parsertemp2832))) +::STMT +MATRIX:A +-(A,t(A)) +::STMT +MATRIX:X,Y,K +FLOAT:x +LITERAL_FLOAT:1.0 +*(-(*(K,-(X,X)),-(Y,Y)),-(1.0,/(-(x,X),-(X,X)))) +::STMT +MATRIX:samples_vs_runs_map,centroid_placer,X_samples +rowSums(*(X_samples,%*%(samples_vs_runs_map,%*%(centroid_placer,X_samples)))) +::STMT +MATRIX:R,parsertemp72406,parsertemp72323 +LITERAL_FLOAT:2.0 +sum(^(-(%*%(parsertemp72323,R),diag(parsertemp72406)),2.0)) +::STMT +MATRIX:parsertemp220845,D,ZERODIAG +rowSums(*(*(exp(parsertemp220845),ZERODIAG),D)) +::STMT +MATRIX:X,y +FLOAT:int621,int319 +INT:int505,m +t(-(%*%(X,rand(m,int505,int319,int621)),y)) +::STMT +FLOAT:s_err_mean +LITERAL_FLOAT:-0.001 +-(-0.001,s_err_mean) +::STMT +FLOAT:batch,i,int558 +LITERAL_FLOAT:1.0 ++(+(*(-(i,int558),batch),1.0),batch) +::STMT +MATRIX:parsertemp447181,strings +/(parsertemp447181,length(strings)) +::STMT +FLOAT:a,b +LITERAL_FLOAT:2.0 +/(*(2.0,*(a,b)),+(a,b)) +::STMT +MATRIX:parsertemp500607,X,y,parsertemp500610 +%*%(t(X),-(%*%(X,*(parsertemp500607,parsertemp500610)),y)) +::STMT +MATRIX:g_Y,lambda,scale_X,parsertemp286673,beta +LITERAL_FLOAT:0.0 ++(*(scale_X,-(0.0,%*%(parsertemp286673,g_Y))),*(lambda,beta)) +::STMT +MATRIX:parsertemp170263,finite_linear_terms,parsertemp170261,the_exp +FLOAT:int120,int98,int745 +LITERAL_FLOAT:1.0 ++(*(-(1.0,==(parsertemp170263,int120)),-(1.0,exp(parsertemp170261))),*(*(==(parsertemp170263,int98),exp(finite_linear_terms)),-(1.0,/(the_exp,int745)))) +::STMT +MATRIX:r,c,E,_sbcvar78 +LITERAL_FLOAT:2.0,10000.0 +sum(/(^(-(_sbcvar78,E),2.0),/(%*%(r,c),10000.0))) +::STMT +MATRIX:X +FLOAT:val +>(X,val) +::STMT +MATRIX:parsertemp498248 +FLOAT:int60,i_process_item +LITERAL_FLOAT:2.0 +*(^(/(-(int60,parsertemp498248),i_process_item),2.0),i_process_item) +::STMT +MATRIX:R,w,ones_ns ++(R,diag(*(ones_ns,cast.FLOAT(w)))) +::STMT +MATRIX:foffb +LITERAL_FLOAT:1.0 +*(ncol(foffb),1.0) +::STMT +MATRIX:selCols2,maxscub +FLOAT:parsertemp31797 +LITERAL_FLOAT:-Infinity +&(selCols2,|(>=(maxscub,parsertemp31797),==(maxscub,-Infinity))) +::STMT +MATRIX:A +FLOAT:a11,a12,int566,int260 +LITERAL_FLOAT:1.0 ++(+(+(/(int260,a11),/(int566,a12)),/(1.0,cast.FLOAT(A))),/(1.0,cast.FLOAT(A))) +::STMT +MATRIX:t_gp,parsertemp170239 +FLOAT:float801,float299 +LITERAL_FLOAT:1.0,1.421413741,-1.453152027 ++(1.421413741,*(/(1.0,+(float299,parsertemp170239)),+(-1.453152027,*(t_gp,float801)))) +::STMT +MATRIX:275_X,275_curr_X +FLOAT:275_value +&(==(275_X,275_curr_X),>=(275_X,275_value)) +::STMT +MATRIX:float999,is_zero_y_corr,is_one_y_corr,parsertemp317445,parsertemp317451,parsertemp317462 +FLOAT:float898 +LITERAL_FLOAT:1.0 +-(+(*(+(parsertemp317451,parsertemp317462),1-*(float999,parsertemp317445)),/(is_one_y_corr,-(float898,is_one_y_corr))),/(is_zero_y_corr,-(1.0,is_zero_y_corr))) +::STMT +MATRIX:vW1,parsertemp146976 +FLOAT:parsertemp146975,191_beta2,191_epsilon ++(sqrt(+(*(191_beta2,vW1),*(parsertemp146975,parsertemp146976))),191_epsilon) +::STMT +MATRIX:X +LITERAL_FLOAT:3.0 +*(ncol(X),3.0) +::STMT +MATRIX:H,betamax,Hneg,Hpos,beta +FLOAT:float761 +LITERAL_FLOAT:0.0,2.0,1.0E20 +*(*(2.0,>=(-(H,float761),0.0)),==(+(*(Hpos,betamax),*(Hneg,beta)),1.0E20)) +::STMT +MATRIX:B +LITERAL_FLOAT:2.0 +*(ncol(B),2.0) +::STMT +MATRIX:parsertemp11251 +LITERAL_FLOAT:2.0 +^(2.0,parsertemp11251) +::STMT +MATRIX:parsertemp220853,Ws,beta +LITERAL_FLOAT:0.0,3.4011973816621555 +>=(-(+(parsertemp220853,*(beta,Ws)),3.4011973816621555),0.0) +::STMT +MATRIX:X +LITERAL_FLOAT:2.0,1764.0 +/(colSums(^(X,2.0)),1764.0) +::STMT +MATRIX:r,scale_X,shift_X ++(*(scale_X,r),%*%(shift_X,r)) +::STMT +MATRIX:_sbcvar1846 +FLOAT:_sbcvar1847 +LITERAL_FLOAT:11.0 +/(_sbcvar1846,-(11.0,_sbcvar1847)) +::STMT +LITERAL_FLOAT:2.0,2000.0 +^(2000.0,2.0) +::STMT +MATRIX:B +LITERAL_FLOAT:4.0 +*(ncol(B),4.0) +::STMT +MATRIX:parsertemp410118,g0_1,parsertemp410117 +%*%(t(+(g0_1,t(parsertemp410118))),+(g0_1,t(colSums(parsertemp410117)))) +::STMT +MATRIX:p,q +FLOAT:norm_r2 +*(/(norm_r2,sum(*(p,q))),p) +::STMT +FLOAT:float778,int543,parsertemp171819,parsertemp171815,sim_score_parent,int9,parsertemp171824,float6 +-(+(/(^(parsertemp171815,int543),+(parsertemp171819,float6)),/(^(parsertemp171824,int9),+(parsertemp171819,float778))),sim_score_parent) +::STMT +MATRIX:lambda,g,beta +sum(*(+(g,*(lambda,beta)),+(g,*(lambda,beta)))) +::STMT +FLOAT:dd,step_sz,wd ++(wd,*(step_sz,dd)) +::STMT +MATRIX:B +LITERAL_FLOAT:8.0 +*(ncol(B),8.0) +::STMT +MATRIX:R,parsertemp40220 +FLOAT:numRows +LITERAL_FLOAT:1.0 +-(/(numRows,-(R,rowSums(parsertemp40220))),1.0) +::STMT +MATRIX:parsertemp31188,parsertemp31186 +FLOAT:int947 +LITERAL_FLOAT:1.0,7000.0 +/(/(-(colSums(parsertemp31186),*(int947,parsertemp31188)),-(7000.0,1.0)),7000.0) +::STMT +MATRIX:f,parsertemp472177,I,parsertemp472179 +*(I,-(%*%(f,parsertemp472177),t(parsertemp472179))) +::STMT +FLOAT:m2X,W +LITERAL_FLOAT:1.0 +*(m2X,/(W,-(W,1.0))) +::STMT +FLOAT:m2X,parsertemp4,m2Y,parsertemp8,int635,int492 +*(sqrt(*(m2X,/(int492,parsertemp4))),sqrt(*(m2Y,/(int635,parsertemp8)))) +::STMT +MATRIX:y_corr +FLOAT:float657 +LITERAL_FLOAT:1.0,0.5 ++(y_corr,*(-(1.0,*(float657,y_corr)),>(y_corr,0.5))) +::STMT +MATRIX:parsertemp400660,W3_rand +FLOAT:int364,int747 +LITERAL_FLOAT:0.2656844656620286 +%*%(*(0.2656844656620286,W3_rand),t(/(-(parsertemp400660,int747),+(parsertemp400660,int364)))) +::STMT +MATRIX:ones,classFeatureCounts +LITERAL_FLOAT:1.0,50.0 +%*%(+(rowSums(classFeatureCounts),*(50.0,1.0)),ones) +::STMT +MATRIX:CFreqs1,present_domain_vals_mat,parsertemp27485 +FLOAT:my +LITERAL_FLOAT:2.0 +*(%*%(present_domain_vals_mat,CFreqs1),^(-(%*%(present_domain_vals_mat,parsertemp27485),my),2.0)) +::STMT +FLOAT:parsertemp13703 +LITERAL_FLOAT:999.0,1000.0 +*(999.0,/(*(parsertemp13703,1000.0),999.0)) +::STMT +MATRIX:p_LS,X +*(cast.FLOAT(%*%(t(X),X)),cast.FLOAT(p_LS)) +::STMT +MATRIX:cm,FD +FLOAT:int406,n +LITERAL_FLOAT:0.0 +!=(+(+(FD,==(cm,int406)),==(t(cm),n)),0.0) +::STMT +MATRIX:lambda,beta +LITERAL_FLOAT:0.0 +t(+(0.0,*(lambda,beta))) +::STMT +MATRIX:r,d,alpha,parsertemp44052,Hd +FLOAT:norm_r2 ++(-(r,*(cast.FLOAT(alpha),Hd)),*(/(sum(parsertemp44052),norm_r2),d)) +::STMT +MATRIX:fdom,parsertemp1688 +-(t(parsertemp1688),fdom) +::STMT +MATRIX:d,parsertemp410054 +FLOAT:r2 +/(r2,cast.FLOAT(%*%(t(d),t(parsertemp410054)))) +::STMT +MATRIX:p,lambda,scale_X,shift_X +FLOAT:q +*(p,+(+(*(scale_X,q),*(q,shift_X)),*(lambda,p))) +::STMT +MATRIX:B2,ytest,Xtest +%*%(t(-(ytest,%*%(Xtest,B2))),-(ytest,%*%(Xtest,B2))) +::STMT +MATRIX:parsertemp43632,X,y +LITERAL_FLOAT:0.0,2.0 ++(0.0,*(2.0,%*%(t(X),*(parsertemp43632,y)))) +::STMT +FLOAT:df,parsertemp437302,n,norm +LITERAL_FLOAT:-2.0 ++(*(*(-2.0,norm),n),*(df,parsertemp437302)) +::STMT +FLOAT:cols,parsertemp451837 +LITERAL_FLOAT:1.0 +-(+(+(*(parsertemp451837,cols),1.0),cols),1.0) +::STMT +MATRIX:codebook +FLOAT:j +*(j,ncol(codebook)) +::STMT +MATRIX:is_LT_infinite +LITERAL_FLOAT:1.0 +-(1.0,rowSums(is_LT_infinite)) +::STMT +LITERAL_FLOAT:1.02 +1.02 +::STMT +MATRIX:linear_terms +FLOAT:link_power +LITERAL_FLOAT:1.0 +-(1.0,^(linear_terms,/(1.0,link_power))) +::STMT +MATRIX:_sbcvar1830 +FLOAT:_sbcvar1831 +LITERAL_FLOAT:10.0 +/(_sbcvar1830,-(10.0,_sbcvar1831)) +::STMT +MATRIX:tmp +LITERAL_FLOAT:50.0 +*(50.0,cast.FLOAT(%*%(t(tmp),tmp))) +::STMT +MATRIX:y_hat +FLOAT:k,parsertemp176418 ++(sqrt(parsertemp176418),*(k,y_hat)) +::STMT +MATRIX:m_iter_err_sum,parsertemp498242 +FLOAT:i_process_item +LITERAL_FLOAT:0.0 +/(-(0.0,-(t(m_iter_err_sum),+(parsertemp498242,m_iter_err_sum))),i_process_item) +::STMT +LITERAL_FLOAT:1.0E-16 +1.0E-16 +::STMT +FLOAT:D,o +LITERAL_FLOAT:-2.0,-1.0,2.0 ++(*(-2.0,*(o,-1.0)),*(2.0,D)) +::STMT +MATRIX:W,parsertemp411099,X,H +LITERAL_FLOAT:1.0E-8 +/(%*%(t(W),X),+(%*%(%*%(parsertemp411099,W),H),1.0E-8)) +::STMT +MATRIX:g_reg,g,parsertemp285556 +FLOAT:parsertemp285562 +*(cast.FLOAT(%*%(t(g_reg),+(g,parsertemp285556))),parsertemp285562) +::STMT +MATRIX:X2 +LITERAL_FLOAT:4.0 +>=(t(colSums(X2)),4.0) +::STMT +MATRIX:select,d_r_rev,X_exp_Xb_rev_agg,D_r_rev +*(/(%*%(select,X_exp_Xb_rev_agg),D_r_rev),d_r_rev) +::STMT +MATRIX:w,ones_ns +diag(*(ones_ns,cast.FLOAT(w))) +::STMT +MATRIX:g +FLOAT:lambda,beta +LITERAL_FLOAT:2.0 +sum(^(+(g,*(lambda,beta)),2.0)) +::STMT +LITERAL_FLOAT:1.0,500.0 +*(500.0,1.0) +::STMT +MATRIX:s,w +LITERAL_FLOAT:1.0 +*(1.0,sum(*(w,s))) +::STMT +MATRIX:m_active_flag_tmp,m_active_flag +LITERAL_FLOAT:1.0 +>=(+(m_active_flag,m_active_flag_tmp),1.0) +::STMT +MATRIX:vW1,190_dW +FLOAT:191_beta2,int129,int49 +sqrt(+(*(191_beta2,vW1),*(-(int129,191_beta2),^(190_dW,int49)))) +::STMT +MATRIX:A +FLOAT:parsertemp12882 +LITERAL_FLOAT:1.0 +*(-(nrow(A),1.0),/(*(parsertemp12882,nrow(A)),-(nrow(A),1.0))) +::STMT +LITERAL_FLOAT:0.08709382882250233 +0.08709382882250233 +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:1.0,50.0 +/(+(classFeatureCounts,1.0),+(rowSums(classFeatureCounts),*(50.0,1.0))) +::STMT +MATRIX:parsertemp170665,residual_matrix,curr_prediction +FLOAT:282_lambda +LITERAL_FLOAT:2.0 +/(^(sum(residual_matrix),2.0),+(sum(*(curr_prediction,parsertemp170665)),282_lambda)) +::STMT +MATRIX:dY,g,parsertemp221002,Y +FLOAT:float831,float422 +-(+(Y,-(*(float422,dY),*(float831,g))),parsertemp221002) +::STMT +MATRIX:grad +LITERAL_FLOAT:-1.0 +*(*(grad,-1.0),*(grad,-1.0)) +::STMT +MATRIX:parsertemp379560,m_iter_err_sum +FLOAT:i_process_item +LITERAL_FLOAT:-1.0 +/(*(-(t(m_iter_err_sum),+(parsertemp379560,m_iter_err_sum)),-1.0),i_process_item) +::STMT +MATRIX:grad +FLOAT:int204,int415 +sqrt(sum(*(*(grad,int204),*(grad,int415)))) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:1.0,2.0 +/(1.0,^(linear_terms,2.0)) +::STMT +MATRIX:log_l_part_saturated +LITERAL_FLOAT:2.0 +*(2.0,sum(log_l_part_saturated)) +::STMT +FLOAT:eta,s,parsertemp454319 +*(parsertemp454319,^(eta,s)) +::STMT +MATRIX:output,Mask +LITERAL_FLOAT:1.0 +*(output,-(1.0,Mask)) +::STMT +MATRIX:paramLens,parsertemp387457 +/(parsertemp387457,rev(paramLens)) +::STMT +MATRIX:WM +FLOAT:parsertemp31268 +*(parsertemp31268,sum(WM)) +::STMT +MATRIX:X +LITERAL_FLOAT:2.0 +*(2.0,ncol(X)) +::STMT +LITERAL_FLOAT:1.0,2000.0 ++(2000.0,1.0) +::STMT +MATRIX:p,parsertemp1934,parsertemp1935 +FLOAT:eps +cast.FLOAT(%*%(t(p),+(%*%(parsertemp1934,parsertemp1935),*(eps,p)))) +::STMT +MATRIX:parsertemp410246,parsertemp410249 +FLOAT:float218,int106,int527,float484 +-(max(^(/(parsertemp410246,parsertemp410249),/(int106,float218))),min(^(/(parsertemp410246,parsertemp410249),/(int527,float484)))) +::STMT +MATRIX:XY_pairs_local,XY_pairs +|(XY_pairs,t(XY_pairs_local)) +::STMT +MATRIX:ssX_V,X,parsertemp150463,P_1K,parsertemp149251 +%*%(t(X),-(*(P_1K,%*%(X,ssX_V)),*(P_1K,%*%(parsertemp149251,parsertemp150463)))) +::STMT +MATRIX:parsertemp235671,I,y2 +LITERAL_FLOAT:0.0 +*(-(0.0,/(%*%(I,y2),sum(I))),parsertemp235671) +::STMT +MATRIX:X +FLOAT:N +%*%(t(/(colSums(X),N)),/(colSums(X),N)) +::STMT +MATRIX:parsertemp27746,parsertemp27872 +FLOAT:featureCorrection +-(%*%(parsertemp27872,t(parsertemp27746)),featureCorrection) +::STMT +MATRIX:_sbcvar1798 +FLOAT:_sbcvar1799 +LITERAL_FLOAT:9.0 +/(_sbcvar1798,-(9.0,_sbcvar1799)) +::STMT +FLOAT:num_records +LITERAL_FLOAT:1.0,100.0 +-(1.0,/(100.0,num_records)) +::STMT +MATRIX:d,X,logisticD +FLOAT:C +*(C,%*%(t(X),*(logisticD,%*%(X,d)))) +::STMT +MATRIX:parsertemp171083 +LITERAL_FLOAT:0.0,-2.0 +-(0.0,sqrt(*(-2.0,parsertemp171083))) +::STMT +MATRIX:S,addedE,addedX2 +FLOAT:level +*(==(%*%(S,t(addedX2)),level),t(addedE)) +::STMT +LITERAL_FLOAT:409.0 +409.0 +::STMT +FLOAT:parsertemp40812,m2,int727 +LITERAL_FLOAT:4.0 +^(sqrt(*(/(int727,parsertemp40812),m2)),4.0) +::STMT +FLOAT:int960,parsertemp285740,p_CG,pp_CG,parsertemp285757 +*(parsertemp285757,/(+(*(p_CG,int960),sqrt(parsertemp285740)),pp_CG)) +::STMT +FLOAT:n +LITERAL_FLOAT:1.0,2.0 +/(1.0,*(2.0,n)) +::STMT +LITERAL_FLOAT:5.0,2000.0 ++(2000.0,5.0) +::STMT +MATRIX:X +LITERAL_FLOAT:1.0E-12 +INT:int829,int420 ++(%*%(t(X),X),diag(rand(int829,int420,1.0E-12,1.0E-12))) +::STMT +MATRIX:is_row_in_samples,parsertemp79018 +LITERAL_FLOAT:3811.0 +-(3811.0,*(is_row_in_samples,parsertemp79018)) +::STMT +MATRIX:W,H +sum(%*%(W,H)) +::STMT +LITERAL_FLOAT:750.0 +750.0 +::STMT +LITERAL_FLOAT:0.08725945907447251 +0.08725945907447251 +::STMT +LITERAL_FLOAT:3.0,2000.0 ++(2000.0,3.0) +::STMT +MATRIX:scores,unnorm_probs,dprobs +rowSums(*(dprobs,/(exp(scores),rowSums(unnorm_probs)))) +::STMT +MATRIX:parsertemp472316,parsertemp472314,ig +FLOAT:min_leaf +max(*(&(>=(parsertemp472314,min_leaf),>=(parsertemp472316,min_leaf)),ig)) +::STMT +FLOAT:FN,FP,TP +*(+(TP,FP),+(TP,FN)) +::STMT +MATRIX:tmp,X,Y,out +t(-(%*%(t(X),*(out,Y)),tmp)) +::STMT +FLOAT:alpha +LITERAL_FLOAT:1.0,2.0 +-(1.0,/(alpha,2.0)) +::STMT +MATRIX:A,B +LITERAL_FLOAT:-1.0,2.0 +^(*(%*%(t(A),B),-1.0),2.0) +::STMT +LITERAL_FLOAT:1.432788 +1.432788 +::STMT +MATRIX:surv +LITERAL_FLOAT:0.5 +sum(<=(surv,0.5)) +::STMT +MATRIX:G,authorities,hubs +-(/(%*%(G,authorities),max(%*%(G,authorities))),hubs) +::STMT +MATRIX:X,parsertemp555606 +LITERAL_FLOAT:1.0 +/(%*%(t(-(X,parsertemp555606)),-(X,parsertemp555606)),-(nrow(X),1.0)) +::STMT +MATRIX:parsertemp42200,F +LITERAL_FLOAT:2.0 +-(parsertemp42200,/(rowSums(F),2.0)) +::STMT +MATRIX:R,parsertemp500307 +FLOAT:int715 +LITERAL_FLOAT:1.0 +INT:int807,int466,parsertemp500306,parsertemp500303 ++(%*%(rowSums(^(R,int715)),rand(int466,parsertemp500303,1.0,1.0)),%*%(rand(parsertemp500306,int807,1.0,1.0),t(rowSums(parsertemp500307)))) +::STMT +MATRIX:parsertemp171117,is_zero_y_corr,is_one_y_corr,parsertemp171113 +FLOAT:parsertemp171116,float156 +LITERAL_FLOAT:1.0 +-(+(-(parsertemp171113,*(parsertemp171116,parsertemp171117)),/(is_one_y_corr,-(float156,is_one_y_corr))),/(is_zero_y_corr,-(1.0,is_zero_y_corr))) +::STMT +MATRIX:parsertemp411207,parsertemp411209,W,parsertemp411198,H,parsertemp411200 +LITERAL_FLOAT:1.0E-8 ++(%*%(/(*(W,parsertemp411207),t(parsertemp411209)),/(*(H,parsertemp411198),t(parsertemp411200))),1.0E-8) +::STMT +MATRIX:subspace_idx,parsertemp72201 +FLOAT:subvector_size +-(subspace_idx,*(parsertemp72201,subvector_size)) +::STMT +MATRIX:p_CG,z +*(cast.FLOAT(z),sum(p_CG)) +::STMT +MATRIX:parsertemp459256 +LITERAL_FLOAT:5.0E-4 +*(5.0E-4,parsertemp459256) +::STMT +MATRIX:X +LITERAL_FLOAT:2.0 +^(/(t(colSums(X)),nrow(X)),2.0) +::STMT +FLOAT:o +LITERAL_FLOAT:-2.0,-1.0 +*(-2.0,*(o,-1.0)) +::STMT +MATRIX:dout1,vb1 +FLOAT:192_beta2 +LITERAL_FLOAT:1.0,2.0 ++(*(192_beta2,vb1),*(-(1.0,192_beta2),^(colSums(dout1),2.0))) +::STMT +MATRIX:X,Y +abs(-(X,Y)) +::STMT +MATRIX:parsertemp10744,W,H +FLOAT:Eps ++(%*%(W,%*%(*(H,parsertemp10744),t(H))),Eps) +::STMT +MATRIX:y_residual,parsertemp415351 +FLOAT:parsertemp415362,n,int152 +LITERAL_FLOAT:1.0 +-(1.0,/(sum(^(y_residual,int152)),-(sum(parsertemp415351),*(n,parsertemp415362)))) +::STMT +MATRIX:parsertemp10740,V,W,H +FLOAT:Eps +/(%*%(t(W),V),+(%*%(%*%(parsertemp10740,W),H),Eps)) +::STMT +MATRIX:in_m_data_target +LITERAL_FLOAT:100.0 +*(-(max(in_m_data_target),min(in_m_data_target)),100.0) +::STMT +MATRIX:parsertemp560919,parsertemp560920,elt,ones_ctg +LITERAL_FLOAT:1.0 +*(/(elt,%*%(rowSums(elt),t(ones_ctg))),%*%(/(elt,%*%(parsertemp560919,parsertemp560920)),-(1.0,diag(ones_ctg)))) +::STMT +MATRIX:termination_bitmap,parsertemp72096 +FLOAT:int497,worst_wcss +LITERAL_FLOAT:1.0,10.0 ++(*(parsertemp72096,termination_bitmap),*(+(*(int497,worst_wcss),10.0),-(1.0,termination_bitmap))) +::STMT +MATRIX:W1_rand,X,parsertemp401984,parsertemp401974 +FLOAT:float690 +LITERAL_FLOAT:0.06835859270246632 +%*%(*(0.06835859270246632,W1_rand),t(/(-(X,parsertemp401974),+(parsertemp401984,float690)))) +::STMT +MATRIX:I,y2 +LITERAL_FLOAT:0.0 +-(0.0,/(%*%(I,y2),sum(I))) +::STMT +MATRIX:M +exp(-(M,max(M))) +::STMT +MATRIX:entropy,parsertemp552397,resp,L +*(==(+(resp,t(parsertemp552397)),L),entropy) +::STMT +FLOAT:sd_X +sqrt(sd_X) +::STMT +FLOAT:j +LITERAL_FLOAT:1.0,4.0 ++(-(4.0,j),1.0) +::STMT +MATRIX:w_X,z_LS +LITERAL_FLOAT:10000.0 +*(/(10000.0,cast.FLOAT(%*%(w_X,z_LS))),z_LS) +::STMT +FLOAT:m2Y,sigmaX +LITERAL_FLOAT:1.0005 +*(sigmaX,sqrt(*(m2Y,1.0005))) +::STMT +FLOAT:deviance_nodisp +LITERAL_FLOAT:0.1,1.0E-12 +*(1.0E-12,+(deviance_nodisp,0.1)) +::STMT +MATRIX:parsertemp410979,W,X,parsertemp410981,parsertemp410983 +FLOAT:eps +*(W,%*%(/(X,+(parsertemp410983,eps)),t(/(parsertemp410979,parsertemp410981)))) +::STMT +FLOAT:n_components,n_features +LITERAL_FLOAT:1.0,2.0 +/(*(*(n_components,n_features),+(n_features,1.0)),2.0) +::STMT +MATRIX:mu +LITERAL_FLOAT:4.0 +*(4.0,*(cast.FLOAT(mu),cast.FLOAT(mu))) +::STMT +MATRIX:p,r,parsertemp1597,lambda,parsertemp1590,parsertemp1589 +FLOAT:norm_r2 ++(r,*(/(norm_r2,sum(parsertemp1597)),+(%*%(parsertemp1589,parsertemp1590),*(lambda,p)))) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:500.0 ++(rowSums(classFeatureCounts),500.0) +::STMT +MATRIX:parsertemp13658,parsertemp13659,_sbcvar12 +FLOAT:44_meanX +LITERAL_FLOAT:999.0,0.5 +*(/(_sbcvar12,999.0),-(+(-(parsertemp13658,parsertemp13659),0.5),44_meanX)) +::STMT +MATRIX:linear_terms +FLOAT:var_power,link_power,float674 +LITERAL_FLOAT:2.0 +/(^(linear_terms,/(-(float674,var_power),link_power)),-(2.0,var_power)) +::STMT +FLOAT:int435,int13 +INT:int92,int565 +rand(int565,int92,int435,int13) +::STMT +MATRIX:prec_chol,bc_matrix,parsertemp436690 +FLOAT:int898 +*(bc_matrix,t(*(rowSums(parsertemp436690),^(prec_chol,int898)))) +::STMT +MATRIX:X +FLOAT:q1,q2 +|(<(X,q1),>(X,q2)) +::STMT +FLOAT:ytest,int697,int876,parsertemp454072,parsertemp454076,int481,int619 +LITERAL_FLOAT:1.0 +-(1.0,/(-(^(ytest,int619),*(int481,parsertemp454072)),-(^(ytest,int697),*(int876,parsertemp454076)))) +::STMT +MATRIX:parsertemp477918,b +FLOAT:tolerance +LITERAL_FLOAT:2.0 +*(sum(^(%*%(parsertemp477918,b),2.0)),^(tolerance,2.0)) +::STMT +MATRIX:X +FLOAT:M +*(nrow(X),M) +::STMT +MATRIX:leaf_ids +FLOAT:boundary_left,step_size +<(leaf_ids,+(boundary_left,step_size)) +::STMT +FLOAT:m2 +LITERAL_FLOAT:2.0,3.0 +*(3.0,^(m2,2.0)) +::STMT +MATRIX:curr_prediction +FLOAT:int644 +LITERAL_FLOAT:0.0 ++(sum(*(curr_prediction,-(int644,curr_prediction))),0.0) +::STMT +MATRIX:A,scale_X,shift_X,parsertemp1656,parsertemp1655 ++(%*%(diag(scale_X),t(+(parsertemp1655,parsertemp1656))),%*%(shift_X,A)) +::STMT +FLOAT:run_index +LITERAL_FLOAT:24.0,1.0 ++(*(24.0,-(run_index,1.0)),1.0) +::STMT +FLOAT:acc +LITERAL_FLOAT:1.0,100.0 +-(1.0,/(acc,100.0)) +::STMT +FLOAT:log_ten,d_eee,x,float396 +*(x,exp(*(log_ten,-(float396,d_eee)))) +::STMT +FLOAT:int244,parsertemp459332,int646,parsertemp459334 +LITERAL_FLOAT:2.0 +sqrt(/(2.0,*(*(int244,parsertemp459332),/(parsertemp459334,int646)))) +::STMT +MATRIX:X +FLOAT:N +LITERAL_FLOAT:0.0 +-(0.0,/(t(colSums(X)),N)) +::STMT +MATRIX:parsertemp410245,parsertemp410248 +FLOAT:int56,float64 +LITERAL_FLOAT:1.0,1.5 +max(^(/(*(parsertemp410245,int56),*(float64,parsertemp410248)),/(1.0,1.5))) +::STMT +MATRIX:parsertemp429913,avg_X_cols +FLOAT:int179 +LITERAL_FLOAT:300.0,299.0 +/(-(t(colSums(parsertemp429913)),*(300.0,^(avg_X_cols,int179))),299.0) +::STMT +MATRIX:P_denom +LITERAL_FLOAT:0.0 +sum(<=(P_denom,0.0)) +::STMT +LITERAL_FLOAT:0.0,1.0 +INT:int842,int402 +cast.FLOAT(rand(int842,int402,0.0,1.0)) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:0.0,2.0 +*(2.0,>=(linear_terms,0.0)) +::STMT +MATRIX:p,q,r,parsertemp1947 +FLOAT:norm_r2,alpha +LITERAL_FLOAT:0.0 ++(-(0.0,+(r,*(alpha,q))),*(/(sum(parsertemp1947),norm_r2),p)) +::STMT +MATRIX:Y_prob,Y +*(*(rowSums(Y),Y_prob),Y_prob) +::STMT +MATRIX:g_new,g_old +/(sum(*(g_new,g_new)),sum(*(g_old,g_old))) +::STMT +MATRIX:r_LS +FLOAT:norm_r2_LS,p_LS,parsertemp170552,lambda_LS ++(r_LS,*(/(norm_r2_LS,*(p_LS,p_LS)),+(*(parsertemp170552,p_LS),*(lambda_LS,p_LS)))) +::STMT +MATRIX:samples_vs_runs_map,X_samples_sq_norms,centroids +FLOAT:int785 ++(X_samples_sq_norms,%*%(samples_vs_runs_map,rowSums(^(centroids,int785)))) +::STMT +FLOAT:e,epochs +LITERAL_FLOAT:1.0 +-(+(1.0,epochs),e) +::STMT +MATRIX:t,parsertemp171083 +FLOAT:float488,float22 +LITERAL_FLOAT:0.802853,2.515517 ++(2.515517,*(sqrt(*(float488,parsertemp171083)),+(0.802853,*(t,float22)))) +::STMT +MATRIX:W1_rand +LITERAL_FLOAT:0.07808688094430302 +*(0.07808688094430302,W1_rand) +::STMT +MATRIX:output1,dataset +LITERAL_FLOAT:0.16 +>=(abs(-(output1,dataset)),0.16) +::STMT +MATRIX:X,parsertemp438796 +t(*(ncol(X),parsertemp438796)) +::STMT +MATRIX:t,tmp +FLOAT:parsertemp477715,int875,x,X,Y,K +*(cast.FLOAT(t),+(*(-(K,Y),-(int875,parsertemp477715)),*(cast.FLOAT(tmp),/(x,X)))) +::STMT +MATRIX:parsertemp12846,F +FLOAT:W +LITERAL_FLOAT:2.0 +/(^(-(F,/(parsertemp12846,W)),2.0),/(%*%(rowSums(F),colSums(F)),sum(F))) +::STMT +MATRIX:std,rad ++(cast.FLOAT(std),cast.FLOAT(rad)) +::STMT +LITERAL_FLOAT:1.0,2.0,150.0 +*(^(150.0,2.0),-(150.0,1.0)) +::STMT +MATRIX:meanDiff,parsertemp570372,parsertemp570375 +LITERAL_FLOAT:0.5,-0.5 +-(*(-0.5,parsertemp570372),*(0.5,%*%(%*%(meanDiff,parsertemp570375),t(meanDiff)))) +::STMT +MATRIX:parsertemp570372 +LITERAL_FLOAT:-0.5 +*(-0.5,parsertemp570372) +::STMT +MATRIX:parsertemp31912,I +FLOAT:eAvg +/(/(t(%*%(parsertemp31912,I)),t(colSums(I))),eAvg) +::STMT +MATRIX:node +LITERAL_FLOAT:1.0,2.0 ++(*(node,2.0),1.0) +::STMT +MATRIX:lambda,g,beta +LITERAL_FLOAT:2.0 +^(+(g,*(cast.FLOAT(lambda),cast.FLOAT(beta))),2.0) +::STMT +MATRIX:p,Z +FLOAT:norm_r2 +/(norm_r2,cast.FLOAT(%*%(t(p),%*%(Z,p)))) +::STMT +MATRIX:posSamples,posSampleMeans +LITERAL_FLOAT:2.0,2000.0 +-(colSums(^(posSamples,2.0)),*(2000.0,^(posSampleMeans,2.0))) +::STMT +MATRIX:parsertemp170665,residual_matrix,curr_prediction +LITERAL_FLOAT:0.0,2.0 +/(^(sum(residual_matrix),2.0),+(sum(*(curr_prediction,parsertemp170665)),0.0)) +::STMT +MATRIX:m_err_vars,m_err_mean +LITERAL_FLOAT:-0.001 +/(-(-0.001,cast.FLOAT(m_err_mean)),cast.FLOAT(m_err_vars)) +::STMT +MATRIX:S,V +FLOAT:int586,delta2 +LITERAL_FLOAT:2.0 +*(sum(^(V,2.0)),-(delta2,sum(^(S,int586)))) +::STMT +MATRIX:parsertemp389212,parsertemp389215 +LITERAL_FLOAT:2.0,1058.0 +-(parsertemp389215,^(/(parsertemp389212,1058.0),2.0)) +::STMT +MATRIX:avg_res_Y,means,Y_counts,Y +LITERAL_FLOAT:2.0 +colSums(^(-(-(Y,means),%*%(Y_counts,avg_res_Y)),2.0)) +::STMT +FLOAT:w_i +LITERAL_FLOAT:5.0 +-(w_i,5.0) +::STMT +MATRIX:r,scale_X,shift_X,y,parsertemp116004 +LITERAL_FLOAT:2.0 +^(+(*(scale_X,%*%(parsertemp116004,y)),*(cast.FLOAT(r),shift_X)),2.0) +::STMT +MATRIX:S,X +LITERAL_FLOAT:1.0,2.0 +/(^(diag(S),2.0),-(nrow(X),1.0)) +::STMT +MATRIX:2699_dscores,parsertemp459193,parsertemp459183,parsertemp459190,2703_X,2703_W +LITERAL_FLOAT:5.0E-4 ++(%*%(t(2703_X),*(*(parsertemp459193,parsertemp459190),%*%(2699_dscores,parsertemp459183))),*(5.0E-4,2703_W)) +::STMT +MATRIX:parsertemp285809,p_CG,z +FLOAT:parsertemp285799,2235_sq_root_d,parsertemp285814 +LITERAL_FLOAT:0.5 ++(*(0.5,cast.FLOAT(%*%(z,parsertemp285809))),*(parsertemp285814,/(+(parsertemp285799,2235_sq_root_d),cast.FLOAT(p_CG)))) +::STMT +FLOAT:obj +LITERAL_FLOAT:1.0E-10 +*(1.0E-10,obj) +::STMT +MATRIX:n_event_stratum,n_risk_stratum,n_risk +*(*(n_risk,n_event_stratum),-(n_risk_stratum,n_event_stratum)) +::STMT +MATRIX:y_hat,X +*(-(X,y_hat),-(X,y_hat)) +::STMT +FLOAT:n +LITERAL_FLOAT:4.0 +-(n,4.0) +::STMT +MATRIX:X,X_nonzero_ind +LITERAL_FLOAT:0.0 +-(nrow(X),sum(!=(rowSums(X_nonzero_ind),0.0))) +::STMT +MATRIX:X,permut +FLOAT:n +*(/(colSums(%*%(permut,X)),n),/(colSums(%*%(permut,X)),n)) +::STMT +MATRIX:W1_rand,stds,parsertemp401986 +LITERAL_FLOAT:0.06835859270246632 +t(%*%(*(0.06835859270246632,W1_rand),t(/(parsertemp401986,stds)))) +::STMT +MATRIX:X +FLOAT:int416 +LITERAL_FLOAT:1.0 +sqrt(/(colSums(^(X,int416)),-(nrow(X),1.0))) +::STMT +MATRIX:U_OE +rowSums(rowSums(U_OE)) +::STMT +MATRIX:Y,Xd,Xw +FLOAT:step_sz +LITERAL_FLOAT:1.0 +-(1.0,*(Y,+(Xw,*(step_sz,Xd)))) +::STMT +FLOAT:s +LITERAL_FLOAT:-1.0,3.0 +^(3.0,*(s,-1.0)) +::STMT +LITERAL_FLOAT:1.000100010001 +1.000100010001 +::STMT +MATRIX:252_Y,252_K +FLOAT:252_X,float532 +LITERAL_FLOAT:1.0 +*(-(*(cast.FLOAT(252_K),-(252_X,252_X)),-(cast.FLOAT(252_Y),cast.FLOAT(252_Y))),-(1.0,/(-(float532,252_X),-(252_X,252_X)))) +::STMT +FLOAT:window_size,parsertemp181047,parsertemp181040 +LITERAL_FLOAT:1.0,2.0 +sqrt(*(*(2.0,window_size),-(1.0,/(parsertemp181040,parsertemp181047)))) +::STMT +MATRIX:b_cumulant,Y,natural_parameters +sum(-(*(Y,natural_parameters),b_cumulant)) +::STMT +LITERAL_FLOAT:0.07808688094430302 +0.07808688094430302 +::STMT +MATRIX:y_corr,is_zero_y_corr +FLOAT:float599,float550,float570,int718 +LITERAL_FLOAT:1.0,0.5 ++(*(*(y_corr,-(float599,is_zero_y_corr)),-(1.0,>=(y_corr,float550))),*(0.5,+(<=(y_corr,int718),>=(y_corr,float570)))) +::STMT +MATRIX:2212_oY +!(2212_oY) +::STMT +MATRIX:parsertemp129475 +LITERAL_FLOAT:1.0,2.0 +-(+(*(max(parsertemp129475),2.0),1.0),1.0) +::STMT +FLOAT:idx +LITERAL_FLOAT:1.0,1024.0 ++(-(1024.0,idx),1.0) +::STMT +MATRIX:resp,mean,X +t(*(-(X,mean),resp)) +::STMT +MATRIX:r_CG,q_CG +FLOAT:alpha_CG +LITERAL_FLOAT:2.0 +^(+(r_CG,*(alpha_CG,cast.FLOAT(q_CG))),2.0) +::STMT +MATRIX:221_CFreqs,221_present_domain_vals_mat,parsertemp27770 +FLOAT:int792 +LITERAL_FLOAT:1000.0 +/(sum(*(-(221_CFreqs,int792),%*%(221_present_domain_vals_mat,parsertemp27770))),-(1000.0,nrow(221_present_domain_vals_mat))) +::STMT +MATRIX:Xtest_dists +LITERAL_FLOAT:0.0 +<(0.0,Xtest_dists) +::STMT +MATRIX:selCols,ncCnts,maxsc +FLOAT:parsertemp31781 +LITERAL_FLOAT:0.0 +&(selCols,|(>(ncCnts,0.0),>(maxsc,parsertemp31781))) +::STMT +MATRIX:b,X,sb +*(X,exp(%*%(X,+(b,sb)))) +::STMT +MATRIX:R,addedE,parsertemp40215 +FLOAT:level ++(R,rowSums(*(==(parsertemp40215,level),t(addedE)))) +::STMT +FLOAT:step +LITERAL_FLOAT:0.9 +*(step,0.9) +::STMT +MATRIX:finite_linear_terms +LITERAL_FLOAT:0.0,2.0 +/(-(0.0,^(finite_linear_terms,2.0)),2.0) +::STMT +MATRIX:Q1,IQR +LITERAL_FLOAT:1.5 +-(Q1,*(1.5,IQR)) +::STMT +LITERAL_FLOAT:1.0E-6 +1.0E-6 +::STMT +MATRIX:ytest +LITERAL_FLOAT:1.0,2.0 +^(/(cast.FLOAT(ytest),1.0),2.0) +::STMT +MATRIX:finite_linear_terms +LITERAL_FLOAT:1.0,-1.0 +-(1.0,exp(*(exp(finite_linear_terms),-1.0))) +::STMT +MATRIX:ytest,yhat +LITERAL_FLOAT:2.0 +^(/(sum(-(ytest,yhat)),nrow(ytest)),2.0) +::STMT +MATRIX:CVars,CFreqs +LITERAL_FLOAT:1.0 +sum(*(-(CFreqs,1.0),CVars)) +::STMT +FLOAT:window_size,i,k ++(+(i,k),window_size) +::STMT +MATRIX:ss,X2 +/(nrow(X2),ss) +::STMT +MATRIX:X +LITERAL_FLOAT:3.0 +*(3.0,ncol(X)) +::STMT +MATRIX:grad +LITERAL_FLOAT:0.0,2.0 +sum(^(-(0.0,grad),2.0)) +::STMT +MATRIX:parsertemp129475,groupIndex +*(groupIndex,max(parsertemp129475)) +::STMT +FLOAT:run_index +LITERAL_FLOAT:1.0,48.0 +*(48.0,-(run_index,1.0)) +::STMT +MATRIX:scale_X,shift_X,X +FLOAT:int959 +LITERAL_FLOAT:2.0 ++(%*%(^(X,2.0),^(scale_X,2.0)),%*%(X,*(*(int959,scale_X),shift_X))) +::STMT +MATRIX:parsertemp171083 +FLOAT:float680 +LITERAL_FLOAT:0.010328,0.802853 ++(0.802853,*(sqrt(*(float680,parsertemp171083)),0.010328)) +::STMT +MATRIX:g +LITERAL_FLOAT:0.01 +*(0.01,cast.FLOAT(%*%(t(g),g))) +::STMT +LITERAL_FLOAT:1.0,2.0,2001.0 +-(^(2001.0,2.0),1.0) +::STMT +MATRIX:parsertemp539203,T,event +FLOAT:int620 +LITERAL_FLOAT:1.0,2.0,1.5 +/(^(/(*(parsertemp539203,int620),2.0),/(1.0,1.5)),/(-(max(T),min(T)),sum(event))) +::STMT +FLOAT:int263 +LITERAL_FLOAT:0.0,2.0 +INT:int475,parsertemp282730 +>(rand(parsertemp282730,int475,int263,2.0),0.0) +::STMT +MATRIX:p,q,g,z +FLOAT:pq,float62,tau_1 ++(+(*(*(float62,tau_1),pq),sum(*(z,q))),sum(*(g,p))) +::STMT +MATRIX:prob,pred +FLOAT:threshold +*(pred,>(prob,threshold)) +::STMT +MATRIX:out,parsertemp2798 +FLOAT:int94,int771,int10,int37 +sum(*(*(>(out,int771),-(int94,parsertemp2798)),*(>(out,int37),-(int10,parsertemp2798)))) +::STMT +MATRIX:Q,R +FLOAT:int517 +LITERAL_FLOAT:2.0 ++(rowSums(^(R,2.0)),t(rowSums(^(Q,int517)))) +::STMT +FLOAT:float812,parsertemp382948,parsertemp382957,loss_init,parsertemp382950 +/(-(loss_init,+(*(float812,parsertemp382948),*(parsertemp382950,parsertemp382957))),loss_init) +::STMT +MATRIX:n_corr,Y +FLOAT:int495 +LITERAL_FLOAT:0.0,0.5 ++(/(Y,+(rowSums(Y),==(n_corr,int495))),*(-(0.5,Y),==(rowSums(Y),0.0))) +::STMT +FLOAT:level +LITERAL_FLOAT:2.0 +-(level,2.0) +::STMT +MATRIX:linear_terms +FLOAT:var_power +LITERAL_FLOAT:1.0,2.0 +/(/(1.0,^(linear_terms,2.0)),-(1.0,var_power)) +::STMT +MATRIX:w_X,z_LS +LITERAL_FLOAT:2000.0 +*(/(2000.0,cast.FLOAT(%*%(w_X,z_LS))),z_LS) +::STMT +MATRIX:X +/(colSums(X),nrow(X)) +::STMT +MATRIX:parsertemp389300 +LITERAL_FLOAT:1.0,2.0 ++(exp(*(2.0,t(parsertemp389300))),1.0) +::STMT +LITERAL_FLOAT:5.0E-7 +5.0E-7 +::STMT +MATRIX:r +FLOAT:tolerance +LITERAL_FLOAT:2.0 +*(sum(^(r,2.0)),^(tolerance,2.0)) +::STMT +MATRIX:parsertemp42223,parsertemp42224,parsertemp42209 +FLOAT:parsertemp42210,meanY +sum(*(t(*(parsertemp42223,parsertemp42224)),-(+(parsertemp42209,parsertemp42210),meanY))) +::STMT +MATRIX:2134_left,2134_right +LITERAL_FLOAT:0.0,2.0 ++(/(^(sum(2134_left),2.0),+(nrow(2134_left),0.0)),/(^(sum(2134_right),2.0),+(nrow(2134_right),0.0))) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0 +-(i,1.0) +::STMT +MATRIX:Y,parsertemp2798,Xw +FLOAT:int178 +LITERAL_FLOAT:0.0,1.0 +*(*(>(-(int178,parsertemp2798),0.0),-(1.0,*(Y,Xw))),Y) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:1.0,500.0 +/(+(classFeatureCounts,1.0),+(rowSums(classFeatureCounts),*(500.0,1.0))) +::STMT +MATRIX:termination_bitmap,parsertemp222665 +LITERAL_FLOAT:10.0 +*(10.0,max(*(parsertemp222665,termination_bitmap))) +::STMT +MATRIX:U,V +LITERAL_FLOAT:2.0 ++(sum(^(U,2.0)),sum(^(V,2.0))) +::STMT +MATRIX:instance,X,mask +*(-(X,instance),mask) +::STMT +MATRIX:X,parsertemp129018 +LITERAL_FLOAT:1.0 ++(*(max(parsertemp129018),-(ncol(X),1.0)),1.0) +::STMT +MATRIX:parsertemp220900,dY,parsertemp220899 +FLOAT:lr,momentum +LITERAL_FLOAT:2.0 +^(-(*(momentum,dY),*(lr,-(parsertemp220899,parsertemp220900))),2.0) +::STMT +MATRIX:parsertemp175066,scores,dprobs +*(dprobs,/(exp(-(scores,parsertemp175066)),rowSums(exp(scores)))) +::STMT +MATRIX:solution,X +sum(*(-(X,solution),-(X,solution))) +::STMT +MATRIX:Q,lambda,V,X,parsertemp149253 +*(V,+(%*%(t(X),-(Q,parsertemp149253)),*(lambda,V))) +::STMT +MATRIX:r,alpha,Hd +*(-(r,*(cast.FLOAT(alpha),Hd)),-(r,*(cast.FLOAT(alpha),Hd))) +::STMT +MATRIX:G,minDist +LITERAL_FLOAT:0.0 ++(G,*(!=(G,0.0),minDist)) +::STMT +MATRIX:parsertemp12846,F +FLOAT:W +LITERAL_FLOAT:2.0 +/(^(-(F,/(parsertemp12846,W)),2.0),/(%*%(rowSums(F),colSums(F)),sum(F))) +::STMT +MATRIX:parsertemp409532,ctab,parsertemp409528 +LITERAL_FLOAT:0.4 +*(parsertemp409532,>(/(parsertemp409528,rowSums(ctab)),0.4)) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:1.0,2.0 +/(1.0,-(1.0,^(linear_terms,2.0))) +::STMT +MATRIX:w,g +FLOAT:alpha +abs(-(w,/(g,alpha))) +::STMT +MATRIX:Xtrain,Xtest,X,Y +-(+(sum(X),sum(Y)),+(sum(Xtrain),sum(Xtest))) +::STMT +MATRIX:parsertemp42200,R +FLOAT:int137,meanX +LITERAL_FLOAT:0.5 +-(+(-(parsertemp42200,/(R,int137)),0.5),meanX) +::STMT +MATRIX:newbeta,lambda +LITERAL_FLOAT:2.0 +cast.FLOAT(%*%(t(lambda),^(newbeta,2.0))) +::STMT +MATRIX:W4_rand +LITERAL_FLOAT:0.08681986202598489 +*(0.08681986202598489,W4_rand) +::STMT +MATRIX:scale_X +cast.FLOAT(diag(scale_X)) +::STMT +MATRIX:q,r +FLOAT:alpha +LITERAL_FLOAT:2.0 +sum(^(+(r,*(alpha,q)),2.0)) +::STMT +MATRIX:_sbcvar12,parsertemp13660 +FLOAT:float545,44_meanX +LITERAL_FLOAT:999.0 +t(*(/(_sbcvar12,999.0),-(+(parsertemp13660,float545),44_meanX))) +::STMT +MATRIX:2701_mask,2700_W,2726_dpred,parsertemp459177,2699_probs,2702_X +LITERAL_FLOAT:0.0,0.5 +*(*(>(2702_X,0.0),/(2701_mask,0.5)),%*%(-(*(2726_dpred,2699_probs),*(2699_probs,parsertemp459177)),t(2700_W))) +::STMT +MATRIX:std,sts,rad +FLOAT:delta2 +/(-(delta2,sts),+(cast.FLOAT(std),cast.FLOAT(rad))) +::STMT +MATRIX:w,out +LITERAL_FLOAT:0.5,0.001 +*(0.001,+(*(0.5,cast.FLOAT(out)),*(0.5,cast.FLOAT(w)))) +::STMT +MATRIX:A +FLOAT:parsertemp12882 +LITERAL_FLOAT:1.0 +/(*(parsertemp12882,nrow(A)),-(nrow(A),1.0)) +::STMT +MATRIX:eVals,eVecs +FLOAT:int192 +%*%(%*%(eVecs,diag(^(eVals,int192))),t(eVecs)) +::STMT +MATRIX:log_det_chol +FLOAT:int840,int149 +INT:int669,parsertemp436708 +*(rand(int669,parsertemp436708,int149,int840),log_det_chol) +::STMT +MATRIX:b,H_inv +/(b,sqrt(diag(H_inv))) +::STMT +FLOAT:alpha +LITERAL_FLOAT:1.0 +-(1.0,alpha) +::STMT +MATRIX:X +LITERAL_FLOAT:2.0 +t(rowSums(^(X,2.0))) +::STMT +MATRIX:X +FLOAT:lambda +LITERAL_FLOAT:2.0,0.5 +*(*(0.5,lambda),sum(^(X,2.0))) +::STMT +MATRIX:X_Xd_exp_Xb_rev_agg,select,X_exp_Xb_rev_agg,D_r_rev,Xd_exp_Xb_rev_agg +LITERAL_FLOAT:2.0 +-(/(%*%(select,X_Xd_exp_Xb_rev_agg),D_r_rev),/(*(X_exp_Xb_rev_agg,%*%(select,Xd_exp_Xb_rev_agg)),^(D_r_rev,2.0))) +::STMT +MATRIX:linear_terms,Y +LITERAL_FLOAT:0.0,1.0 +*(linear_terms,-(1.0,==(Y,0.0))) +::STMT +MATRIX:C,I +FLOAT:ss ++(%*%(t(C),C),*(I,ss)) +::STMT +MATRIX:X +LITERAL_FLOAT:1.8378770664093453 +*(ncol(X),1.8378770664093453) +::STMT +MATRIX:parsertemp31104,parsertemp31106 +FLOAT:int180 +LITERAL_FLOAT:1999.0,2.0 +^(/(-(colSums(parsertemp31104),*(int180,parsertemp31106)),1999.0),2.0) +::STMT +LITERAL_FLOAT:-0.284496736 +-0.284496736 +::STMT +FLOAT:width +LITERAL_FLOAT:2.0 +*(2.0,^(width,2.0)) +::STMT +MATRIX:parsertemp560880,parsertemp560876,parsertemp560863,parsertemp560868 +FLOAT:float715,float721,int346,int38 +LITERAL_FLOAT:1.0,2.0 +*(*(*(/(float715,parsertemp560868),+(float721,parsertemp560876)),-(*(int346,parsertemp560863),1.0)),exp(/(*(parsertemp560880,int38),2.0))) +::STMT +LITERAL_FLOAT:0.45 +0.45 +::STMT +MATRIX:COMPONENTS,id +-(==(id,cast.FLOAT(id)),cast.FLOAT(diag(diag(COMPONENTS)))) +::STMT +MATRIX:parsertemp130875 +LITERAL_FLOAT:1.0,4.0 +-(+(*(max(parsertemp130875),4.0),1.0),1.0) +::STMT +FLOAT:int84,se_g1,int223,int467,int512,parsertemp113,wt +sqrt(/(*(*(int512,parsertemp113),^(se_g1,int84)),*(+(wt,int467),-(wt,int223)))) +::STMT +MATRIX:gs +LITERAL_FLOAT:-0.5 +*(-0.5,cast.FLOAT(gs)) +::STMT +MATRIX:s,parsertemp44016,d +cast.FLOAT(%*%(t(-(s,parsertemp44016)),d)) +::STMT +MATRIX:samples_vs_runs_map,X_samples_sq_norms,parsertemp222444,is_row_in_samples,parsertemp222440 +LITERAL_FLOAT:2.0 +*(is_row_in_samples,-(+(X_samples_sq_norms,%*%(samples_vs_runs_map,parsertemp222440)),*(2.0,rowSums(parsertemp222444)))) +::STMT +MATRIX:X,parsertemp16892 +FLOAT:int275 +%*%(sqrt(rowSums(^(X,int275))),t(sqrt(rowSums(parsertemp16892)))) +::STMT +MATRIX:eVals +LITERAL_FLOAT:-1.0 +diag(^(eVals,-1.0)) +::STMT +MATRIX:s,w,wnew,parsertemp44079 +FLOAT:int330,C +LITERAL_FLOAT:0.5 ++(*(0.5,%*%(t(wnew),+(w,s))),*(C,sum(*(parsertemp44079,int330)))) +::STMT +MATRIX:X +FLOAT:x +LITERAL_FLOAT:1.0 ++(-(nrow(X),sum(>=(X,x))),1.0) +::STMT +MATRIX:parsertemp503368,B +LITERAL_FLOAT:-1.0,2.0 +sum(^(*(%*%(parsertemp503368,B),-1.0),2.0)) +::STMT +LITERAL_FLOAT:0.9 +0.9 +::STMT +MATRIX:g0_2,g0_1 +FLOAT:tol +LITERAL_FLOAT:2.0 +*(sum(^(+(g0_1,g0_2),2.0)),^(tol,2.0)) +::STMT +FLOAT:wcss +LITERAL_FLOAT:1.0E-5 +*(1.0E-5,wcss) +::STMT +MATRIX:WM,CVars,parsertemp31290,CFreqs,parsertemp31285 +LITERAL_FLOAT:1.0 +/(/(sum(*(CFreqs,parsertemp31285)),-(nrow(CFreqs),1.0)),/(sum(*(parsertemp31290,CVars)),-(sum(WM),nrow(CFreqs)))) +::STMT +MATRIX:q,ssX_p,scale_X,shift_X,X ++(*(scale_X,%*%(t(X),%*%(X,ssX_p))),*(cast.FLOAT(q),shift_X)) +::STMT +MATRIX:mean,X,parsertemp437224,weight +/(-(%*%(t(X),X),%*%(*(parsertemp437224,weight),mean)),sum(weight)) +::STMT +MATRIX:parsertemp43635,w +sqrt(sum(*(+(w,parsertemp43635),+(w,parsertemp43635)))) +::STMT +MATRIX:p_CG +FLOAT:parsertemp170148,int378,z,int271 +LITERAL_FLOAT:0.5 +*(0.5,/(-(*(z,int378),sqrt(parsertemp170148)),sum(^(p_CG,int271)))) +::STMT +MATRIX:252_Y +FLOAT:252_X,int54,int127,parsertemp32925,int877,parsertemp32915,float189,parsertemp32934,float807 ++(+(*(-(int877,parsertemp32915),cast.FLOAT(252_Y)),*(/(float807,252_X),cast.FLOAT(252_Y))),*(*(/(float189,252_X),-(int54,parsertemp32915)),+(*(parsertemp32925,int127),*(parsertemp32934,parsertemp32915)))) +::STMT +LITERAL_FLOAT:2.29128784747792 +2.29128784747792 +::STMT +MATRIX:parsertemp146931,184_dtemp,parsertemp146929,184_unnorm_probs,parsertemp146936,outr2 +%*%(t(outr2),-(*(*(parsertemp146929,parsertemp146931),/(184_unnorm_probs,parsertemp146936)),*(/(184_unnorm_probs,parsertemp146936),rowSums(184_dtemp)))) +::STMT +MATRIX:252_Y,252_X +LITERAL_FLOAT:4.5 +*(/(-(4.5,cast.FLOAT(252_X)),-(cast.FLOAT(252_X),cast.FLOAT(252_X))),cast.FLOAT(252_Y)) +::STMT +MATRIX:parsertemp16755 +LITERAL_FLOAT:2.0 +^(2.0,cast.FLOAT(parsertemp16755)) +::STMT +MATRIX:WM,CFreqs +-(sum(WM),nrow(CFreqs)) +::STMT +MATRIX:IQR +LITERAL_FLOAT:1.5 +*(1.5,IQR) +::STMT +MATRIX:sv,out +LITERAL_FLOAT:0.5 +*(0.5,sum(*(*(sv,out),*(sv,out)))) +::STMT +MATRIX:W1_rand,X,parsertemp394884,parsertemp394894 +FLOAT:float244 +LITERAL_FLOAT:0.08146881698903526 +%*%(*(0.08146881698903526,W1_rand),t(/(-(X,parsertemp394884),+(parsertemp394894,float244)))) +::STMT +MATRIX:u,parsertemp500604 +FLOAT:alpha,tau +LITERAL_FLOAT:0.0 +*(*(parsertemp500604,-(abs(u),/(tau,alpha))),>(-(abs(u),/(tau,alpha)),0.0)) +::STMT +MATRIX:V,W,H,parsertemp10749 +FLOAT:Eps +*(W,/(%*%(V,t(H)),+(%*%(W,parsertemp10749),Eps))) +::STMT +FLOAT:e,int622,mu,epochs +LITERAL_FLOAT:0.999 ++(mu,/(-(0.999,mu),-(+(int622,epochs),e))) +::STMT +MATRIX:hubs +FLOAT:parsertemp30953 +LITERAL_FLOAT:2.0 +sum(^(-(/(hubs,parsertemp30953),hubs),2.0)) +::STMT +MATRIX:_funvar2124,parsertemp437267,parsertemp437272 +exp(-(+(_funvar2124,parsertemp437267),parsertemp437272)) +::STMT +MATRIX:q_CG +FLOAT:alpha_CG +*(alpha_CG,cast.FLOAT(q_CG)) +::STMT +FLOAT:n_features +LITERAL_FLOAT:1.0,2.0 +/(*(n_features,+(n_features,1.0)),2.0) +::STMT +MATRIX:posSampleVariances,negSampleVariances +FLOAT:int496,int127,int944,int391 +LITERAL_FLOAT:2.0,3352500.0,990000.0 +/(^(+(/(posSampleVariances,int496),/(negSampleVariances,int127)),2.0),+(/(^(posSampleVariances,int391),990000.0),/(^(negSampleVariances,int944),3352500.0))) +::STMT +MATRIX:avg_X_cols,parsertemp1513 +FLOAT:int956,n +LITERAL_FLOAT:1.0 +/(-(t(colSums(parsertemp1513)),*(n,^(avg_X_cols,int956))),-(n,1.0)) +::STMT +FLOAT:n_group_cols +LITERAL_FLOAT:1.0,3.0 +-(+(3.0,n_group_cols),1.0) +::STMT +FLOAT:float725,float58 +LITERAL_FLOAT:0.0,0.5 +INT:int612,int943,int668,int29 +*(0.5,%*%(t(rand(int943,int612,float725,float58)),rand(int29,int668,0.0,0.0))) +::STMT +FLOAT:deviance_nodisp +LITERAL_FLOAT:0.1,1.0E-6 +*(1.0E-6,+(deviance_nodisp,0.1)) +::STMT +MATRIX:tmp_Xw,Y,parsertemp2773,Xw +LITERAL_FLOAT:0.0,1.0 +*(-(1.0,*(Y,+(Xw,parsertemp2773))),>(-(1.0,*(Y,tmp_Xw)),0.0)) +::STMT +MATRIX:parsertemp410976,W,H,X +/(*(H,%*%(t(W),/(X,parsertemp410976))),t(colSums(W))) +::STMT +MATRIX:surv +LITERAL_FLOAT:1.0 +sqrt(-(1.0,surv)) +::STMT +MATRIX:parsertemp539203 +LITERAL_FLOAT:-1.0,2.0,0.6666666666666666 +^(/(*(parsertemp539203,-1.0),2.0),0.6666666666666666) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,253.0 +-(+(i,253.0),1.0) +::STMT +MATRIX:r,c,F +LITERAL_FLOAT:2.0 +^(-(F,/(%*%(r,c),sum(F))),2.0) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:50.0 ++(rowSums(classFeatureCounts),50.0) +::STMT +MATRIX:parsertemp389219,X,permut +FLOAT:parsertemp389220,n +LITERAL_FLOAT:1.0E-17 +/(-(%*%(permut,X),/(colSums(X),n)),+(sqrt(/(parsertemp389219,parsertemp389220)),1.0E-17)) +::STMT +MATRIX:r,c,E,F +LITERAL_FLOAT:2.0 +sum(/(^(-(F,E),2.0),/(%*%(r,c),sum(F)))) +::STMT +MATRIX:W +FLOAT:parsertemp112,int710,parsertemp91 +LITERAL_FLOAT:2.0,3.0,4.0,5.0 +/(*(*(4.0,-(parsertemp112,int710)),^(sqrt(parsertemp91),2.0)),*(+(sum(W),5.0),-(sum(W),3.0))) +::STMT +MATRIX:classCounts +LITERAL_FLOAT:1000.0 +/(classCounts,1000.0) +::STMT +MATRIX:P,I +LITERAL_FLOAT:1.0 +&(I,<=(rowSums(P),1.0)) +::STMT +FLOAT:beg +LITERAL_FLOAT:1.0,256.0 +-(+(beg,256.0),1.0) +::STMT +MATRIX:parsertemp410977,W,H,parsertemp410974 +t(/(*(H,%*%(parsertemp410974,parsertemp410977)),t(colSums(W)))) +::STMT +MATRIX:ss,se,e +LITERAL_FLOAT:40.0 +/(/(se,ss),/(sum(e),40.0)) +::STMT +MATRIX:X +LITERAL_FLOAT:0.0 +-(0.0,/(t(colSums(X)),nrow(X))) +::STMT +MATRIX:X +FLOAT:x +LITERAL_FLOAT:1.0 +*(/(-(x,X),-(X,X)),-(1.0,/(-(x,X),-(X,X)))) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:105.0 ++(rowSums(classFeatureCounts),105.0) +::STMT +MATRIX:m_iter_err_sum,m_err ++(colSums(m_err),m_iter_err_sum) +::STMT +MATRIX:w,out +FLOAT:int362,int565 +LITERAL_FLOAT:0.5 ++(*(0.5,sum(^(out,int362))),*(0.5,sum(^(w,int565)))) +::STMT +MATRIX:M +LITERAL_FLOAT:2.0 +<(rowSums(M),2.0) +::STMT +LITERAL_FLOAT:1.0,100.0 ++(+(100.0,100.0),1.0) +::STMT +MATRIX:2663_X +LITERAL_FLOAT:1.0 +*(1.0,ncol(2663_X)) +::STMT +MATRIX:Ileft,_funvar2707 +FLOAT:numI +*(/(rowSums(Ileft),numI),_funvar2707) +::STMT +MATRIX:parsertemp31276,CVars +FLOAT:int850,parsertemp31269,W,parsertemp31270 +LITERAL_FLOAT:1.0 +-(1.0,/(sum(*(parsertemp31276,CVars)),*(-(W,int850),/(parsertemp31269,parsertemp31270)))) +::STMT +LITERAL_FLOAT:0.010328 +0.010328 +::STMT +MATRIX:parsertemp220863,parsertemp220864,Hdiff,beta +FLOAT:int935 +LITERAL_FLOAT:2.0,1.0E20 +*(*(*(2.0,>=(Hdiff,int935)),==(+(parsertemp220863,parsertemp220864),1.0E20)),beta) +::STMT +MATRIX:d,sb +LITERAL_FLOAT:2.0 +*(2.0,sum(*(sb,d))) +::STMT +MATRIX:parsertemp31190,parsertemp31197 +FLOAT:int867,int372 +LITERAL_FLOAT:3.42951E11,2.0,3.37275E9 ++(/(^(/(parsertemp31190,int372),2.0),3.42951E11),/(^(/(parsertemp31197,int867),2.0),3.37275E9)) +::STMT +LITERAL_FLOAT:2.0,100.0 ++(+(100.0,100.0),2.0) +::STMT +MATRIX:linear_terms +FLOAT:link_power +LITERAL_FLOAT:0.0,-1.0 +-(0.0,^(linear_terms,/(-1.0,link_power))) +::STMT +LITERAL_FLOAT:0.0 +/(0.0,0.0) +::STMT +MATRIX:LT,Y,parsertemp149320 +*(Y,-(LT,parsertemp149320)) +::STMT +MATRIX:posSamples,posSampleMeans +FLOAT:int113,int998 +LITERAL_FLOAT:99.0,100.0 +/(-(colSums(^(posSamples,int998)),*(100.0,^(posSampleMeans,int113))),99.0) +::STMT +MATRIX:cumLeftHist,parsertemp132494,leftHist,outBucket +%*%(==(outBucket,t(parsertemp132494)),-(cumLeftHist,leftHist)) +::STMT +MATRIX:U +LITERAL_FLOAT:1.0E-6 +*(1.0E-6,U) +::STMT +MATRIX:D,ZERODIAG +FLOAT:int802 +LITERAL_FLOAT:1.0 +sum(*(/(1.0,+(D,int802)),ZERODIAG)) +::STMT +MATRIX:_sbcvar0 +LITERAL_FLOAT:2000.0 +/(_sbcvar0,2000.0) +::STMT +MATRIX:X +LITERAL_FLOAT:1.0E-14 +>(abs(-(X,round(X))),1.0E-14) +::STMT +MATRIX:X,Y +LITERAL_FLOAT:1.0 ++(+(ncol(X),ncol(Y)),1.0) +::STMT +MATRIX:lambda,beta +*(cast.FLOAT(lambda),cast.FLOAT(beta)) +::STMT +MATRIX:ss,se,e +LITERAL_FLOAT:20.0 +/(/(se,ss),/(sum(e),20.0)) +::STMT +MATRIX:parsertemp43632,X,y +LITERAL_FLOAT:0.0,2.0 +INT:int440,int584 ++(rand(int440,int584,0.0,0.0),*(2.0,%*%(t(X),*(parsertemp43632,y)))) +::STMT +MATRIX:parsertemp477718,parsertemp477715,parsertemp477724,X,Y,parsertemp477733,K,parsertemp477730 +FLOAT:x +LITERAL_FLOAT:1.0 ++(*(-(*(K,parsertemp477724),-(Y,Y)),-(1.0,/(parsertemp477715,parsertemp477718))),*(+(*(parsertemp477730,parsertemp477733),-(Y,Y)),/(-(x,X),-(X,X)))) +::STMT +MATRIX:R,dssm +FLOAT:2_n +/(2_n,-(R,dssm)) +::STMT +MATRIX:n_risk_stratum,n_risk_i2j +FLOAT:I_i1i2 +-(I_i1i2,/(n_risk_i2j,n_risk_stratum)) +::STMT +MATRIX:parsertemp410978,W,H +t(/(*(H,t(parsertemp410978)),t(colSums(W)))) +::STMT +FLOAT:parsertemp89,parsertemp88,parsertemp83,parsertemp84 +LITERAL_FLOAT:2.0 +^(sqrt(/(*(parsertemp83,parsertemp84),*(parsertemp88,parsertemp89))),2.0) +::STMT +MATRIX:Y +LITERAL_FLOAT:0.0 +<=(Y,0.0) +::STMT +MATRIX:parsertemp383173 +FLOAT:reg,parsertemp383181,loss_init +/(-(loss_init,+(sum(parsertemp383173),*(reg,parsertemp383181))),loss_init) +::STMT +MATRIX:parsertemp437549,pred,parsertemp437666 +t(colSums(==(*(parsertemp437666,parsertemp437549),pred))) +::STMT +MATRIX:R,parsertemp40219,parsertemp40216,parsertemp40225 +FLOAT:level +/(+(R,rowSums(*(parsertemp40216,parsertemp40225))),-(R,rowSums(==(parsertemp40219,level)))) +::STMT +MATRIX:termination_bitmap,parsertemp222665 +max(*(parsertemp222665,termination_bitmap)) +::STMT +MATRIX:features,beta_unscaled +FLOAT:intercept,parsertemp176418 +LITERAL_FLOAT:3.0 ++(sqrt(parsertemp176418),*(3.0,+(%*%(features,beta_unscaled),intercept))) +::STMT +MATRIX:id +diag(diag(==(id,cast.FLOAT(id)))) +::STMT +MATRIX:parsertemp145796,parsertemp145794,y +/(sum(rowSums(*(parsertemp145794,parsertemp145796))),nrow(y)) +::STMT +MATRIX:Xd,out +FLOAT:dd,step_sz,wd +/(-(+(wd,*(step_sz,dd)),sum(out)),+(dd,sum(Xd))) +::STMT +MATRIX:X +LITERAL_FLOAT:4.0 +<=(X,4.0) +::STMT +MATRIX:X,y,logisticnew +LITERAL_FLOAT:1.0 +%*%(t(X),*(-(logisticnew,1.0),y)) +::STMT +MATRIX:classCounts +LITERAL_FLOAT:8000.0 +/(classCounts,8000.0) +::STMT +MATRIX:parsertemp570381,parsertemp570372,parsertemp570376,parsertemp570377 +LITERAL_FLOAT:0.5,-0.5 ++(parsertemp570381,-(*(-0.5,parsertemp570372),*(0.5,%*%(parsertemp570376,parsertemp570377)))) +::STMT +MATRIX:parsertemp31762,X2 +FLOAT:minSup +LITERAL_FLOAT:0.0 +&(>=(t(colSums(X2)),minSup),>(t(%*%(parsertemp31762,X2)),0.0)) +::STMT +MATRIX:parsertemp220896,W,Y,Z +FLOAT:lr +*(lr,-(*(Y,rowSums(W)),%*%(*(parsertemp220896,Z),Y))) +::STMT +MATRIX:X +FLOAT:N +t(/(colSums(X),N)) +::STMT +MATRIX:classesUnBalanced,classesBalanced +cast.FLOAT(-(classesUnBalanced,classesBalanced)) +::STMT +MATRIX:posSampleMeans +LITERAL_FLOAT:2.0,7000.0 +*(7000.0,^(posSampleMeans,2.0)) +::STMT +MATRIX:r,c,E,F +LITERAL_FLOAT:2.0 +sum(/(^(-(F,E),2.0),/(%*%(r,c),sum(F)))) +::STMT +MATRIX:scale_X,X,z,beta +*(*(cast.FLOAT(diag(scale_X)),+(cast.FLOAT(beta),cast.FLOAT(z))),X) +::STMT +MATRIX:X,y +LITERAL_FLOAT:0.0 +-(0.0,%*%(-(0.0,t(X)),y)) +::STMT +MATRIX:linear_terms +FLOAT:link_power +LITERAL_FLOAT:1.0 +^(linear_terms,/(1.0,link_power)) +::STMT +MATRIX:linear_terms +FLOAT:link_power +LITERAL_FLOAT:-1.0 +^(linear_terms,/(-1.0,link_power)) +::STMT +MATRIX:V1,parsertemp539081 +FLOAT:range,I_i1i2 +LITERAL_FLOAT:2.0 +/(sum(*(V1,-(I_i1i2,parsertemp539081))),^(range,2.0)) +::STMT +MATRIX:surv +LITERAL_FLOAT:0.5 +<=(surv,0.5) +::STMT +MATRIX:parsertemp410070,r +FLOAT:r2 +/(cast.FLOAT(%*%(t(r),+(r,parsertemp410070))),r2) +::STMT +MATRIX:W +FLOAT:parsertemp65,parsertemp66 +LITERAL_FLOAT:1.0,2.0,3.0 +*(*(-(sum(W),1.0),-(sum(W),2.0)),^(sqrt(/(parsertemp65,parsertemp66)),3.0)) +::STMT +MATRIX:C,Xm,parsertemp265701,XtZ +FLOAT:ZtZ_sum +sum(%*%(%*%(Xm,%*%(C,parsertemp265701)),t(/(XtZ,ZtZ_sum)))) +::STMT +MATRIX:R,dsep,dssm +FLOAT:2_eAvg +LITERAL_FLOAT:1.0 +-(/(/(+(R,dsep),-(R,dssm)),2_eAvg),1.0) +::STMT +MATRIX:pred +LITERAL_FLOAT:1.0,1.0E-10 +/(1.0,+(pred,1.0E-10)) +::STMT +MATRIX:t_gp,parsertemp171332,pt_gp,parsertemp171331,Y,the_gauss_exp,parsertemp171327,parsertemp171316 +LITERAL_FLOAT:2.0,0.25,0.3989422804014327 +/(*(0.3989422804014327,+(-(Y,parsertemp171327),*(parsertemp171331,parsertemp171332))),*(*(0.25,*(t_gp,parsertemp171316)),-(2.0,*(the_gauss_exp,pt_gp)))) +::STMT +MATRIX:p,r,parsertemp503395,Z +FLOAT:norm_r2 ++(r,*(/(norm_r2,cast.FLOAT(parsertemp503395)),%*%(Z,p))) +::STMT +MATRIX:Xtest_dists +LITERAL_FLOAT:1.0 +<=(Xtest_dists,1.0) +::STMT +MATRIX:parsertemp43626 +LITERAL_FLOAT:-1.0 +*(-1.0,sum(parsertemp43626)) +::STMT +MATRIX:parsertemp415524,y +FLOAT:intercept +LITERAL_FLOAT:2.0 +sum(^(-(y,+(parsertemp415524,intercept)),2.0)) +::STMT +MATRIX:parsertemp279509 +FLOAT:int374 +LITERAL_FLOAT:1000.0,100.0 +*(/(sum(==(parsertemp279509,int374)),1000.0),100.0) +::STMT +MATRIX:X +LITERAL_FLOAT:0.0 +rowSums(!=(X,0.0)) +::STMT +MATRIX:col_nonzeros,U,parsertemp382849,V,parsertemp382852 +LITERAL_FLOAT:1.0E-6 ++(t(%*%(t(U),*(parsertemp382849,parsertemp382852))),*(*(1.0E-6,V),col_nonzeros)) +::STMT +MATRIX:R,parsertemp72406 +-(%*%(t(R),R),diag(parsertemp72406)) +::STMT +FLOAT:log_ten,parsertemp169812 +LITERAL_FLOAT:0.5 +round(-(/(parsertemp169812,log_ten),0.5)) +::STMT +MATRIX:W,X,H +FLOAT:eps +%*%(t(W),/(X,+(%*%(W,H),eps))) +::STMT +MATRIX:is_LT_infinite,Y_prob,Y,parsertemp171293,flip_pos +rowSums(*(*(Y,%*%(Y_prob,flip_pos)),+(*(Y_prob,parsertemp171293),is_LT_infinite))) +::STMT +MATRIX:prevTK2,X2 +colSums(==(%*%(X2,t(prevTK2)),t(rowSums(prevTK2)))) +::STMT +MATRIX:lambda,p_CG,shift_X,parsertemp170070,temp_CG +*(p_CG,+(+(*(lambda,p_CG),%*%(parsertemp170070,temp_CG)),%*%(shift_X,temp_CG))) +::STMT +LITERAL_FLOAT:2001.0 +sqrt(2001.0) +::STMT +MATRIX:W4_rand +LITERAL_FLOAT:0.08709382882250233 +*(0.08709382882250233,W4_rand) +::STMT +MATRIX:parsertemp414371,scale_X +LITERAL_FLOAT:0.0,200.0 +*(-(0.0,/(t(parsertemp414371),200.0)),scale_X) +::STMT +MATRIX:r,c,_sbcvar78 +LITERAL_FLOAT:2.0,10000.0 +^(-(_sbcvar78,/(%*%(r,c),10000.0)),2.0) +::STMT +MATRIX:linear_terms +FLOAT:var_power,float434 +LITERAL_FLOAT:1.0 +/(exp(*(linear_terms,-(float434,var_power))),-(1.0,var_power)) +::STMT +MATRIX:samples_vs_runs_map,centroid_placer,X_samples +*(X_samples,%*%(samples_vs_runs_map,%*%(centroid_placer,X_samples))) +::STMT +FLOAT:int463,parsertemp40812,m2 +LITERAL_FLOAT:3.0 +^(sqrt(*(/(int463,parsertemp40812),m2)),3.0) +::STMT +MATRIX:X_nonzero_ind +LITERAL_FLOAT:0.0 +sum(!=(t(colSums(X_nonzero_ind)),0.0)) +::STMT +MATRIX:CMeans,CFreqs +FLOAT:parsertemp31266,W +LITERAL_FLOAT:2.0 +*(CFreqs,^(-(CMeans,/(parsertemp31266,W)),2.0)) +::STMT +MATRIX:parsertemp386449,corePts +FLOAT:int440 +LITERAL_FLOAT:0.0,1.0 +&(==(t(corePts),0.0),>(colSums(>(parsertemp386449,int440)),1.0)) +::STMT +MATRIX:output,output1 +LITERAL_FLOAT:0.1 +>=(abs(-(output,output1)),0.1) +::STMT +MATRIX:codes,codebook +*(ncol(codes),ncol(codebook)) +::STMT +MATRIX:p_LS,parsertemp170551,X +FLOAT:lambda_LS +*(p_LS,+(%*%(%*%(parsertemp170551,X),p_LS),*(lambda_LS,p_LS))) +::STMT +FLOAT:parsertemp13703 +LITERAL_FLOAT:1.0,1000.0 +*(-(1000.0,1.0),/(*(parsertemp13703,1000.0),-(1000.0,1.0))) +::STMT +MATRIX:xs +FLOAT:252_x +LITERAL_FLOAT:1.0,10.0 ++(-(10.0,sum(>=(xs,252_x))),1.0) +::STMT +MATRIX:parsertemp1510,scale_X +FLOAT:n +LITERAL_FLOAT:-1.0 +*(*(/(t(parsertemp1510),n),-1.0),scale_X) +::STMT +MATRIX:posSampleVariances,negSampleVariances +FLOAT:int842,int402,int569,int812 +LITERAL_FLOAT:2.0,3352500.0,990000.0 +/(^(+(/(posSampleVariances,int569),/(negSampleVariances,int842)),2.0),+(/(^(posSampleVariances,int402),990000.0),/(^(negSampleVariances,int812),3352500.0))) +::STMT +LITERAL_FLOAT:0.189269 +0.189269 +::STMT +FLOAT:k,kmax,start_stepsize +LITERAL_FLOAT:1.0 +*(-(1.0,/(k,kmax)),start_stepsize) +::STMT +MATRIX:p_CG,z +FLOAT:trust_delta_sq,pp_CG +-(*(cast.FLOAT(%*%(p_CG,z)),cast.FLOAT(%*%(p_CG,z))),*(pp_CG,-(cast.FLOAT(z),trust_delta_sq))) +::STMT +FLOAT:idx +LITERAL_FLOAT:1.0,1048.0 ++(-(1048.0,idx),1.0) +::STMT +LITERAL_FLOAT:0.0,0.025 +INT:parsertemp410939,rnk +rand(parsertemp410939,rnk,0.0,0.025) +::STMT +MATRIX:P,parsertemp220844,ZERODIAG,beta +LITERAL_FLOAT:1.0E-12 +/(*(exp(*(parsertemp220844,beta)),ZERODIAG),+(rowSums(*(P,ZERODIAG)),1.0E-12)) +::STMT +MATRIX:X,y +LITERAL_FLOAT:0.0 +t(-(0.0,%*%(t(X),y))) +::STMT +MATRIX:g0_1,parsertemp410117 +t(+(g0_1,t(colSums(parsertemp410117)))) +::STMT +MATRIX:A +LITERAL_FLOAT:1.0 ++(length(A),1.0) +::STMT +MATRIX:252_Y,252_X +FLOAT:252_X,252_K,int803 +LITERAL_FLOAT:4.5 +*(+(*(-(int803,252_K),-(252_X,252_X)),-(cast.FLOAT(252_Y),cast.FLOAT(252_Y))),/(-(4.5,cast.FLOAT(252_X)),-(cast.FLOAT(252_X),cast.FLOAT(252_X)))) +::STMT +FLOAT:approx_sample_size +LITERAL_FLOAT:10.0 +round(*(10.0,sqrt(approx_sample_size))) +::STMT +MATRIX:linear_terms,Y +LITERAL_FLOAT:-0.0 +*(^(linear_terms,-0.0),-(Y,linear_terms)) +::STMT +MATRIX:linear_terms,Y +LITERAL_FLOAT:-2.0 +*(^(linear_terms,-2.0),-(Y,linear_terms)) +::STMT +MATRIX:parsertemp220863,parsertemp220864,H,betamax,Hneg,beta,Hpos +FLOAT:INF,logU +LITERAL_FLOAT:0.0 +*(*(>=(-(H,logU),0.0),!=(+(parsertemp220863,parsertemp220864),INF)),+(beta,+(*(Hpos,betamax),*(Hneg,beta)))) +::STMT +MATRIX:w,ssX_p_CG,X +%*%(t(X),*(w,%*%(X,ssX_p_CG))) +::STMT +FLOAT:j +LITERAL_FLOAT:1.0,3.0 ++(-(3.0,j),1.0) +::STMT +MATRIX:parsertemp400674,W4_rand,parsertemp400677 +LITERAL_FLOAT:0.08720414403938946 +t(%*%(*(0.08720414403938946,W4_rand),t(/(parsertemp400674,parsertemp400677)))) +::STMT +MATRIX:parsertemp496901 +FLOAT:std +*(cast.FLOAT(parsertemp496901),std) +::STMT +FLOAT:cmLabels +LITERAL_FLOAT:1.000100010001 +sqrt(*(cmLabels,1.000100010001)) +::STMT +MATRIX:parsertemp31403,classFeatureCounts +LITERAL_FLOAT:105.0,1.0 +%*%(+(rowSums(classFeatureCounts),*(105.0,1.0)),parsertemp31403) +::STMT +MATRIX:X +FLOAT:x +LITERAL_FLOAT:1.0 +*(/(-(x,X),-(X,X)),-(1.0,/(-(x,X),-(X,X)))) +::STMT +MATRIX:X,Y +FLOAT:x +*(/(-(x,cast.FLOAT(X)),-(cast.FLOAT(X),cast.FLOAT(X))),cast.FLOAT(Y)) +::STMT +MATRIX:lambda,parsertemp170067,parsertemp170065,p_CG,shift_X,w,parsertemp170066,X,parsertemp170060 ++(+(*(lambda,p_CG),*(cast.FLOAT(parsertemp170060),%*%(parsertemp170065,parsertemp170067))),*(cast.FLOAT(shift_X),%*%(t(X),*(w,parsertemp170066)))) +::STMT +MATRIX:parsertemp175076,parsertemp175080,R1 +abs(-(R1,/(exp(parsertemp175076),rowSums(parsertemp175080)))) +::STMT +MATRIX:parsertemp437190,resp,X,weight +LITERAL_FLOAT:2.22E-16 +/(*(/(%*%(parsertemp437190,X),t(weight)),%*%(t(resp),X)),t(+(colSums(resp),2.22E-16))) +::STMT +MATRIX:upd_W1,X_batch,W1_grad +FLOAT:mu,step +-(*(mu,upd_W1),*(/(step,nrow(X_batch)),W1_grad)) +::STMT +FLOAT:i,n +LITERAL_FLOAT:1.0,64.0 +-(n,-(+(i,64.0),1.0)) +::STMT +MATRIX:simplex +/(-(rowSums(simplex),simplex),nrow(simplex)) +::STMT +LITERAL_FLOAT:-0.05,0.05 +INT:parsertemp411077,rnk +rand(parsertemp411077,rnk,-0.05,0.05) +::STMT +MATRIX:oldE +FLOAT:parsertemp32107 +/(sum(oldE),parsertemp32107) +::STMT +FLOAT:norm_Grad_initial +LITERAL_FLOAT:1.0E-8 +*(1.0E-8,norm_Grad_initial) +::STMT +MATRIX:parsertemp414375,parsertemp414377 +FLOAT:int880 +LITERAL_FLOAT:0.0,199.0 +<=(/(-(t(parsertemp414375),*(int880,parsertemp414377)),199.0),0.0) +::STMT +MATRIX:R,parsertemp40216,parsertemp40226 +FLOAT:eAvg +/(/(+(R,rowSums(parsertemp40226)),+(R,rowSums(parsertemp40216))),eAvg) +::STMT +FLOAT:high,low +LITERAL_FLOAT:2.0 +/(+(low,high),2.0) +::STMT +MATRIX:45_CFreqs +LITERAL_FLOAT:1000.0 +-(1000.0,nrow(45_CFreqs)) +::STMT +LITERAL_FLOAT:0.128920512778062 +0.128920512778062 +::STMT +MATRIX:X +FLOAT:int242,n +-(/(colSums(^(X,int242)),n),*(/(colSums(X),n),/(colSums(X),n))) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:1.0,50.0 +/(+(classFeatureCounts,1.0),+(rowSums(classFeatureCounts),50.0)) +::STMT +MATRIX:parsertemp43631,parsertemp43633,w +LITERAL_FLOAT:2.0 +*(+(w,*(2.0,%*%(parsertemp43631,parsertemp43633))),+(w,*(2.0,%*%(parsertemp43631,parsertemp43633)))) +::STMT +MATRIX:Y,Xd,out +FLOAT:dd,step_sz,wd +-(+(wd,*(step_sz,dd)),sum(*(*(out,Y),Xd))) +::STMT +MATRIX:Y_counts,means,parsertemp560512,parsertemp560516 +LITERAL_FLOAT:2.0 +*(Y_counts,-(rowSums(*(means,parsertemp560516)),^(rowSums(parsertemp560512),2.0))) +::STMT +MATRIX:parsertemp500608,parsertemp500604,parsertemp500605,w +FLOAT:lambda +LITERAL_FLOAT:0.0 +-(*(*(parsertemp500604,-(parsertemp500605,lambda)),>(-(parsertemp500608,lambda),0.0)),w) +::STMT +MATRIX:X,y,logisticnew +FLOAT:C,int545 +*(C,%*%(t(X),*(-(logisticnew,int545),y))) +::STMT +MATRIX:rowSums_X_sq +max(sqrt(rowSums_X_sq)) +::STMT +MATRIX:Y,parsertemp171319 +FLOAT:float554 +LITERAL_FLOAT:0.15915494309189535 +*(*(exp(/(parsertemp171319,float554)),0.15915494309189535),rowSums(Y)) +::STMT +MATRIX:mn,mx +-(mx,mn) +::STMT +MATRIX:y_corr,parsertemp171089,parsertemp171084,parsertemp171095 +FLOAT:float558,float534,float130 +LITERAL_FLOAT:0.0,1.0,2.0 +*(+(-(0.0,sqrt(parsertemp171084)),/(+(float558,parsertemp171089),+(float130,parsertemp171095))),-(1.0,*(2.0,>(y_corr,float534)))) +::STMT +MATRIX:Y +LITERAL_FLOAT:2.0 +^(rowSums(Y),2.0) +::STMT +MATRIX:parsertemp500607,w,parsertemp500610,wnew +cast.FLOAT(%*%(t(-(wnew,w)),-(*(parsertemp500607,parsertemp500610),w))) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:105.0,1.0 +/(+(classFeatureCounts,1.0),+(rowSums(classFeatureCounts),105.0)) +::STMT +FLOAT:x +LITERAL_FLOAT:1.0,-1.0 ++(1.0,exp(*(x,-1.0))) +::STMT +LITERAL_FLOAT:1.0,2000.0 +/(2000.0,-(2000.0,1.0)) +::STMT +MATRIX:curr_rows_vector +LITERAL_FLOAT:0.0 +sum(>(curr_rows_vector,0.0)) +::STMT +MATRIX:parsertemp31189,parsertemp31187 +LITERAL_FLOAT:3.42951E11,2.0,6999.0 +/(^(/(-(parsertemp31187,parsertemp31189),6999.0),2.0),3.42951E11) +::STMT +MATRIX:R,dssp,dssm +FLOAT:5_n +LITERAL_FLOAT:1.0 +-(/(5_n,-(+(R,dssp),dssm)),1.0) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:1.0,500.0 +/(+(classFeatureCounts,1.0),+(rowSums(classFeatureCounts),500.0)) +::STMT +MATRIX:parsertemp555744,target +/(sum(rowSums(abs(parsertemp555744))),nrow(target)) +::STMT +MATRIX:parsertemp129125,groupIndex +-(*(groupIndex,max(parsertemp129125)),max(parsertemp129125)) +::STMT +LITERAL_FLOAT:1.0,6.0,2003.0 +*(*(6.0,2003.0),-(2003.0,1.0)) +::STMT +MATRIX:M2,parsertemp553121 +%*%(rowSums(*(M2,M2)),parsertemp553121) +::STMT +MATRIX:s,d +FLOAT:parsertemp44015 +%*%(t(-(s,*(parsertemp44015,d))),d) +::STMT +FLOAT:link_power +LITERAL_FLOAT:2.0 +/(2.0,link_power) +::STMT +FLOAT:link_power +LITERAL_FLOAT:0.0 +/(0.0,link_power) +::STMT +MATRIX:Y +FLOAT:check_max,check_min +LITERAL_FLOAT:2.0 +-(*(/(2.0,-(check_max,check_min)),Y),/(+(check_min,check_max),-(check_max,check_min))) +::STMT +LITERAL_FLOAT:1.8 +1.8 +::STMT +FLOAT:parsertemp40813,m2,mu +LITERAL_FLOAT:5.0 ++(mu,*(5.0,sqrt(*(parsertemp40813,m2)))) +::STMT +MATRIX:Y,2212_fp +/(2212_fp,-(nrow(Y),sum(Y))) +::STMT +MATRIX:R,dssp,dsep,dssm,dsem +FLOAT:5_eAvg +/(/(-(+(R,dsep),dsem),-(+(R,dssp),dssm)),5_eAvg) +::STMT +MATRIX:parsertemp31188,parsertemp31186 +FLOAT:int452 +LITERAL_FLOAT:2.0,6999.0 +^(/(-(colSums(parsertemp31186),*(int452,parsertemp31188)),6999.0),2.0) +::STMT +MATRIX:scale_X,shift_X,parsertemp274137,parsertemp274138,Grad +LITERAL_FLOAT:2.0 +^(+(%*%(diag(scale_X),%*%(parsertemp274137,parsertemp274138)),%*%(shift_X,Grad)),2.0) +::STMT +MATRIX:csgaps,csmask +>(csgaps,csmask) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:1.0 +/(exp(linear_terms),+(1.0,exp(linear_terms))) +::STMT +MATRIX:ctab,parsertemp409528 +LITERAL_FLOAT:0.4 +>(/(parsertemp409528,rowSums(ctab)),0.4) +::STMT +MATRIX:y_hat,B,parsertemp503774 +LITERAL_FLOAT:2.0 +sum(^(-(-(B,parsertemp503774),y_hat),2.0)) +::STMT +MATRIX:P12,map +LITERAL_FLOAT:0.0 +!=(%*%(map,P12),0.0) +::STMT +FLOAT:run_index +LITERAL_FLOAT:24.0,1.0 +*(24.0,-(run_index,1.0)) +::STMT +MATRIX:parsertemp31029,parsertemp31031 +FLOAT:int341 +LITERAL_FLOAT:1.0,150.0 +/(/(-(colSums(parsertemp31029),*(int341,parsertemp31031)),-(150.0,1.0)),150.0) +::STMT +MATRIX:B,parsertemp410245,X_t +LITERAL_FLOAT:-1.0,2.0 +/(*(parsertemp410245,-1.0),*(2.0,exp(%*%(X_t,B)))) +::STMT +MATRIX:surv,parsertemp538706 +*(sqrt(parsertemp538706),surv) +::STMT +MATRIX:LHSthreshold +LITERAL_FLOAT:1.0 +sum(>(LHSthreshold,1.0)) +::STMT +MATRIX:parsertemp477718,parsertemp477728,t,parsertemp477715,parsertemp477737,parsertemp477725,X,parsertemp477734 +FLOAT:int376,x +LITERAL_FLOAT:1.0 +*(*(/(-(x,X),-(X,X)),-(1.0,/(parsertemp477715,parsertemp477718))),+(*(-(parsertemp477725,parsertemp477728),-(int376,t)),*(+(parsertemp477734,parsertemp477737),/(parsertemp477715,parsertemp477718)))) +::STMT +FLOAT:s,num_groups +LITERAL_FLOAT:1.0,7.0 +*(*(-(s,1.0),num_groups),7.0) +::STMT +MATRIX:X +FLOAT:val +!=(X,val) +::STMT +MATRIX:parsertemp171083 +LITERAL_FLOAT:-2.0,0.001308 +*(sqrt(*(-2.0,parsertemp171083)),0.001308) +::STMT +MATRIX:Y,linear_terms,is_y_0 +FLOAT:int410 +LITERAL_FLOAT:0.0 +/(+(Y,==(Y,0.0)),+(*(linear_terms,-(int410,is_y_0)),==(Y,0.0))) +::STMT +MATRIX:Y_counts,Y +-(Y,%*%(Y_counts,/(colSums(Y),sum(Y_counts)))) +::STMT +MATRIX:linear_terms,Y +FLOAT:int829,link_power,parsertemp286300 +/(*(^(linear_terms,-(parsertemp286300,int829)),-(Y,^(linear_terms,parsertemp286300))),link_power) +::STMT +MATRIX:out,parsertemp2798 +FLOAT:int733,int943 +LITERAL_FLOAT:2.0 +sum(^(*(>(out,int733),-(int943,parsertemp2798)),2.0)) +::STMT +MATRIX:R,dssm +FLOAT:2_n,2_alpha +LITERAL_FLOAT:1.0 +*(-(1.0,2_alpha),-(/(2_n,-(R,dssm)),1.0)) +::STMT +MATRIX:parsertemp149248,parsertemp150463,P_1K +*(P_1K,%*%(rowSums(*(P_1K,parsertemp149248)),parsertemp150463)) +::STMT +FLOAT:g +LITERAL_FLOAT:1.0,2.0 ++(*(-(g,1.0),2.0),2.0) +::STMT +MATRIX:p,V +LITERAL_FLOAT:1.0E-8 ++(%*%(t(V),%*%(V,p)),*(1.0E-8,p)) +::STMT +MATRIX:posSamples +LITERAL_FLOAT:2.0 +colSums(^(posSamples,2.0)) +::STMT +MATRIX:parsertemp175066,scores,parsertemp175069,unnorm_probs,dprobs +*(/(exp(-(scores,parsertemp175066)),rowSums(exp(scores))),rowSums(*(dprobs,/(unnorm_probs,parsertemp175069)))) +::STMT +MATRIX:F +%*%(rowSums(F),colSums(F)) +::STMT +FLOAT:42_m2X +LITERAL_FLOAT:1.0,1000.0 +*(42_m2X,/(1000.0,-(1000.0,1.0))) +::STMT +MATRIX:252_Y +FLOAT:252_X,float125,float67 +LITERAL_FLOAT:1.0 ++(*(-(1.0,/(float67,252_X)),cast.FLOAT(252_Y)),*(/(-(float125,252_X),-(252_X,252_X)),cast.FLOAT(252_Y))) +::STMT +MATRIX:CVars,CFreqs +FLOAT:float426,int601,int956,parsertemp31330,int591 +LITERAL_FLOAT:1.0,10000.0 +/(sum(*(-(CFreqs,int601),CVars)),*(-(10000.0,1.0),/(*(parsertemp31330,int956),-(int591,float426)))) +::STMT +MATRIX:parsertemp171315,t_gp,parsertemp171320,parsertemp171307,parsertemp171316 +FLOAT:float678,float19 +LITERAL_FLOAT:2.0,0.25 +*(*(0.25,*(/(float678,parsertemp171307),+(float19,parsertemp171315))),-(2.0,*(exp(parsertemp171320),*(t_gp,parsertemp171316)))) +::STMT +MATRIX:X,parsertemp115855 +FLOAT:int61,n +LITERAL_FLOAT:2.0 +-(t(colSums(^(X,int61))),*(nrow(X),^(/(parsertemp115855,n),2.0))) +::STMT +MATRIX:out2,parsertemp146942,184_dscores,maskd1,W2 +FLOAT:p,int969 +*(/(maskd1,p),%*%(*(>(out2,int969),%*%(184_dscores,parsertemp146942)),t(W2))) +::STMT +MATRIX:linear_terms +FLOAT:var_power +LITERAL_FLOAT:1.0 +exp(*(linear_terms,-(1.0,var_power))) +::STMT +FLOAT:g +LITERAL_FLOAT:1.0,2.0 ++(*(-(g,1.0),2.0),1.0) +::STMT +MATRIX:posSampleVariances,negSampleVariances +FLOAT:int593,int652,int575,int849 +LITERAL_FLOAT:7.996E9,2.0,3.37275E9 +/(^(+(/(posSampleVariances,int652),/(negSampleVariances,int849)),2.0),+(/(^(posSampleVariances,int593),7.996E9),/(^(negSampleVariances,int575),3.37275E9))) +::STMT +MATRIX:parsertemp410245,parsertemp410248 +FLOAT:int577,float40 +LITERAL_FLOAT:1.0,1.5 +min(^(/(*(parsertemp410245,int577),*(float40,parsertemp410248)),/(1.0,1.5))) +::STMT +MATRIX:parsertemp2832 +==(round(parsertemp2832),min(round(parsertemp2832))) +::STMT +MATRIX:C,Xm,parsertemp265702 +sum(%*%(%*%(%*%(Xm,parsertemp265702),t(C)),t(Xm))) +::STMT +MATRIX:parsertemp386437,neighbors +FLOAT:eps +LITERAL_FLOAT:0.0 +*(<=(-(neighbors,diag(parsertemp386437)),eps),<(0.0,-(neighbors,diag(parsertemp386437)))) +::STMT +MATRIX:neighbors +FLOAT:eps,int625 +LITERAL_FLOAT:1.0 ++(rowSums(*(<=(neighbors,eps),<(int625,neighbors))),1.0) +::STMT +MATRIX:finite_linear_terms +LITERAL_FLOAT:0.0,1.0 +-(1.0,exp(-(0.0,exp(finite_linear_terms)))) +::STMT +MATRIX:W +FLOAT:parsertemp65,parsertemp66 +LITERAL_FLOAT:1.0,2.0,3.0 +*(*(-(sum(W),1.0),-(sum(W),2.0)),^(sqrt(/(parsertemp65,parsertemp66)),3.0)) +::STMT +MATRIX:parsertemp42200,R +FLOAT:int779,meanX +LITERAL_FLOAT:1.0,2.0 +-(+(-(parsertemp42200,/(R,int779)),/(1.0,2.0)),meanX) +::STMT +MATRIX:V,parsertemp10742,H,parsertemp10738 +FLOAT:Eps +t(*(H,/(%*%(parsertemp10738,V),+(parsertemp10742,Eps)))) +::STMT +MATRIX:r_LS,parsertemp285848 +LITERAL_FLOAT:0.0 +-(0.0,cast.FLOAT(%*%(t(r_LS),t(parsertemp285848)))) +::STMT +MATRIX:X,parsertemp115854 +LITERAL_FLOAT:2.0 +*($1:nrow(X),^(/(t(parsertemp115854),$1),2.0)) +::STMT +MATRIX:W,X,parsertemp411199,parsertemp411201 +LITERAL_FLOAT:1.0E-8 +/(X,+(%*%(W,/(parsertemp411199,parsertemp411201)),1.0E-8)) +::STMT +FLOAT:parsertemp42302,parsertemp42306 +LITERAL_FLOAT:1.000100010001 +*(sqrt(*(parsertemp42302,1.000100010001)),sqrt(*(parsertemp42306,1.000100010001))) +::STMT +MATRIX:p_CG +FLOAT:parsertemp285794,parsertemp285796 +LITERAL_FLOAT:-1.0 +/(+(*(cast.FLOAT(p_CG),-1.0),sqrt(-(parsertemp285794,parsertemp285796))),cast.FLOAT(%*%(t(p_CG),p_CG))) +::STMT +MATRIX:cdf_min_distances,threshold_matrix +LITERAL_FLOAT:1.0 ++(t(colSums(<(cdf_min_distances,threshold_matrix))),1.0) +::STMT +FLOAT:dimensions +LITERAL_FLOAT:1.0,2.0 ++(^(2.0,dimensions),1.0) +::STMT +FLOAT:m2Y,sigmaX,W,parsertemp26583 +*(sigmaX,sqrt(*(m2Y,/(W,parsertemp26583)))) +::STMT +MATRIX:CVars,CFreqs +FLOAT:int381 +LITERAL_FLOAT:10000.0 +/(sum(*(-(CFreqs,int381),CVars)),-(10000.0,nrow(CFreqs))) +::STMT +MATRIX:R,dssp,dssm +FLOAT:5_n +/(5_n,-(+(R,dssp),dssm)) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:1.0,50.0 ++(rowSums(classFeatureCounts),*(50.0,1.0)) +::STMT +MATRIX:parsertemp410978,W,H,parsertemp410980 +FLOAT:eps ++(%*%(W,/(*(H,parsertemp410978),t(parsertemp410980))),eps) +::STMT +MATRIX:linear_terms +FLOAT:var_power +LITERAL_FLOAT:1.0,2.0 +/(/(1.0,^(linear_terms,2.0)),-(2.0,var_power)) +::STMT +MATRIX:LT,Y,parsertemp149320,parsertemp150469 +*(Y,-(LT,%*%(parsertemp149320,parsertemp150469))) +::STMT +LITERAL_FLOAT:0.6 +0.6 +::STMT +MATRIX:C,Xm,parsertemp265701 +t(%*%(Xm,%*%(C,parsertemp265701))) +::STMT +MATRIX:parsertemp42190,X +LITERAL_FLOAT:1.0,2.0 ++(-(parsertemp42190,/(X,2.0)),/(1.0,2.0)) +::STMT +LITERAL_FLOAT:-1.0E30 +-1.0E30 +::STMT +LITERAL_FLOAT:1.0E30 +1.0E30 +::STMT +MATRIX:prec_chol,mu +LITERAL_FLOAT:2.0 +*(^(mu,2.0),^(prec_chol,2.0)) +::STMT +LITERAL_FLOAT:0.85 +0.85 +::STMT +MATRIX:X,Y +FLOAT:x +*(/(-(x,X),-(X,X)),Y) +::STMT +LITERAL_FLOAT:0.3 +0.3 +::STMT +MATRIX:p,V +%*%(V,p) +::STMT +MATRIX:dY,g +FLOAT:lr,momentum +LITERAL_FLOAT:2.0 +sum(^(-(*(momentum,dY),*(lr,g)),2.0)) +::STMT +MATRIX:ncCnts,maxsc +FLOAT:parsertemp31781 +LITERAL_FLOAT:0.0 +|(>(ncCnts,0.0),>(maxsc,parsertemp31781)) +::STMT +MATRIX:current_node +FLOAT:cur_node_index +LITERAL_FLOAT:1.0 ++(+(cur_node_index,cast.FLOAT(current_node)),1.0) +::STMT +MATRIX:_sbcvar1708 +LITERAL_FLOAT:45.0 ++(45.0,nrow(_sbcvar1708)) +::STMT +LITERAL_FLOAT:0.08146881698903526 +0.08146881698903526 +::STMT +MATRIX:cumLeftHist,parsertemp132495,parsertemp132506,leftHist,outBucket +LITERAL_FLOAT:1.0 ++(+(%*%(==(outBucket,parsertemp132495),-(cumLeftHist,leftHist)),parsertemp132506),1.0) +::STMT +LITERAL_FLOAT:0.30000000000000004 +0.30000000000000004 +::STMT +MATRIX:s,d,alpha +FLOAT:parsertemp44015 +%*%(t(-(s,*(parsertemp44015,d))),-(s,*(cast.FLOAT(alpha),d))) +::STMT +MATRIX:C,Xm,tmp,parsertemp265701 +/(%*%(t(Xm),%*%(Xm,%*%(C,parsertemp265701))),sum(tmp)) +::STMT +MATRIX:logistic,X,y +FLOAT:int215 +LITERAL_FLOAT:2.0 +*(2.0,%*%(t(X),*(-(logistic,int215),y))) +::STMT +MATRIX:q_LS,p_LS,parsertemp170551,X +FLOAT:norm_r2_LS,lambda_LS +*(/(norm_r2_LS,sum(*(p_LS,q_LS))),+(%*%(%*%(parsertemp170551,X),p_LS),*(lambda_LS,p_LS))) +::STMT +MATRIX:shift_X,parsertemp116007 +LITERAL_FLOAT:2.0,9.999999999999998E-15 +*(sum(^(+(parsertemp116007,shift_X),2.0)),9.999999999999998E-15) +::STMT +MATRIX:parsertemp10744,W,H +FLOAT:Eps ++(%*%(W,%*%(*(H,parsertemp10744),t(H))),Eps) +::STMT +MATRIX:parsertemp170277 +LITERAL_FLOAT:3.141592653589793,0.5 ++(0.5,/(parsertemp170277,3.141592653589793)) +::STMT +MATRIX:ts +FLOAT:q ++(-(q,%*%(ts,ts)),%*%(ts,ts)) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:1.0,2.0,1.0E7 +*(==(+(1.0E7,exp(linear_terms)),1.0E7),-(1.0,/(exp(linear_terms),2.0))) +::STMT +MATRIX:dY,W,Y,sumW +FLOAT:lr,momentum +-(*(momentum,dY),*(lr,-(*(Y,sumW),%*%(W,Y)))) +::STMT +MATRIX:m_err +sum(colSums(m_err)) +::STMT +MATRIX:parsertemp409058,parsertemp409054,ctab +FLOAT:threshold +*(parsertemp409058,>(/(parsertemp409054,rowSums(ctab)),threshold)) +::STMT +MATRIX:means,parsertemp560515 +LITERAL_FLOAT:2.0 +rowSums(*(means,^(parsertemp560515,2.0))) +::STMT +MATRIX:P,minD,D +t(colSums(/(<=(D,minD),rowSums(P)))) +::STMT +MATRIX:X +FLOAT:x +LITERAL_FLOAT:1.0 +-(1.0,/(cast.FLOAT(-(x,X)),-(cast.FLOAT(X),cast.FLOAT(X)))) +::STMT +MATRIX:tpr,fpr +*(-(fpr,fpr),+(tpr,tpr)) +::STMT +MATRIX:is_LT_infinite,Y_prob +LITERAL_FLOAT:1.0 +*(Y_prob,-(1.0,rowSums(is_LT_infinite))) +::STMT +MATRIX:parsertemp222327,is_row_in_samples +LITERAL_FLOAT:2001.0 +-(2001.0,*(is_row_in_samples,parsertemp222327)) +::STMT +MATRIX:surv +LITERAL_FLOAT:1.0 +*(surv,sqrt(-(1.0,surv))) +::STMT +MATRIX:ssX_p,scale_X,X +*(scale_X,%*%(t(X),%*%(X,ssX_p))) +::STMT +FLOAT:b,int247 +LITERAL_FLOAT:2.0 +sqrt(-(^(b,2.0),int247)) +::STMT +MATRIX:tab,catTotal +LITERAL_FLOAT:-1.0 +*(/(tab,catTotal),-1.0) +::STMT +MATRIX:col_nonzeros,parsertemp382954,parsertemp382951,row_nonzeros +LITERAL_FLOAT:5.0E-7 +*(5.0E-7,+(sum(*(parsertemp382951,row_nonzeros)),sum(*(parsertemp382954,col_nonzeros)))) +::STMT +FLOAT:int484,217_a22,parsertemp22450,parsertemp22451 +LITERAL_FLOAT:2.0 +*(2.0,sqrt(+(+(parsertemp22450,parsertemp22451),/(int484,217_a22)))) +::STMT +LITERAL_FLOAT:44.721359549995796 +44.721359549995796 +::STMT +MATRIX:X +FLOAT:int557,int228 +LITERAL_FLOAT:1.0 +/(-(exp(*(int557,X)),1.0),+(exp(*(int228,X)),1.0)) +::STMT +MATRIX:parsertemp31023,parsertemp31025 +FLOAT:int211,int718 +LITERAL_FLOAT:1.0,2.0,100.0 +/(^(/(-(parsertemp31023,parsertemp31025),-(int211,int718)),2.0),*(^(100.0,2.0),-(100.0,1.0))) +::STMT +MATRIX:B,X,y +LITERAL_FLOAT:2.0 +^(-(y,%*%(X,B)),2.0) +::STMT +MATRIX:prec_chol,mu +FLOAT:int510,int468 +t(rowSums(*(^(mu,int468),^(prec_chol,int510)))) +::STMT +LITERAL_FLOAT:1.0,2001.0 +-(2001.0,1.0) +::STMT +MATRIX:y_corr +LITERAL_FLOAT:1.0 +/(y_corr,-(1.0,y_corr)) +::STMT +MATRIX:ts +LITERAL_FLOAT:1.0,2.0,4.0 +-(+(-(length(ts),4.0),1.0),2.0) +::STMT +LITERAL_FLOAT:0.086386842558136 +0.086386842558136 +::STMT +MATRIX:P,scale_lambda,X,Y,parsertemp150455 +LITERAL_FLOAT:0.0,1.0E-5 ++(%*%(t(X),-(P,Y)),*(*(%*%(scale_lambda,parsertemp150455),1.0E-5),0.0)) +::STMT +MATRIX:parsertemp555613,parsertemp555615 +%*%(t(sqrt(parsertemp555613)),sqrt(parsertemp555615)) +::STMT +MATRIX:X,Y +/(abs(-(X,Y)),abs(X)) +::STMT +MATRIX:W1_rand,X,parsertemp393476,parsertemp393466 +FLOAT:float616 +LITERAL_FLOAT:0.07261134713572442 +%*%(*(0.07261134713572442,W1_rand),t(/(-(X,parsertemp393466),+(parsertemp393476,float616)))) +::STMT +MATRIX:colSD +LITERAL_FLOAT:3.0 +*(3.0,colSD) +::STMT +MATRIX:_funvar402 +LITERAL_FLOAT:1.0E-16 ++(_funvar402,1.0E-16) +::STMT +MATRIX:var_tot_Y +cast.FLOAT(sqrt(var_tot_Y)) +::STMT +MATRIX:select,d_r_rev,X_rev_agg +*(%*%(select,X_rev_agg),d_r_rev) +::STMT +FLOAT:n_features +LITERAL_FLOAT:1.0 +*(n_features,+(n_features,1.0)) +::STMT +MATRIX:r +LITERAL_FLOAT:9.999999999999998E-15 +*(cast.FLOAT(%*%(t(r),r)),9.999999999999998E-15) +::STMT +LITERAL_FLOAT:3.0,2001.0 +-(2001.0,3.0) +::STMT +MATRIX:X,Y,K +LITERAL_FLOAT:-1.0 ++(*(*(K,-1.0),-(X,X)),-(Y,Y)) +::STMT +MATRIX:sample_maps,X +LITERAL_FLOAT:2.0 +^(%*%(sample_maps,X),2.0) +::STMT +MATRIX:Y_prob,Y,parsertemp171380 +FLOAT:int58 +LITERAL_FLOAT:3.141592653589793,2.0 +*(*(*(rowSums(Y),Y_prob),Y_prob),^(*(+(int58,parsertemp171380),3.141592653589793),2.0)) +::STMT +MATRIX:y_corr +LITERAL_FLOAT:0.5 +-(y_corr,0.5) +::STMT +LITERAL_FLOAT:0.15000000000000002 +0.15000000000000002 +::STMT +FLOAT:parsertemp31330 +LITERAL_FLOAT:10000.0 +*(parsertemp31330,10000.0) +::STMT +MATRIX:A +/(*(cast.FLOAT(A),cast.FLOAT(A)),*(cast.FLOAT(A),cast.FLOAT(A))) +::STMT +MATRIX:W +FLOAT:int461,parsertemp65,parsertemp66,int339,wt +LITERAL_FLOAT:3.0,4.0 +*(*(*(-(wt,int339),-(wt,int461)),-(sum(W),3.0)),^(sqrt(/(parsertemp65,parsertemp66)),4.0)) +::STMT +FLOAT:int495,x +LITERAL_FLOAT:1.0 +/(1.0,+(1.0,exp(*(x,int495)))) +::STMT +MATRIX:distances,ksmall,parsertemp557211 +LITERAL_FLOAT:0.0 +*(<=(distances,ksmall),==(diag(parsertemp557211),0.0)) +::STMT +MATRIX:parsertemp410979,W,X,parsertemp410981 +FLOAT:eps +/(X,+(%*%(W,/(parsertemp410979,parsertemp410981)),eps)) +::STMT +MATRIX:Xtest_dists +FLOAT:eps +LITERAL_FLOAT:0.0 +rowSums(*(<=(Xtest_dists,eps),<(0.0,Xtest_dists))) +::STMT +MATRIX:_sbcvar11 +LITERAL_FLOAT:1000.0 +/(%*%(rowSums(_sbcvar11),colSums(_sbcvar11)),1000.0) +::STMT +MATRIX:resp,mean,X,weight +FLOAT:int164 +LITERAL_FLOAT:2.0 +-(/(%*%(t(resp),^(X,int164)),t(weight)),*(2.0,^(mean,2.0))) +::STMT +MATRIX:CFreqs +LITERAL_FLOAT:1.0 +-(CFreqs,1.0) +::STMT +MATRIX:parsertemp220853,parsertemp220854 +LITERAL_FLOAT:0.0,2.0,3.4011973816621555 +*(2.0,>=(-(+(parsertemp220853,parsertemp220854),3.4011973816621555),0.0)) +::STMT +MATRIX:WM +LITERAL_FLOAT:1.0 +/(sum(WM),-(sum(WM),1.0)) +::STMT +LITERAL_FLOAT:2.0,2001.0 +-(2001.0,2.0) +::STMT +MATRIX:X +FLOAT:index,int193,parsertemp129094 +LITERAL_FLOAT:2.0 ++(+(*(index,-(parsertemp129094,int193)),2.0),-(ncol(X),2.0)) +::STMT +MATRIX:t_gp,parsertemp560881,parsertemp560864,parsertemp560863,parsertemp560877 +FLOAT:int773,float843,int853 +LITERAL_FLOAT:1.0 +-(+(1.0,-(*(int773,parsertemp560863),1.0)),*(*(*(t_gp,parsertemp560877),-(parsertemp560864,int853)),exp(/(parsertemp560881,float843)))) +::STMT +FLOAT:parsertemp410218,parsertemp410219 +LITERAL_FLOAT:-1.0,50.0 +exp(/(*(-(parsertemp410218,parsertemp410219),-1.0),50.0)) +::STMT +FLOAT:rho +LITERAL_FLOAT:10000.0 +round(*(10000.0,rho)) +::STMT +FLOAT:eta,s +^(eta,s) +::STMT +MATRIX:ss_res_Y,var_tot_Y +FLOAT:df_ss_res_Y +LITERAL_FLOAT:1.0 +-(1.0,/(/(ss_res_Y,df_ss_res_Y),var_tot_Y)) +::STMT +MATRIX:tmp,parsertemp260786,parsertemp260787,parsertemp260785 +cast.FLOAT(%*%(t(-(parsertemp260787,tmp)),-(%*%(parsertemp260785,parsertemp260786),tmp))) +::STMT +MATRIX:A +LITERAL_FLOAT:1.0E-4 ++(1.0E-4,abs(t(A))) +::STMT +MATRIX:is_LT_infinite,Y_prob,Y,parsertemp171294,parsertemp171292,flip_pos,parsertemp171290 +FLOAT:float465 +*(*(Y,%*%(+(parsertemp171294,is_LT_infinite),flip_pos)),+(*(/(Y_prob,parsertemp171290),-(float465,parsertemp171292)),is_LT_infinite)) +::STMT +MATRIX:P,Y,dP +sum(&(>(P,dP),Y)) +::STMT +FLOAT:a,b,c,int863 +LITERAL_FLOAT:2.0 +sqrt(-(^(b,2.0),*(*(int863,a),c))) +::STMT +MATRIX:y_corr +FLOAT:link_power,int319 +LITERAL_FLOAT:0.0 +-(^(+(y_corr,==(y_corr,int319)),link_power),==(y_corr,0.0)) +::STMT +LITERAL_FLOAT:1.0,2.0,3.0,2003.0 +*(*(-(2003.0,2.0),+(2003.0,1.0)),+(2003.0,3.0)) +::STMT +MATRIX:g0_1,parsertemp410117 +LITERAL_FLOAT:2.0 +^(+(g0_1,t(colSums(parsertemp410117))),2.0) +::STMT +MATRIX:P,Y,dP +&(<=(P,dP),!(Y)) +::STMT +MATRIX:parsertemp274141,shift_X,Grad +LITERAL_FLOAT:2.0 +sum(^(+(%*%(parsertemp274141,Grad),%*%(shift_X,Grad)),2.0)) +::STMT +MATRIX:U,V,X +LITERAL_FLOAT:0.0 +*(!=(X,0.0),-(%*%(U,t(V)),X)) +::STMT +MATRIX:col +FLOAT:min_val,bin_width +/(-(col,min_val),bin_width) +::STMT +MATRIX:parsertemp260759,parsertemp260756,Xd +FLOAT:dd,parsertemp260753,wd +/(*(-(+(wd,parsertemp260753),sum(parsertemp260756)),-(+(wd,parsertemp260753),sum(parsertemp260756))),+(dd,sum(*(parsertemp260759,Xd)))) +::STMT +MATRIX:parsertemp254737 +FLOAT:2124_sq_root_d,parsertemp254772,parsertemp254751,float69 ++(float69,*(parsertemp254772,/(-(parsertemp254751,2124_sq_root_d),sum(parsertemp254737)))) +::STMT +MATRIX:X,Centering,ScaleFactor +FLOAT:N +LITERAL_FLOAT:1.0 +/(%*%(t(/(X,ScaleFactor)),/(-(X,Centering),ScaleFactor)),-(N,1.0)) +::STMT +FLOAT:link_power +LITERAL_FLOAT:-2.0,1.0 ++(-2.0,/(1.0,link_power)) +::STMT +LITERAL_FLOAT:0.0,1.0 +INT:int449,m +abs(rand(m,int449,0.0,1.0)) +::STMT +FLOAT:parsertemp22454,parsertemp22485 +LITERAL_FLOAT:2.0 +exp(+(parsertemp22485,*(2.0,sqrt(parsertemp22454)))) +::STMT +MATRIX:_sbcvar1007 +FLOAT:number_nans +/(number_nans,nrow(_sbcvar1007)) +::STMT +MATRIX:r,parsertemp44050 +FLOAT:norm_r2 +/(sum(*(-(r,parsertemp44050),-(r,parsertemp44050))),norm_r2) +::STMT +MATRIX:xs +LITERAL_FLOAT:1000.0,4.5 +-(1000.0,sum(>=(xs,4.5))) +::STMT +MATRIX:parsertemp397720,W1_rand,parsertemp397730,X +FLOAT:float798 +LITERAL_FLOAT:0.086386842558136 +%*%(*(0.086386842558136,W1_rand),t(/(-(X,parsertemp397720),+(parsertemp397730,float798)))) +::STMT +MATRIX:I +*(nrow(I),ncol(I)) +::STMT +MATRIX:linear_terms +FLOAT:link_power,parsertemp171228 +LITERAL_FLOAT:2.0 +/(^(linear_terms,-(/(parsertemp171228,link_power),2.0)),^(link_power,2.0)) +::STMT +MATRIX:_sbcvar96,_sbcvar95,_sbcvar98 +LITERAL_FLOAT:-1.0 +sum(*(+(%*%(_sbcvar95,_sbcvar96),-1.0),%*%(_sbcvar95,_sbcvar98))) +::STMT +MATRIX:parsertemp170136 +FLOAT:278_sq_root_d,parsertemp170150,pq_CG +LITERAL_FLOAT:0.5 +*(*(0.5,/(-(parsertemp170150,278_sq_root_d),sum(parsertemp170136))),pq_CG) +::STMT +MATRIX:V,W,parsertemp10741,H +LITERAL_FLOAT:1.0E-8 +*(H,/(%*%(t(W),V),+(%*%(parsertemp10741,H),1.0E-8))) +::STMT +FLOAT:252_Y,252_X,252_K,float711,float512,parsertemp32930,int666,parsertemp32915,float790 +LITERAL_FLOAT:1.0 +*(*(/(-(float512,252_X),-(252_X,252_X)),-(1.0,/(float790,252_X))),+(*(-(252_K,252_Y),-(int666,parsertemp32915)),*(+(parsertemp32930,252_Y),/(float711,252_X)))) +::STMT +FLOAT:int684,191_t,191_lr,int4,191_beta1,parsertemp146979 +LITERAL_FLOAT:1.0 +/(*(191_lr,sqrt(-(int684,parsertemp146979))),-(1.0,^(191_beta1,+(191_t,int4)))) +::STMT +FLOAT:rho +LITERAL_FLOAT:10000.0 +/(round(*(10000.0,rho)),10000.0) +::STMT +MATRIX:finite_linear_terms +LITERAL_FLOAT:-1.0,2.0 +/(*(^(finite_linear_terms,2.0),-1.0),2.0) +::STMT +MATRIX:ssX_newbeta +LITERAL_FLOAT:0.0 +INT:int142,int272 ++(ssX_newbeta,cast.FLOAT(rand(int142,int272,0.0,0.0))) +::STMT +MATRIX:S +LITERAL_FLOAT:2.0,799.0 +/(^(diag(S),2.0),799.0) +::STMT +MATRIX:parsertemp171314,t_gp,parsertemp171306 +FLOAT:float653 +LITERAL_FLOAT:1.0,0.25,0.254829592 +*(0.25,*(/(1.0,+(float653,parsertemp171306)),+(0.254829592,*(t_gp,parsertemp171314)))) +::STMT +FLOAT:num_hidden1,m +sqrt(+(m,num_hidden1)) +::STMT +MATRIX:parsertemp410988,parsertemp410979,parsertemp410990,parsertemp410981 +FLOAT:parsertemp410999 +-(sum(%*%(/(parsertemp410988,parsertemp410990),/(parsertemp410979,parsertemp410981))),parsertemp410999) +::STMT +MATRIX:d,parsertemp410054 +FLOAT:r2 +/(r2,sum(*(d,t(parsertemp410054)))) +::STMT +MATRIX:E,parsertemp22269 +FLOAT:int373,q +LITERAL_FLOAT:10000.0 +sqrt(/(sum(/(parsertemp22269,E)),*(10000.0,-(q,int373)))) +::STMT +FLOAT:beg +LITERAL_FLOAT:1.0,512.0 +-(+(beg,512.0),1.0) +::STMT +MATRIX:parsertemp220863,parsertemp220864,H,betamax,Hneg,beta,Hpos +FLOAT:float727 +LITERAL_FLOAT:0.0,1.0E20 +*(*(>=(-(H,float727),0.0),!=(+(parsertemp220863,parsertemp220864),1.0E20)),+(beta,+(*(Hpos,betamax),*(Hneg,beta)))) +::STMT +MATRIX:w,out +LITERAL_FLOAT:1.0,0.5,0.001 +*(0.001,+(*(0.5,cast.FLOAT(out)),*(1.0,cast.FLOAT(w)))) +::STMT +MATRIX:F +-(F,/(%*%(rowSums(F),colSums(F)),sum(F))) +::STMT +MATRIX:rowSums_X_sq +LITERAL_FLOAT:5.0 +/(5.0,max(sqrt(rowSums_X_sq))) +::STMT +MATRIX:Ileft,Iright,ig +FLOAT:min_leaf +*(&(>=(rowSums(Ileft),min_leaf),>=(rowSums(Iright),min_leaf)),ig) +::STMT +FLOAT:c +LITERAL_FLOAT:-1.0,2.0 +*(*(2.0,c),-1.0) +::STMT +MATRIX:maxscub +FLOAT:parsertemp31797 +LITERAL_FLOAT:-Infinity +|(>=(maxscub,parsertemp31797),==(maxscub,-Infinity)) +::STMT +MATRIX:vars +FLOAT:dispersion +*(dispersion,colSums(vars)) +::STMT +MATRIX:parsertemp410245,parsertemp410247 +LITERAL_FLOAT:-1.0,1.0,2.0,1.5 +^(/(*(parsertemp410245,-1.0),*(2.0,exp(parsertemp410247))),/(1.0,1.5)) +::STMT +FLOAT:e,mu +LITERAL_FLOAT:0.999,4.0 +/(-(0.999,mu),-(4.0,e)) +::STMT +LITERAL_FLOAT:105.0,1.0 +*(105.0,1.0) +::STMT +LITERAL_FLOAT:1.0,10000.0 +-(10000.0,1.0) +::STMT +MATRIX:parsertemp2781,Xd,parsertemp2785 +FLOAT:dd,step_sz,wd +/(-(+(wd,*(step_sz,dd)),sum(*(parsertemp2781,Xd))),+(dd,sum(*(parsertemp2785,Xd)))) +::STMT +MATRIX:W2_rand +LITERAL_FLOAT:0.2656844656620286 +*(0.2656844656620286,W2_rand) +::STMT +MATRIX:linear_terms +FLOAT:link_power +LITERAL_FLOAT:2.0 +^(linear_terms,/(2.0,link_power)) +::STMT +MATRIX:252_X,252_K +LITERAL_FLOAT:0.0 +*(-(0.0,cast.FLOAT(252_K)),-(cast.FLOAT(252_X),cast.FLOAT(252_X))) +::STMT +MATRIX:ytest,yhat +FLOAT:parsertemp115806,n +LITERAL_FLOAT:2.0 +-(sum(^(-(ytest,yhat),2.0)),*(nrow(ytest),^(/(parsertemp115806,n),2.0))) +::STMT +MATRIX:parsertemp31265,WM,CMeans +LITERAL_FLOAT:2.0 +^(-(CMeans,/(sum(parsertemp31265),sum(WM))),2.0) +::STMT +FLOAT:log_ten,float83,parsertemp169813 +LITERAL_FLOAT:4.0 +*(log_ten,-(4.0,round(-(parsertemp169813,float83)))) +::STMT +MATRIX:X +FLOAT:i ++(i,ncol(X)) +::STMT +MATRIX:parsertemp410978,H,parsertemp410980 +t(rowSums(/(*(H,parsertemp410978),t(parsertemp410980)))) +::STMT +MATRIX:residual_matrix +LITERAL_FLOAT:0.0 ++(nrow(residual_matrix),0.0) +::STMT +MATRIX:X_plane,parsertemp11251 +FLOAT:int665 +LITERAL_FLOAT:0.0 +rowSums(*(>(X_plane,0.0),t(^(int665,parsertemp11251)))) +::STMT +MATRIX:parsertemp178161,M +colSums(exp(-(M,parsertemp178161))) +::STMT +MATRIX:W +LITERAL_FLOAT:2.0 +-(sum(round(W)),2.0) +::STMT +MATRIX:r,d,Hd +FLOAT:r2,c +LITERAL_FLOAT:0.0 ++(-(0.0,+(r,*(c,Hd))),*(/(cast.FLOAT(r),r2),d)) +::STMT +LITERAL_FLOAT:2.0,0.5,-0.5 +INT:int121,int493 +^(rand(int121,int493,-0.5,0.5),2.0) +::STMT +MATRIX:W +LITERAL_FLOAT:3.0 +-(sum(round(W)),3.0) +::STMT +MATRIX:trees_M_offset +LITERAL_FLOAT:1.0 +-(cast.FLOAT(trees_M_offset),1.0) +::STMT +MATRIX:dataFrame,constraintsFrame +*(nrow(dataFrame),nrow(constraintsFrame)) +::STMT +MATRIX:S,parsertemp382904,V,W,row_nonzeros +FLOAT:reg ++(%*%(*(W,%*%(S,parsertemp382904)),V),*(*(reg,S),row_nonzeros)) +::STMT +MATRIX:oldX +LITERAL_FLOAT:1.0 ++(nrow(oldX),1.0) +::STMT +MATRIX:parsertemp10964,C +LITERAL_FLOAT:100.0 +/(sum(==(parsertemp10964,C)),100.0) +::STMT +MATRIX:obj,gs,parsertemp44066 +FLOAT:float664,int191,parsertemp44077,int394 +LITERAL_FLOAT:-0.5 +/(-(cast.FLOAT(obj),+(*(float664,parsertemp44077),*(int191,int394))),*(-0.5,-(cast.FLOAT(gs),cast.FLOAT(parsertemp44066)))) +::STMT +MATRIX:W +LITERAL_FLOAT:1.0 +-(sum(round(W)),1.0) +::STMT +MATRIX:_sbcvar179,_sbcvar182,237_CFreqs +FLOAT:int842 +LITERAL_FLOAT:10000.0 +/(sum(*(+(237_CFreqs,int842),%*%(_sbcvar179,_sbcvar182))),-(10000.0,nrow(_sbcvar179))) +::STMT +MATRIX:p,z +FLOAT:pp,parsertemp169870,pz +LITERAL_FLOAT:-1.0 ++(*(sum(*(p,z)),-1.0),sqrt(-(*(pz,pz),*(pp,parsertemp169870)))) +::STMT +MATRIX:parsertemp31782,err,parsertemp31769,parsertemp31768,cCnts,parsertemp31780 +FLOAT:minSup,int606 +-(sum(&(>=(cCnts,minSup),>(err,int606))),sum(&(&(parsertemp31768,parsertemp31769),|(parsertemp31780,parsertemp31782)))) +::STMT +MATRIX:V,y +LITERAL_FLOAT:0.0 +-(0.0,%*%(t(V),y)) +::STMT +MATRIX:Y,predicted_Y +LITERAL_FLOAT:0.0 +sum(==(-(predicted_Y,Y),0.0)) +::STMT +FLOAT:n_stratum_cols,n_group_cols +LITERAL_FLOAT:2.0 ++(+(2.0,n_group_cols),n_stratum_cols) +::STMT +FLOAT:sigma,alpha +LITERAL_FLOAT:0.5 +*(*(0.5,sigma),alpha) +::STMT +MATRIX:rowSums_X_sq +LITERAL_FLOAT:8.674675786448736 +/(8.674675786448736,max(sqrt(rowSums_X_sq))) +::STMT +MATRIX:G,authorities +max(%*%(t(G),%*%(G,authorities))) +::STMT +MATRIX:indexWithInGroups,selectedMatrix +rowSums(*(indexWithInGroups,selectedMatrix)) +::STMT +MATRIX:in_m_neighbor_value +FLOAT:in_i_k_min +LITERAL_FLOAT:1.0 ++(-(ncol(in_m_neighbor_value),in_i_k_min),1.0) +::STMT +MATRIX:parsertemp386440,parsertemp386441 +FLOAT:minPts +LITERAL_FLOAT:1.0 +>=(+(rowSums(*(parsertemp386440,parsertemp386441)),1.0),minPts) +::STMT +MATRIX:solution,X +*(-(X,solution),-(X,solution)) +::STMT +MATRIX:qLow,length,qUp +LITERAL_FLOAT:2.0 +<(rowSums(|(<(length,qLow),>(length,qUp))),2.0) +::STMT +MATRIX:C,parsertemp11014 +LITERAL_FLOAT:1000.0 +/(sum(==(parsertemp11014,C)),1000.0) +::STMT +MATRIX:parsertemp2832 +==(round(parsertemp2832),max(round(parsertemp2832))) +::STMT +MATRIX:parsertemp410081,d_r_rev,parsertemp410090 +FLOAT:o +LITERAL_FLOAT:-1.0 +-(+(*(cast.FLOAT(parsertemp410081),-1.0),sum(*(d_r_rev,parsertemp410090))),o) +::STMT +MATRIX:CFreqs1,present_domain_vals_mat,CMeans +FLOAT:my +LITERAL_FLOAT:2.0 +sum(*(%*%(present_domain_vals_mat,CFreqs1),^(-(CMeans,my),2.0))) +::STMT +MATRIX:linear_terms +FLOAT:link_power +LITERAL_FLOAT:-1.0 +*(^(linear_terms,/(-1.0,link_power)),-1.0) +::STMT +MATRIX:parsertemp437190,X,weight +LITERAL_FLOAT:2.0 +*(2.0,^(/(%*%(parsertemp437190,X),t(weight)),2.0)) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:1.0,2.0 +/(/(1.0,linear_terms),2.0) +::STMT +FLOAT:int252,int543 +INT:int84,int477 +diag(rand(int84,int477,int252,int543)) +::STMT +MATRIX:A,B,C,X +%*%(<=(%*%(X,A),B),C) +::STMT +MATRIX:r,d,parsertemp43999 +cast.FLOAT(/(sum(*(r,r)),%*%(t(d),+(d,parsertemp43999)))) +::STMT +MATRIX:2814_K,2814_X +LITERAL_FLOAT:0.0 +*(cast.FLOAT(-(0.0,2814_K)),-(cast.FLOAT(2814_X),cast.FLOAT(2814_X))) +::STMT +MATRIX:posSamples,posSampleMeans +LITERAL_FLOAT:2.0,7000.0 +-(colSums(^(posSamples,2.0)),*(7000.0,^(posSampleMeans,2.0))) +::STMT +MATRIX:mu +FLOAT:q +LITERAL_FLOAT:4.0 +-(q,*(4.0,*(cast.FLOAT(mu),cast.FLOAT(mu)))) +::STMT +MATRIX:linear_terms,Y +LITERAL_FLOAT:-1.0 +*(^(linear_terms,-1.0),-(Y,linear_terms)) +::STMT +MATRIX:U,X,parsertemp382850 +LITERAL_FLOAT:0.0 +%*%(t(U),*(!=(X,0.0),-(%*%(U,parsertemp382850),X))) +::STMT +LITERAL_FLOAT:0.0,1.0 +INT:int379,parsertemp12177 +rand(parsertemp12177,int379,0.0,1.0) +::STMT +MATRIX:parsertemp553122,missing +t(%*%(rowSums(*(missing,missing)),parsertemp553122)) +::STMT +MATRIX:parsertemp171314,t_gp,parsertemp171318,parsertemp171306 +FLOAT:int866,float62 +LITERAL_FLOAT:1.0,2.0,0.254829592 +*(exp(/(*(parsertemp171318,int866),2.0)),*(/(1.0,+(float62,parsertemp171306)),+(0.254829592,*(t_gp,parsertemp171314)))) +::STMT +MATRIX:grad +FLOAT:psi +*(psi,sqrt(sum(*(grad,grad)))) +::STMT +MATRIX:dX,v,X +FLOAT:lr,mu ++(X,-(*(mu,v),*(lr,dX))) +::STMT +MATRIX:R,parsertemp40219 +FLOAT:numRows,level +/(numRows,-(R,rowSums(==(parsertemp40219,level)))) +::STMT +MATRIX:d_r,parsertemp409781 +cast.FLOAT(%*%(t(rev(d_r)),parsertemp409781)) +::STMT +MATRIX:287_x,287_y +LITERAL_FLOAT:2.0 +/(+(cast.FLOAT(287_x),cast.FLOAT(287_y)),2.0) +::STMT +MATRIX:aggr_best_index_vector +LITERAL_FLOAT:0.0,1.0 ++(sum(==(aggr_best_index_vector,0.0)),1.0) +::STMT +MATRIX:id +FLOAT:parsertemp22683 +cast.FLOAT(diag(diag(==(id,parsertemp22683)))) +::STMT +MATRIX:w,X,y +*(-(%*%(X,w),y),-(%*%(X,w),y)) +::STMT +LITERAL_FLOAT:2.0 +INT:int554,int716 +rand(int716,int554,2.0,2.0) +::STMT +LITERAL_FLOAT:0.0 +INT:int87,int416 +rand(int87,int416,0.0,0.0) +::STMT +FLOAT:window_size,parsertemp180776,n +LITERAL_FLOAT:1.0 +-(+(-(n,window_size),1.0),+(parsertemp180776,1.0)) +::STMT +MATRIX:outSize +LITERAL_FLOAT:0.0 +cast.FLOAT(>(outSize,0.0)) +::STMT +MATRIX:P,I,X2 +LITERAL_FLOAT:0.0 +==(*(t(%*%(X2,P)),I),0.0) +::STMT +MATRIX:P,I +LITERAL_FLOAT:0.0 +==(%*%(P,I),0.0) +::STMT +MATRIX:R,dssp +FLOAT:4_n +/(4_n,+(R,dssp)) +::STMT +MATRIX:X +LITERAL_FLOAT:4.0 +>(X,4.0) +::STMT +MATRIX:parsertemp146930,184_unnorm_probs,184_probs,parsertemp146928,183_dpred,184_scores +FLOAT:int466,parsertemp146927 +-(*(*(*(parsertemp146927,parsertemp146928),/(int466,parsertemp146930)),/(exp(184_scores),rowSums(184_unnorm_probs))),*(/(exp(184_scores),rowSums(184_unnorm_probs)),rowSums(*(183_dpred,184_probs)))) +::STMT +MATRIX:Y +cast.MATRIX(max(Y)) +::STMT +FLOAT:run_index +LITERAL_FLOAT:1.0,48.0 ++(*(48.0,-(run_index,1.0)),1.0) +::STMT +LITERAL_FLOAT:1.0E-7 +INT:int267,m +rand(m,int267,1.0E-7,1.0E-7) +::STMT +MATRIX:F +LITERAL_FLOAT:2.0 +/(t(colSums(F)),2.0) +::STMT +FLOAT:sum_y_test,n +LITERAL_FLOAT:2.0 +^(/(sum_y_test,n),2.0) +::STMT +MATRIX:x,y +LITERAL_FLOAT:2.0 +/(+(x,y),2.0) +::STMT +MATRIX:gXY +FLOAT:lambda,parsertemp171602,beta +LITERAL_FLOAT:2.0 +sum(^(+(*(parsertemp171602,gXY),*(lambda,beta)),2.0)) +::STMT +MATRIX:X_plane +LITERAL_FLOAT:0.0 +>(X_plane,0.0) +::STMT +MATRIX:cumLens +FLOAT:i +LITERAL_FLOAT:1.0 +/(-(i,1.0),cast.FLOAT(cumLens)) +::STMT +MATRIX:err,cCnts +FLOAT:minSup +LITERAL_FLOAT:0.0 +|(<(cCnts,minSup),==(err,0.0)) +::STMT +MATRIX:parsertemp220845,ZERODIAG +LITERAL_FLOAT:1.0E-12 ++(rowSums(*(exp(parsertemp220845),ZERODIAG)),1.0E-12) +::STMT +MATRIX:parsertemp11509 +LITERAL_FLOAT:2.0 +*(2.0,parsertemp11509) +::STMT +MATRIX:intercept +LITERAL_FLOAT:0.0 +*(0.0,intercept) +::STMT +MATRIX:parsertemp171083 +LITERAL_FLOAT:-2.0 +*(-2.0,parsertemp171083) +::STMT +MATRIX:shift_X +FLOAT:lambda,p_CG,parsertemp170060,temp_CG +*(+(+(*(lambda,p_CG),*(parsertemp170060,temp_CG)),*(cast.FLOAT(shift_X),cast.FLOAT(temp_CG))),sum(p_CG)) +::STMT +MATRIX:cumHistMul,offset +cast.FLOAT(<=(offset,cumHistMul)) +::STMT +MATRIX:P,Y,Z,ZERODIAG,parsertemp220891 +FLOAT:int631,parsertemp220894 +%*%(*(-(P,/(Z,parsertemp220894)),*(/(int631,parsertemp220891),ZERODIAG)),Y) +::STMT +MATRIX:X,MSE +LITERAL_FLOAT:2.0 +/(^(max(X),2.0),MSE) +::STMT +MATRIX:parsertemp10744,V,W,H,parsertemp10748 +FLOAT:Eps +/(%*%(V,t(*(H,parsertemp10744))),+(%*%(W,%*%(H,parsertemp10748)),Eps)) +::STMT +MATRIX:parsertemp460641 +LITERAL_FLOAT:0.282842712474619 +*(parsertemp460641,0.282842712474619) +::STMT +MATRIX:P,gradients,Phi_new,Theta +FLOAT:alpha ++(Phi_new,*(alpha,%*%(t(gradients),%*%(P,Theta)))) +::STMT +MATRIX:xs +FLOAT:252_x +LITERAL_FLOAT:10.0 +-(10.0,sum(>=(xs,252_x))) +::STMT +MATRIX:Yhat_prime,H3_prime,E,W4 +*(H3_prime,%*%(*(E,Yhat_prime),W4)) +::STMT +MATRIX:means,parsertemp560530 +LITERAL_FLOAT:5.0 +/(sum(<(*(means,parsertemp560530),5.0)),*(nrow(means),ncol(means))) +::STMT +MATRIX:79_77_X_row_norm,parsertemp17178,parsertemp17180,Y_block,parsertemp17170,79_77_Y_row_norm,X_block +LITERAL_FLOAT:0.9 +*(>(/(%*%(X_block,parsertemp17180),%*%(79_77_X_row_norm,parsertemp17178)),0.9),/(%*%(X_block,t(Y_block)),%*%(+(79_77_X_row_norm,parsertemp17170),t(79_77_Y_row_norm)))) +::STMT +MATRIX:tmp,w,out +LITERAL_FLOAT:1.0,0.5 ++(*(0.5,cast.FLOAT(%*%(out,out))),*(1.0,cast.FLOAT(%*%(w,tmp)))) +::STMT +MATRIX:confusionM +min(rowSums(confusionM)) +::STMT +MATRIX:parsertemp175056,316_scores,X +-(/(exp(-(X,parsertemp175056)),rowSums(exp(316_scores))),/(exp(-(X,parsertemp175056)),rowSums(exp(316_scores)))) +::STMT +FLOAT:m2,float885,wt +LITERAL_FLOAT:5.0 +*(5.0,sqrt(/(*(m2,wt),-(wt,float885)))) +::STMT +MATRIX:validKeyMask +cast.FLOAT(colSums(validKeyMask)) +::STMT +MATRIX:classes +LITERAL_FLOAT:1.0,0.8 +*(cast.FLOAT(classes),-(1.0,0.8)) +::STMT +MATRIX:U,V,X +-(%*%(U,t(V)),X) +::STMT +MATRIX:termination_bitmap,parsertemp222665 +LITERAL_FLOAT:10.0 ++(*(10.0,max(*(parsertemp222665,termination_bitmap))),10.0) +::STMT +MATRIX:sv,s,w,X,Y,out +FLOAT:lambda,step_sz +-(%*%(t(X),*(*(sv,out),Y)),*(lambda,+(w,*(step_sz,s)))) +::STMT +MATRIX:parsertemp195898 +FLOAT:parsertemp195895,factor_up +LITERAL_FLOAT:1.0 +-(1.0,abs(-(/(parsertemp195898,factor_up),/(parsertemp195895,factor_up)))) +::STMT +FLOAT:p_CG,parsertemp170088,z,pp_CG,parsertemp170090 +LITERAL_FLOAT:-1.0 +/(-(*(*(z,p_CG),-1.0),sqrt(-(parsertemp170088,parsertemp170090))),pp_CG) +::STMT +MATRIX:parsertemp31115,parsertemp31108 +FLOAT:parsertemp31116,parsertemp31109 +LITERAL_FLOAT:1500.0,2000.0 +sqrt(+(/(/(parsertemp31108,parsertemp31109),2000.0),/(/(parsertemp31115,parsertemp31116),1500.0))) +::STMT +MATRIX:t,parsertemp171083,parsertemp171092 +FLOAT:float141 +LITERAL_FLOAT:1.0,1.432788 ++(1.0,*(sqrt(*(float141,parsertemp171083)),+(1.432788,*(t,parsertemp171092)))) +::STMT +MATRIX:y +FLOAT:beta +LITERAL_FLOAT:2.0,100.0 +/(sum(^(-(beta,y),2.0)),100.0) +::STMT +MATRIX:p,q,V,parsertemp1939 +FLOAT:norm_r2,eps +*(/(norm_r2,cast.FLOAT(%*%(parsertemp1939,q))),+(%*%(t(V),%*%(V,p)),*(eps,p))) +::STMT +MATRIX:surv,se_surv +FLOAT:z_alpha_2 +/(*(z_alpha_2,se_surv),surv) +::STMT +MATRIX:rowSums_X_sq +LITERAL_FLOAT:3.5355339059327378 +/(3.5355339059327378,max(sqrt(rowSums_X_sq))) +::STMT +MATRIX:s,w,wnew +LITERAL_FLOAT:0.5 +*(0.5,cast.FLOAT(%*%(t(wnew),+(w,s)))) +::STMT +LITERAL_FLOAT:1.0E20 +INT:int563,n +rand(n,int563,1.0E20,1.0E20) +::STMT +FLOAT:prob_true,prob_false +LITERAL_FLOAT:2.0 ++(^(prob_true,2.0),^(prob_false,2.0)) +::STMT +MATRIX:R,dsep,dssm +FLOAT:2_eAvg +/(/(+(R,dsep),-(R,dssm)),2_eAvg) +::STMT +MATRIX:is_too_small,parsertemp171346,the_exp_exp,linear_terms,the_exp +FLOAT:int95,int146,int568,int902,int805 +LITERAL_FLOAT:1.0,1.0E7 ++(/(*(-(int805,is_too_small),-(int902,the_exp_exp)),+(exp(linear_terms),==(parsertemp171346,int568))),*(==(+(int146,the_exp),1.0E7),-(1.0,/(the_exp,int95)))) +::STMT +MATRIX:T_1,parsertemp410245,event,parsertemp410248 +FLOAT:int916,float628 +LITERAL_FLOAT:1.0,1.5 +/(^(/(*(parsertemp410245,int916),*(float628,parsertemp410248)),/(1.0,1.5)),/(-(max(T_1),min(T_1)),sum(event))) +::STMT +FLOAT:obj,objnew +/(abs(-(objnew,obj)),obj) +::STMT +FLOAT:padw,padh,Hin,Win +LITERAL_FLOAT:2.0 +*(+(Hin,*(2.0,padh)),+(Win,*(2.0,padw))) +::STMT +MATRIX:LHSthreshold +LITERAL_FLOAT:1.0 +>(LHSthreshold,1.0) +::STMT +MATRIX:2707_X,2706_dX +LITERAL_FLOAT:0.0 +colSums(*(>(2707_X,0.0),2706_dX)) +::STMT +MATRIX:parsertemp220853,parsertemp220854,Hneg,beta,betamin,Hpos +FLOAT:logU +LITERAL_FLOAT:0.0 +*(<(-(+(parsertemp220853,parsertemp220854),logU),0.0),+(beta,+(*(Hneg,betamin),*(Hpos,beta)))) +::STMT +MATRIX:linear_terms,Y +LITERAL_FLOAT:0.0 +*(^(exp(linear_terms),0.0),-(Y,exp(linear_terms))) +::STMT +FLOAT:R,eta,s +LITERAL_FLOAT:-1.0 +*(R,^(eta,*(s,-1.0))) +::STMT +FLOAT:sig,q,parsertemp181039,int284 +LITERAL_FLOAT:1.0,8.0 +*(8.0,-(1.0,/(-(q,parsertemp181039),*(int284,sig)))) +::STMT +MATRIX:Y,parsertemp283552 +-(sum(Y),parsertemp283552) +::STMT +MATRIX:newbeta,lambda +LITERAL_FLOAT:2.0 +%*%(t(lambda),^(newbeta,2.0)) +::STMT +LITERAL_FLOAT:10.0,1.5,-8.0 +*(1.5,^(10.0,-8.0)) +::STMT +MATRIX:Train,2342_m_colmax,2342_m_colmin +LITERAL_FLOAT:2.0 +/(*(2.0,-(Train,2342_m_colmin)),-(2342_m_colmax,2342_m_colmin)) +::STMT +MATRIX:parsertemp143446,parsertemp143445 +&(parsertemp143445,parsertemp143446) +::STMT +MATRIX:X_batch,dout1 +FLOAT:191_beta1 +LITERAL_FLOAT:1.0 +*(-(1.0,191_beta1),%*%(t(X_batch),dout1)) +::STMT +MATRIX:std,rad +-(rad,cast.FLOAT(std)) +::STMT +MATRIX:parsertemp171315,parsertemp171307,parsertemp171319 +FLOAT:float489,float311,float639 +LITERAL_FLOAT:2.0 +-(2.0,*(exp(/(parsertemp171319,float489)),*(/(float311,parsertemp171307),+(float639,parsertemp171315)))) +::STMT +MATRIX:y_corr +LITERAL_FLOAT:0.5 +>(y_corr,0.5) +::STMT +MATRIX:s,sts,d,parsertemp44023 +FLOAT:delta2 +LITERAL_FLOAT:2.0 ++(^(%*%(t(s),d),2.0),*(cast.FLOAT(%*%(parsertemp44023,d)),-(delta2,cast.FLOAT(sts)))) +::STMT +MATRIX:t_gp,parsertemp560881,parsertemp560864,parsertemp560863,parsertemp560877 +FLOAT:int551,int310,float761 +LITERAL_FLOAT:1.0 ++(-(1.0,-(*(int310,parsertemp560863),1.0)),*(*(*(t_gp,parsertemp560877),-(parsertemp560864,int551)),exp(/(parsertemp560881,float761)))) +::STMT +MATRIX:parsertemp43620,y +FLOAT:float213 +LITERAL_FLOAT:1.0 +*(-(/(1.0,+(float213,parsertemp43620)),1.0),y) +::STMT +MATRIX:X_plane,parsertemp11251 +LITERAL_FLOAT:0.0,2.0 +*(>(X_plane,0.0),t(^(2.0,parsertemp11251))) +::STMT +MATRIX:p,parsertemp285529,g +FLOAT:pp,pq,int41,pz,parsertemp285521,parsertemp285537 +*(+(+(*(parsertemp285537,pq),sum(parsertemp285529)),sum(*(g,p))),/(+(*(pz,int41),sqrt(parsertemp285521)),pp)) +::STMT +MATRIX:W1_rand +FLOAT:num_hidden1,m +LITERAL_FLOAT:6.0 +*(/(sqrt(6.0),sqrt(+(m,num_hidden1))),W1_rand) +::STMT +FLOAT:int584,m2,float284 +LITERAL_FLOAT:2003.0 +sqrt(*(/(2003.0,-(int584,float284)),m2)) +::STMT +LITERAL_FLOAT:1.0E-7 +1.0E-7 +::STMT +MATRIX:parsertemp27746,parsertemp27872 +FLOAT:featureCorrection +LITERAL_FLOAT:0.0 ++(%*%(parsertemp27872,t(parsertemp27746)),-(0.0,featureCorrection)) +::STMT +MATRIX:scale_X,parsertemp429910 +LITERAL_FLOAT:300.0,0.0 +*(-(0.0,/(t(parsertemp429910),300.0)),scale_X) +::STMT +MATRIX:parsertemp79022 +LITERAL_FLOAT:0.5,1270.0 +round(+(0.5,/(parsertemp79022,1270.0))) +::STMT +MATRIX:prec_chol,X +LITERAL_FLOAT:2.0 +%*%(^(X,2.0),t(^(prec_chol,2.0))) +::STMT +MATRIX:t_gp,pt_gp,parsertemp171320,Y,the_gauss_exp,parsertemp171316 +LITERAL_FLOAT:2.0,0.25,0.15915494309189535 +/(*(*(exp(parsertemp171320),0.15915494309189535),rowSums(Y)),*(*(0.25,*(t_gp,parsertemp171316)),-(2.0,*(the_gauss_exp,pt_gp)))) +::STMT +MATRIX:cumHistMul,offset,parsertemp132495,histMul,outBucket +LITERAL_FLOAT:1.0 +-(-(offset,%*%(==(outBucket,parsertemp132495),-(cumHistMul,histMul))),1.0) +::STMT +MATRIX:parsertemp1904,y +LITERAL_FLOAT:-1.0 +sum(*(*(%*%(parsertemp1904,y),-1.0),*(%*%(parsertemp1904,y),-1.0))) +::STMT +MATRIX:y +FLOAT:beta +LITERAL_FLOAT:2.0,10.0 +/(sum(^(-(beta,y),2.0)),10.0) +::STMT +FLOAT:i,k +LITERAL_FLOAT:2.0,4.0 +-(+(+(i,k),4.0),2.0) +::STMT +MATRIX:X +FLOAT:M +/(ncol(X),M) +::STMT +MATRIX:X +LITERAL_FLOAT:200.0 +/(t(colSums(X)),200.0) +::STMT +FLOAT:s,num_groups +LITERAL_FLOAT:1.0 +*(-(s,1.0),-(num_groups,1.0)) +::STMT +MATRIX:id +==(id,cast.FLOAT(id)) +::STMT +MATRIX:R,svLowBnd +>(R,cast.FLOAT(svLowBnd)) +::STMT +MATRIX:X +LITERAL_FLOAT:300.0 +/(t(colSums(X)),300.0) +::STMT +FLOAT:s +LITERAL_FLOAT:-1.0,50.0,3.0 +*(50.0,^(3.0,*(s,-1.0))) +::STMT +FLOAT:var,arch_coef,xt,var_coef,int838,a0 ++(+(a0,*(arch_coef,^(xt,int838))),*(var_coef,var)) +::STMT +MATRIX:parsertemp171318 +FLOAT:int267,one_over_sqrt_two_pi +LITERAL_FLOAT:2.0 +*(exp(/(*(parsertemp171318,int267),2.0)),^(one_over_sqrt_two_pi,2.0)) +::STMT +MATRIX:ssX_V,X,parsertemp150463,P_1K +%*%(rowSums(*(P_1K,%*%(X,ssX_V))),parsertemp150463) +::STMT +MATRIX:sv,out +LITERAL_FLOAT:2.0,0.5 +*(0.5,sum(^(*(sv,out),2.0))) +::STMT +MATRIX:probs,y_batch +LITERAL_FLOAT:0.0,1.0,1.0E-10 +*(*(/(1.0,nrow(y_batch)),-(0.0,y_batch)),/(1.0,+(probs,1.0E-10))) +::STMT +FLOAT:i,cols,n +LITERAL_FLOAT:1.0 +-(n,-(+(i,cols),1.0)) +::STMT +MATRIX:parsertemp222331 +LITERAL_FLOAT:200.0,0.5 ++(0.5,/(parsertemp222331,200.0)) +::STMT +LITERAL_FLOAT:1.0,2.0,2000.0 +-(^(2000.0,2.0),1.0) +::STMT +MATRIX:parsertemp175083 +LITERAL_FLOAT:1.0E-6 +cast.MATRIX(sum(<(abs(parsertemp175083),1.0E-6))) +::STMT +MATRIX:Y,linear_terms +LITERAL_FLOAT:0.0 +-(Y,*(rowSums(Y),>=(linear_terms,0.0))) +::STMT +MATRIX:parsertemp44079 +FLOAT:C +LITERAL_FLOAT:-1.0 +*(C,sum(*(parsertemp44079,-1.0))) +::STMT +MATRIX:y_corr +LITERAL_FLOAT:0.0,1.0 +-(1.0,<=(y_corr,0.0)) +::STMT +FLOAT:qmle,var_t,int653,xq_t,parsertemp496694,n +LITERAL_FLOAT:1.0 +-(qmle,*(/(1.0,*(int653,n)),+(parsertemp496694,/(xq_t,var_t)))) +::STMT +MATRIX:b4,parsertemp389338 +LITERAL_FLOAT:2.0 +exp(*(2.0,t(+(parsertemp389338,b4)))) +::STMT +MATRIX:parsertemp397828,parsertemp397825,W3_rand +LITERAL_FLOAT:0.5107539184552492 +t(%*%(*(0.5107539184552492,W3_rand),t(/(parsertemp397825,parsertemp397828)))) +::STMT +MATRIX:wnew,parsertemp44111 +LITERAL_FLOAT:2.0 +sqrt(sum(^(+(wnew,parsertemp44111),2.0))) +::STMT +MATRIX:_sbcvar2306 +LITERAL_FLOAT:1.0 ++(max(t(_sbcvar2306)),1.0) +::STMT +MATRIX:simplex +LITERAL_FLOAT:2.0 +*(2.0,/(-(rowSums(simplex),simplex),nrow(simplex))) +::STMT +MATRIX:W1_rand,stds,parsertemp394896 +LITERAL_FLOAT:0.08146881698903526 +t(%*%(*(0.08146881698903526,W1_rand),t(/(parsertemp394896,stds)))) +::STMT +MATRIX:V,y +%*%(t(V),y) +::STMT +MATRIX:is_natural_parameter_log_zero,Y +LITERAL_FLOAT:0.0,1.0 +-(1.0,*(>(Y,0.0),is_natural_parameter_log_zero)) +::STMT +FLOAT:int143,o_init,int524,o +LITERAL_FLOAT:-1.0,50.0 +/(*(-(*(int524,o_init),*(int143,o)),-1.0),50.0) +::STMT +MATRIX:U,V_sum +/(*(U,U),sum(V_sum)) +::STMT +FLOAT:parsertemp565893,h,y_offset +LITERAL_FLOAT:1.0 +-(+(+(parsertemp565893,y_offset),h),1.0) +::STMT +LITERAL_FLOAT:0.054717579189018505 +0.054717579189018505 +::STMT +MATRIX:X_batch,dout1,mW1 +FLOAT:191_beta1 +LITERAL_FLOAT:1.0 ++(*(191_beta1,mW1),*(-(1.0,191_beta1),%*%(t(X_batch),dout1))) +::STMT +MATRIX:X_batch,parsertemp389606,parsertemp389591,2364_2361_Y,parsertemp389588,W4 +FLOAT:int318 +LITERAL_FLOAT:1.0,2.0 +*(-(1.0,^(/(parsertemp389588,parsertemp389591),2.0)),%*%(*(-(2364_2361_Y,X_batch),-(int318,parsertemp389606)),W4)) +::STMT +MATRIX:d_r_rev,Hd_1,Hd_2 +t(colSums(*(-(Hd_1,Hd_2),d_r_rev))) +::STMT +MATRIX:I,parsertemp472360 +LITERAL_FLOAT:0.0 +*(I,==(!=(*(parsertemp472360,I),0.0),0.0)) +::STMT +LITERAL_FLOAT:1.0,0.8 +-(1.0,-(1.0,0.8)) +::STMT +MATRIX:parsertemp222700,parsertemp222697,parsertemp222694 +FLOAT:int857 +t(<=(+(*(int857,parsertemp222694),t(parsertemp222697)),parsertemp222700)) +::STMT +FLOAT:int227,429_C +LITERAL_FLOAT:1.0,2.0 +sqrt(/(2.0,*(*(429_C,int227),1.0))) +::STMT +MATRIX:finite_linear_terms +LITERAL_FLOAT:-1.0,2.0 +*(^(finite_linear_terms,2.0),-1.0) +::STMT +MATRIX:X,Y,K +FLOAT:x +LITERAL_FLOAT:1.0 +*(-(*(K,-(X,X)),-(Y,Y)),-(1.0,/(-(x,X),-(X,X)))) +::STMT +MATRIX:V +-(max(V),min(V)) +::STMT +FLOAT:2690_Hin +LITERAL_FLOAT:0.0,2.0 ++(2690_Hin,*(2.0,0.0)) +::STMT +MATRIX:parsertemp386457,parsertemp386459,neighbors,parsertemp386455 +LITERAL_FLOAT:0.0 +==(-(*(*(neighbors,parsertemp386455),parsertemp386457),parsertemp386459),0.0) +::STMT +MATRIX:grad +FLOAT:int396,int927 +sqrt(sum(*(*(grad,int927),*(grad,int396)))) +::STMT +MATRIX:residuals_vector +FLOAT:lambda +/(sum(residuals_vector),+(nrow(residuals_vector),lambda)) +::STMT +MATRIX:g0_2,g0_1,g0 +LITERAL_FLOAT:1.0E-12 +*(cast.FLOAT(%*%(t(g0),+(g0_1,g0_2))),1.0E-12) +::STMT +MATRIX:Yhat_prime,E +colSums(*(E,Yhat_prime)) +::STMT +FLOAT:m2 +LITERAL_FLOAT:1.0005002501250626 +*(1.0005002501250626,m2) +::STMT +FLOAT:sim_score_left,sim_score_right,sim_score_parent +-(+(sim_score_left,sim_score_right),sim_score_parent) +::STMT +MATRIX:samples_vs_runs_map,X_samples_sq_norms,parsertemp222439,parsertemp222443,X_samples +LITERAL_FLOAT:2.0 +-(+(X_samples_sq_norms,%*%(samples_vs_runs_map,rowSums(parsertemp222439))),*(2.0,rowSums(*(X_samples,parsertemp222443)))) +::STMT +MATRIX:parsertemp500609,parsertemp500606,parsertemp500604,X,y +FLOAT:int564 +-(%*%(X,*(*(parsertemp500604,parsertemp500606),>(parsertemp500609,int564))),y) +::STMT +FLOAT:window_size,i,k +LITERAL_FLOAT:2.0 +-(+(+(i,k),window_size),2.0) +::STMT +LITERAL_FLOAT:4.890349128221754 +4.890349128221754 +::STMT +MATRIX:negSampleMeans,negSamples +FLOAT:int877,int492 +LITERAL_FLOAT:1.0,150.0 +/(-(colSums(^(negSamples,int492)),*(150.0,^(negSampleMeans,int877))),-(150.0,1.0)) +::STMT +MATRIX:y_val,preds +%*%(t(-(y_val,preds)),-(y_val,preds)) +::STMT +MATRIX:A +abs(t(A)) +::STMT +FLOAT:check_max,check_min +LITERAL_FLOAT:2.0 +/(2.0,-(check_max,check_min)) +::STMT +FLOAT:idx +LITERAL_FLOAT:1.0,253.0 ++(-(253.0,idx),1.0) +::STMT +MATRIX:ZtZ,Xm,parsertemp265719,parsertemp265718,parsertemp265714 +LITERAL_FLOAT:2.0 +-(+(sum(*(Xm,Xm)),trace(*(ZtZ,parsertemp265714))),*(2.0,sum(%*%(parsertemp265718,parsertemp265719)))) +::STMT +MATRIX:W +FLOAT:int573,parsertemp97,int148,m4,int722,wt,int371 +LITERAL_FLOAT:1.0 +-(*(*(^(wt,int371),+(wt,int148)),m4),*(*(*(int722,parsertemp97),^(wt,int573)),-(sum(W),1.0))) +::STMT +MATRIX:r_CG,p_CG +FLOAT:rr_CG,old_rr_CG +LITERAL_FLOAT:-1.0 ++(*(r_CG,-1.0),*(/(rr_CG,old_rr_CG),p_CG)) +::STMT +FLOAT:int153,float879,float406,int53 +LITERAL_FLOAT:1.0,3.0,6.0,2003.0 +/(*(*(6.0,2003.0),-(2003.0,1.0)),*(*(-(int153,float879),+(int53,float406)),+(2003.0,3.0))) +::STMT +FLOAT:429_C +LITERAL_FLOAT:1.0,2.0 +/(2.0,*(*(429_C,1.0),1.0)) +::STMT +MATRIX:S,V,parsertemp149285 +FLOAT:int503,delta2 +LITERAL_FLOAT:2.0 ++(^(sum(*(S,V)),2.0),*(sum(^(V,int503)),-(delta2,sum(parsertemp149285)))) +::STMT +MATRIX:p,q,A +FLOAT:norm_r2 +*(/(norm_r2,sum(*(p,q))),%*%(t(A),%*%(A,p))) +::STMT +MATRIX:r,Hd +FLOAT:parsertemp44049 +sum(*(-(r,*(parsertemp44049,Hd)),-(r,*(parsertemp44049,Hd)))) +::STMT +LITERAL_FLOAT:0.0,1.0 +INT:int633,int424 ++(1.0,exp(rand(int633,int424,0.0,0.0))) +::STMT +MATRIX:s,d,tau ++(s,*(cast.FLOAT(tau),d)) +::STMT +MATRIX:leaf_ids +FLOAT:boundary_left,step_size +&(>=(leaf_ids,boundary_left),<(leaf_ids,+(boundary_left,step_size))) +::STMT +MATRIX:P,Q,Y,Z,ZERODIAG +*(Y,rowSums(*(-(P,Q),*(Z,ZERODIAG)))) +::STMT +MATRIX:B,X,y +-(y,%*%(X,B)) +::STMT +MATRIX:s,d +FLOAT:norm_r2,alpha_deno +%*%(t(+(s,*(norm_r2,d))),+(s,*(/(norm_r2,alpha_deno),d))) +::STMT +MATRIX:parsertemp437192,parsertemp437191,parsertemp437237,mean,weight,avgMean +FLOAT:int874 +LITERAL_FLOAT:1.0E-9 ++(+(-(/(parsertemp437237,parsertemp437192),*(int874,avgMean)),/(*(mean,parsertemp437191),t(weight))),1.0E-9) +::STMT +MATRIX:W,X,H,parsertemp411105,parsertemp411107 +LITERAL_FLOAT:1.0E-8 +/(%*%(X,t(*(H,parsertemp411105))),+(%*%(W,%*%(H,parsertemp411107)),1.0E-8)) +::STMT +FLOAT:m2 +LITERAL_FLOAT:5.0,1.0005 +*(5.0,sqrt(*(1.0005,m2))) +::STMT +MATRIX:parsertemp129186,parsertemp129185,key_unique,key +==(%*%(key_unique,parsertemp129185),%*%(parsertemp129186,t(key))) +::STMT +MATRIX:hubs +LITERAL_FLOAT:2.0 +abs(sum(^(-(hubs,hubs),2.0))) +::STMT +MATRIX:P,N_T,X,parsertemp230442 +<=(rowSums(*(X,parsertemp230442)),%*%(P,t(N_T))) +::STMT +MATRIX:R,parsertemp497774 +LITERAL_FLOAT:0.0 +-(ncol(R),sum(==(colSums(parsertemp497774),0.0))) +::STMT +MATRIX:A +FLOAT:parsertemp22359,a21,parsertemp22358,int923 +LITERAL_FLOAT:1.0 +sqrt(+(+(+(parsertemp22358,parsertemp22359),/(int923,a21)),/(1.0,cast.FLOAT(A)))) +::STMT +LITERAL_FLOAT:8.660254037844387 +8.660254037844387 +::STMT +MATRIX:y +FLOAT:beta +LITERAL_FLOAT:2.0 +^(-(beta,y),2.0) +::STMT +MATRIX:D,parsertemp570375,classMeans +%*%(-(D,classMeans),parsertemp570375) +::STMT +FLOAT:481_Hf,481_Hin +LITERAL_FLOAT:0.0,2.0 +-(+(481_Hin,*(2.0,0.0)),481_Hf) +::STMT +MATRIX:parsertemp10964,C +sum(==(parsertemp10964,C)) +::STMT +MATRIX:parsertemp146940,184_dtemp,mW3,outr2 +FLOAT:beta1 +LITERAL_FLOAT:1.0 ++(*(beta1,mW3),*(-(1.0,beta1),%*%(t(outr2),-(184_dtemp,parsertemp146940)))) +::STMT +MATRIX:G,authorities +max(%*%(G,authorities)) +::STMT +MATRIX:nI +LITERAL_FLOAT:0.25 +*(0.25,ncol(nI)) +::STMT +FLOAT:int455,int456,o_init,N,o +LITERAL_FLOAT:-1.0 +/(*(-(*(int456,o_init),*(int455,o)),-1.0),N) +::STMT +MATRIX:confusionM +min(colSums(confusionM)) +::STMT +MATRIX:parsertemp383011,X,X_nonzero_ind +LITERAL_FLOAT:2.0 +sum(*(X_nonzero_ind,^(-(X,parsertemp383011),2.0))) +::STMT +MATRIX:parsertemp498248,m_iter_err_sum,m_err +FLOAT:int526,i_process_item +LITERAL_FLOAT:2.0 +*(*(2.0,/(-(int526,parsertemp498248),i_process_item)),+(colSums(m_err),m_iter_err_sum)) +::STMT +MATRIX:std,sts,rad +FLOAT:delta2 +/(-(delta2,sts),+(std,rad)) +::STMT +MATRIX:_sbcvar1708 +LITERAL_FLOAT:105.0 ++(105.0,nrow(_sbcvar1708)) +::STMT +MATRIX:parsertemp414375,parsertemp414377,parsertemp414379 +FLOAT:int577,int293 +LITERAL_FLOAT:0.0,1.0,199.0 +*(/(-(t(parsertemp414375),*(int577,parsertemp414377)),199.0),-(1.0,<=(/(parsertemp414379,int293),0.0))) +::STMT +MATRIX:maskNAN +LITERAL_FLOAT:0.0 +!=(rowSums(maskNAN),0.0) +::STMT +MATRIX:parsertemp171083 +LITERAL_FLOAT:-2.0,-1.0 +*(sqrt(*(-2.0,parsertemp171083)),-1.0) +::STMT +MATRIX:parsertemp170248,parsertemp170253,parsertemp170240,lt_pos_neg +FLOAT:float811,float257,float69 ++(lt_pos_neg,*(*(-(float257,lt_pos_neg),exp(parsertemp170253)),*(/(float811,parsertemp170240),+(float69,parsertemp170248)))) +::STMT +MATRIX:prec_chol,X,mu +FLOAT:int69 +%*%(X,t(*(mu,^(prec_chol,int69)))) +::STMT +MATRIX:parsertemp13624,_sbcvar11 +FLOAT:int171 +LITERAL_FLOAT:2.0,1000.0 +/(^(-(_sbcvar11,/(parsertemp13624,int171)),2.0),/(%*%(rowSums(_sbcvar11),colSums(_sbcvar11)),1000.0)) +::STMT +MATRIX:r,Hd +FLOAT:parsertemp44049 +LITERAL_FLOAT:2.0 +sum(^(-(r,*(parsertemp44049,Hd)),2.0)) +::STMT +MATRIX:tmp_Xw,parsertemp260747,Y,Xw +LITERAL_FLOAT:0.0,1.0 +*(-(1.0,*(Y,+(Xw,parsertemp260747))),>(-(1.0,*(Y,tmp_Xw)),0.0)) +::STMT +MATRIX:out2,parsertemp146942,184_dscores,maskd1,out1,W2 +FLOAT:p,int336 +LITERAL_FLOAT:0.0 +*(*(>(out1,0.0),/(maskd1,p)),%*%(*(>(out2,int336),%*%(184_dscores,parsertemp146942)),t(W2))) +::STMT +MATRIX:is_LT_infinite,parsertemp171366,p_one_m_one +LITERAL_FLOAT:3.141592653589793,1.0,0.5 +*(+(0.5,/(%*%(parsertemp171366,p_one_m_one),3.141592653589793)),-(1.0,rowSums(is_LT_infinite))) +::STMT +MATRIX:parsertemp231012 +FLOAT:parsertemp231013 +LITERAL_FLOAT:1.0,2.0 +-(1.0,sum(^(/(parsertemp231012,parsertemp231013),2.0))) +::STMT +MATRIX:V,y +LITERAL_FLOAT:0.0,2.0 +^(-(0.0,%*%(t(V),y)),2.0) +::STMT +MATRIX:c,x_r +LITERAL_FLOAT:2.0 +-(*(2.0,x_r),c) +::STMT +MATRIX:X +FLOAT:int758 +LITERAL_FLOAT:1.0 +/(1.0,+(1.0,exp(*(X,int758)))) +::STMT +MATRIX:vW1,W1,dW1 +FLOAT:2727_mu,2727_lr +LITERAL_FLOAT:1.0 ++(-(W1,*(2727_mu,vW1)),*(+(1.0,2727_mu),-(*(2727_mu,vW1),*(2727_lr,dW1)))) +::STMT +MATRIX:W +FLOAT:m2,wt,float491 +/(sqrt(/(*(m2,wt),-(wt,float491))),sqrt(sum(round(W)))) +::STMT +MATRIX:P,Q +LITERAL_FLOAT:-2.0 ++(*(-2.0,%*%(P,t(Q))),P) +::STMT +MATRIX:X,y +FLOAT:float984,float563 +LITERAL_FLOAT:-1.0 +INT:int154,int667 +exp(*(*(y,-1.0),%*%(X,rand(int154,int667,float984,float563)))) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,1024.0 +-(+(i,1024.0),1.0) +::STMT +MATRIX:y +LITERAL_FLOAT:1.0 +/(1.0,nrow(y)) +::STMT +MATRIX:X +*(nrow(X),ncol(X)) +::STMT +MATRIX:_sbcvar78 +LITERAL_FLOAT:10000.0 +-(_sbcvar78,/(%*%(rowSums(_sbcvar78),colSums(_sbcvar78)),10000.0)) +::STMT +MATRIX:parsertemp43619 +LITERAL_FLOAT:1.0 +-(1.0,/(1.0,+(1.0,exp(parsertemp43619)))) +::STMT +MATRIX:parsertemp383012,parsertemp383020,parsertemp383017,X_nonzero_ind +FLOAT:reg,int800 ++(sum(*(X_nonzero_ind,^(parsertemp383012,int800))),*(reg,+(sum(parsertemp383017),sum(parsertemp383020)))) +::STMT +MATRIX:parsertemp400673,W4_rand +FLOAT:int116,int619 +LITERAL_FLOAT:0.08720414403938946 +%*%(*(0.08720414403938946,W4_rand),t(/(-(parsertemp400673,int619),+(parsertemp400673,int116)))) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,1048.0 +-(+(i,1048.0),1.0) +::STMT +MATRIX:parsertemp570381,parsertemp570372,parsertemp570376,parsertemp570377 +FLOAT:int431,int433,int633,int645 ++(parsertemp570381,-(*(/(int433,int431),parsertemp570372),*(/(int633,int645),%*%(parsertemp570376,parsertemp570377)))) +::STMT +MATRIX:parsertemp389580,parsertemp389562,parsertemp389565,2365_delta3,W2,W3 +FLOAT:int926 +LITERAL_FLOAT:1.0,2.0 +*(-(1.0,^(/(parsertemp389562,parsertemp389565),2.0)),%*%(*(-(int926,parsertemp389580),%*%(2365_delta3,W3)),W2)) +::STMT +MATRIX:Y,parsertemp2798,Xw +FLOAT:int92 +LITERAL_FLOAT:0.0,1.0,2.0 +^(*(>(-(int92,parsertemp2798),0.0),-(1.0,*(Y,Xw))),2.0) +::STMT +MATRIX:s,d,alpha +t(-(s,*(cast.FLOAT(alpha),d))) +::STMT +MATRIX:parsertemp31189,parsertemp31187 +FLOAT:int226,int613 +LITERAL_FLOAT:1.0,2.0,7000.0 +/(^(/(-(parsertemp31187,parsertemp31189),-(int226,int613)),2.0),*(^(7000.0,2.0),-(7000.0,1.0))) +::STMT +MATRIX:col,less_than_lb,parsertemp24102,parsertemp24103 +FLOAT:int760,num_bins,int226 +LITERAL_FLOAT:1.0 ++(*(-(-(int226,less_than_lb),>(col,num_bins)),+(round(parsertemp24102),1.0)),*(>(+(parsertemp24103,int760),num_bins),num_bins)) +::STMT +FLOAT:m2Y,sigmaX,covXY,parsertemp26584 +/(covXY,*(sigmaX,sqrt(*(m2Y,parsertemp26584)))) +::STMT +MATRIX:g,parsertemp169907 +sqrt(sum(*(+(g,parsertemp169907),+(g,parsertemp169907)))) +::STMT +MATRIX:2814_K,2814_X,2814_Y +FLOAT:int302 ++(*(cast.FLOAT(-(int302,2814_K)),-(cast.FLOAT(2814_X),cast.FLOAT(2814_X))),-(cast.FLOAT(2814_Y),cast.FLOAT(2814_Y))) +::STMT +MATRIX:Y +cast.MATRIX(min(Y)) +::STMT +MATRIX:tmp_Xw,parsertemp260749,Y +FLOAT:int438 +LITERAL_FLOAT:0.0,1.0 +*(*(-(1.0,*(Y,tmp_Xw)),>(-(int438,parsertemp260749),0.0)),Y) +::STMT +MATRIX:parsertemp31732,parsertemp31734,dssm,dsem +FLOAT:5_eAvg +LITERAL_FLOAT:1.0 +-(/(/(-(parsertemp31734,dsem),-(parsertemp31732,dssm)),5_eAvg),1.0) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,133.0 +*(133.0,-(i,1.0)) +::STMT +MATRIX:out2,parsertemp146942,184_dscores,outd1 +FLOAT:beta1,int17 +LITERAL_FLOAT:1.0 +*(-(1.0,beta1),%*%(t(outd1),*(>(out2,int17),%*%(184_dscores,parsertemp146942)))) +::STMT +MATRIX:_sbcvar92 +LITERAL_FLOAT:0.0 +==(/(%*%(rowSums(_sbcvar92),colSums(_sbcvar92)),sum(_sbcvar92)),0.0) +::STMT +MATRIX:parsertemp382672,parsertemp382681,parsertemp382668,parsertemp382678 +FLOAT:reg +LITERAL_FLOAT:0.5 ++(*(0.5,sum(*(parsertemp382668,parsertemp382672))),*(*(0.5,reg),+(sum(parsertemp382678),sum(parsertemp382681)))) +::STMT +MATRIX:intercept +FLOAT:int172,int470 +INT:num_records,int150 +%*%(rand(num_records,int150,int172,int470),intercept) +::STMT +MATRIX:neighbors +FLOAT:eps +<=(-(neighbors,diag(diag(neighbors))),eps) +::STMT +MATRIX:R,w +INT:parsertemp31673,int63 ++(R,diag(rand(parsertemp31673,int63,cast.FLOAT(w),cast.FLOAT(w)))) +::STMT +MATRIX:240_elt,240_ones_ctg +%*%(rowSums(240_elt),t(240_ones_ctg)) +::STMT +MATRIX:p +FLOAT:eps +*(eps,p) +::STMT +MATRIX:sample_rec_ids +FLOAT:num_records +LITERAL_FLOAT:1.0 +*(+(num_records,1.0),-(1.0,<=(sample_rec_ids,num_records))) +::STMT +MATRIX:s,parsertemp44005,d +FLOAT:parsertemp44004 +cast.FLOAT(%*%(t(+(s,parsertemp44005)),+(s,*(parsertemp44004,d)))) +::STMT +MATRIX:X_batch,2365_delta2,W2,parsertemp389567 +FLOAT:int376 +%*%(t(*(-(int376,parsertemp389567),%*%(2365_delta2,W2))),X_batch) +::STMT +MATRIX:A,b +LITERAL_FLOAT:-1.0 +*(%*%(*(t(A),-1.0),b),-1.0) +::STMT +MATRIX:X,mu,precisions +LITERAL_FLOAT:2.0 +*(2.0,%*%(X,t(*(mu,precisions)))) +::STMT +FLOAT:var_power,link_power +LITERAL_FLOAT:2.0 +-(/(-(2.0,var_power),link_power),2.0) +::STMT +MATRIX:Y,the_exp +FLOAT:int14 +-(*(rowSums(Y),exp(-(int14,the_exp))),Y) +::STMT +MATRIX:cumHistMul,offset +<=(offset,cumHistMul) +::STMT +FLOAT:current_hash_value +LITERAL_FLOAT:1.0,33.0 +-(33.0,+(current_hash_value,1.0)) +::STMT +MATRIX:F,parsertemp27458 +FLOAT:W +LITERAL_FLOAT:0.0,1.0E-4 ++(*(==(/(parsertemp27458,W),0.0),1.0E-4),/(%*%(rowSums(F),colSums(F)),sum(F))) +::STMT +MATRIX:D,parsertemp570375,classMeans +LITERAL_FLOAT:0.5 +*(0.5,%*%(%*%(-(D,classMeans),parsertemp570375),t(-(D,classMeans)))) +::STMT +MATRIX:parsertemp393571,W3_rand,parsertemp393574 +LITERAL_FLOAT:0.128920512778062 +t(%*%(*(0.128920512778062,W3_rand),t(/(parsertemp393571,parsertemp393574)))) +::STMT +MATRIX:_sbcvar1716 +LITERAL_FLOAT:120.0 ++(120.0,nrow(_sbcvar1716)) +::STMT +MATRIX:negSampleMeans,negSamples +LITERAL_FLOAT:2.0,150.0 +-(colSums(^(negSamples,2.0)),*(150.0,^(negSampleMeans,2.0))) +::STMT +MATRIX:Mask1 +LITERAL_FLOAT:0.0 +>(colSums(Mask1),0.0) +::STMT +MATRIX:X,y +LITERAL_FLOAT:0.0,1.0 +INT:int942,m +-(%*%(X,rand(m,int942,0.0,1.0)),y) +::STMT +MATRIX:MDx,MUx,MLx ++(+(MUx,MDx),MLx) +::STMT +FLOAT:ssPrev,parsertemp265727,parsertemp265726 +LITERAL_FLOAT:1.0 +abs(-(1.0,/(/(parsertemp265726,parsertemp265727),ssPrev))) +::STMT +MATRIX:ytest +LITERAL_FLOAT:2.0 +^(/(sum(ytest),nrow(ytest)),2.0) +::STMT +MATRIX:means,Y_counts,parsertemp560529 +LITERAL_FLOAT:1.0 +sum(<(*(means,%*%(Y_counts,parsertemp560529)),1.0)) +::STMT +MATRIX:t,parsertemp171088,parsertemp171083,parsertemp171094 +FLOAT:float707 +LITERAL_FLOAT:0.0,1.0,2.515517 ++(-(0.0,sqrt(*(float707,parsertemp171083))),/(+(2.515517,*(t,parsertemp171088)),+(1.0,*(t,parsertemp171094)))) +::STMT +LITERAL_FLOAT:2000.0 +sqrt(2000.0) +::STMT +MATRIX:parsertemp31024,parsertemp31022 +FLOAT:int777 +LITERAL_FLOAT:1.0,100.0 +/(/(-(colSums(parsertemp31022),*(int777,parsertemp31024)),-(100.0,1.0)),100.0) +::STMT +MATRIX:r,d,parsertemp43999 +cast.FLOAT(/(sum(*(r,r)),%*%(t(d),+(d,parsertemp43999)))) +::STMT +MATRIX:C,parsertemp265706,parsertemp265704,Z,XtZ +FLOAT:ss,ZtZ_sum +trace(*(+(%*%(parsertemp265704,Z),*(parsertemp265706,ss)),%*%(t(C),/(XtZ,ZtZ_sum)))) +::STMT +FLOAT:sample_frac +LITERAL_FLOAT:0.0,1.0 +INT:parsertemp553005,int999 +<=(rand(parsertemp553005,int999,0.0,1.0),sample_frac) +::STMT +MATRIX:classFeatureCounts +FLOAT:laplaceCorrection ++(classFeatureCounts,laplaceCorrection) +::STMT +MATRIX:U,row_nonzeros +LITERAL_FLOAT:2.0 +*(^(U,2.0),row_nonzeros) +::STMT +MATRIX:184_probs,183_dpred,parsertemp146939,outr2 +FLOAT:beta1 +LITERAL_FLOAT:1.0 +*(-(1.0,beta1),%*%(t(outr2),-(*(183_dpred,184_probs),*(184_probs,parsertemp146939)))) +::STMT +MATRIX:X,y +LITERAL_FLOAT:0.0,-1.0 +INT:int578,int705 +*(*(y,-1.0),%*%(X,rand(int578,int705,0.0,0.0))) +::STMT +FLOAT:int625 +LITERAL_FLOAT:-1.0 +INT:int426,int191 ++(diag(rand(int426,int191,-1.0,-1.0)),int625) +::STMT +MATRIX:Bxu,Bxd +LITERAL_FLOAT:2.0 +diag(*(2.0,+(Bxd,Bxu))) +::STMT +MATRIX:45_CVars,45_CFreqs +FLOAT:float192,int474,parsertemp13703,int43,int766 +LITERAL_FLOAT:1.0,1000.0 +/(sum(*(-(45_CFreqs,int43),45_CVars)),*(-(1000.0,1.0),/(*(parsertemp13703,int766),-(int474,float192)))) +::STMT +MATRIX:d,parsertemp43996,parsertemp43997 +FLOAT:C +%*%(t(d),+(d,*(C,%*%(parsertemp43996,parsertemp43997)))) +::STMT +MATRIX:p,parsertemp1936,parsertemp1937 +FLOAT:norm_r2 +/(norm_r2,cast.FLOAT(%*%(t(p),+(parsertemp1936,parsertemp1937)))) +::STMT +MATRIX:parsertemp11741 +LITERAL_FLOAT:1.0 ++(1.0,parsertemp11741) +::STMT +MATRIX:distance_matrix,parsertemp447763,upper_triangle ++(+(distance_matrix,t(upper_triangle)),diag(parsertemp447763)) +::STMT +MATRIX:s,parsertemp44016 +FLOAT:delta2 +-(delta2,cast.FLOAT(%*%(t(s),-(s,parsertemp44016)))) +::STMT +LITERAL_FLOAT:2.225E-307 +2.225E-307 +::STMT +MATRIX:col,less_than_lb,parsertemp24102,parsertemp24103 +FLOAT:int918,num_bins,int391 +LITERAL_FLOAT:1.0 ++(*(-(-(int391,less_than_lb),>(col,num_bins)),+(round(parsertemp24102),1.0)),*(>(+(parsertemp24103,int918),num_bins),num_bins)) +::STMT +MATRIX:A,B ++(ncol(A),ncol(B)) +::STMT +FLOAT:log_l,new_log_l +LITERAL_FLOAT:1.0E-14 +*(+(abs(log_l),abs(new_log_l)),1.0E-14) +::STMT +LITERAL_FLOAT:1.0,50.0 +*(50.0,1.0) +::STMT +MATRIX:X +abs(X) +::STMT +FLOAT:step +LITERAL_FLOAT:0.95 +*(step,0.95) +::STMT +MATRIX:parsertemp415351,ytest +FLOAT:parsertemp415362,n +LITERAL_FLOAT:1.0 +sqrt(/(-(sum(parsertemp415351),*(n,parsertemp415362)),-(nrow(ytest),1.0))) +::STMT +MATRIX:t_gp,parsertemp170245,parsertemp170239 +FLOAT:float726 +LITERAL_FLOAT:1.0,-0.284496736,0.254829592 ++(0.254829592,*(/(1.0,+(float726,parsertemp170239)),+(-0.284496736,*(t_gp,parsertemp170245)))) +::STMT +FLOAT:parsertemp557354,prob_true +LITERAL_FLOAT:0.6931471805599453 +/(*(prob_true,parsertemp557354),0.6931471805599453) +::STMT +FLOAT:num_records,i +LITERAL_FLOAT:1.0 +*(num_records,-(i,1.0)) +::STMT +FLOAT:num_min,num_max ++(num_min,num_max) +::STMT +FLOAT:i,n +LITERAL_FLOAT:1.0,253.0 +-(n,-(+(i,253.0),1.0)) +::STMT +MATRIX:tmp,Y +LITERAL_FLOAT:0.0 +>(1-*(Y,tmp),0.0) +::STMT +MATRIX:b,X,sb +exp(%*%(X,+(b,sb))) +::STMT +MATRIX:parsertemp436668,X,parsertemp436672 +LITERAL_FLOAT:1.0,2.0 +INT:int254,parsertemp436666 +-(*(rand(int254,parsertemp436666,1.0,1.0),t(rowSums(parsertemp436668))),*(2.0,%*%(X,t(parsertemp436672)))) +::STMT +FLOAT:i,n +LITERAL_FLOAT:1.0,256.0 +-(n,-(+(i,256.0),1.0)) +::STMT +LITERAL_FLOAT:4.0 +INT:int785,int18 +rand(int18,int785,4.0,4.0) +::STMT +MATRIX:parsertemp115858,X,parsertemp115862,parsertemp115860 +FLOAT:parsertemp115863,n +LITERAL_FLOAT:0.0,1.0 +*(/(-(t(parsertemp115858),*(n,parsertemp115860)),-(nrow(X),1.0)),-(1.0,<=(/(parsertemp115862,parsertemp115863),0.0))) +::STMT +MATRIX:obj,objnew,gs +cast.FLOAT(-(-(objnew,obj),gs)) +::STMT +MATRIX:determinants +FLOAT:nFeats +LITERAL_FLOAT:6.283185307179586 +*(^(6.283185307179586,nFeats),determinants) +::STMT +MATRIX:R +LITERAL_FLOAT:1.0,2.0 +INT:parsertemp500303,int480 +%*%(rowSums(^(R,2.0)),rand(int480,parsertemp500303,1.0,1.0)) +::STMT +MATRIX:lambda,g,beta +*(+(g,*(lambda,beta)),+(g,*(lambda,beta))) +::STMT +MATRIX:s,d,alpha +FLOAT:parsertemp44004 +%*%(t(+(s,*(parsertemp44004,d))),+(s,*(cast.FLOAT(alpha),d))) +::STMT +MATRIX:parsertemp31023,parsertemp31025,parsertemp31030,parsertemp31032 +FLOAT:int254,int53,int315,int955 +LITERAL_FLOAT:150.0,100.0 ++(/(/(-(parsertemp31023,parsertemp31025),-(int955,int254)),100.0),/(/(-(parsertemp31030,parsertemp31032),-(int53,int315)),150.0)) +::STMT +MATRIX:parsertemp146940,184_dtemp,outr2 +FLOAT:beta2 +LITERAL_FLOAT:1.0,2.0 +*(-(1.0,beta2),^(%*%(t(outr2),-(184_dtemp,parsertemp146940)),2.0)) +::STMT +MATRIX:parsertemp40086,addedE,addedX2 +/(t(%*%(t(addedE),addedX2)),t(parsertemp40086)) +::STMT +MATRIX:lambda,parsertemp170067,scale_X,parsertemp170065,p_CG ++(*(cast.FLOAT(lambda),cast.FLOAT(p_CG)),*(cast.FLOAT(diag(scale_X)),cast.FLOAT(%*%(parsertemp170065,parsertemp170067)))) +::STMT +MATRIX:e +LITERAL_FLOAT:4.0 +*(4.0,e) +::STMT +FLOAT:m2X +LITERAL_FLOAT:1.0002795638803466 +sqrt(*(m2X,1.0002795638803466)) +::STMT +FLOAT:parsertemp170147,parsertemp170145,p_CG,z +LITERAL_FLOAT:-1.0,2.0 +/(+(*(*(z,p_CG),-1.0),sqrt(-(parsertemp170145,parsertemp170147))),sum(^(p_CG,2.0))) +::STMT +MATRIX:P,minD,D,X +%*%(t(/(<=(D,minD),rowSums(P))),X) +::STMT +MATRIX:present_domain_vals_mat,parsertemp27485 +FLOAT:my +LITERAL_FLOAT:2.0 +^(-(%*%(present_domain_vals_mat,parsertemp27485),my),2.0) +::STMT +MATRIX:D +LITERAL_FLOAT:1.0 +/(1.0,+(D,1.0)) +::STMT +MATRIX:Y +FLOAT:bernoulli_No_label +LITERAL_FLOAT:1.0 +-(1.0,==(Y,bernoulli_No_label)) +::STMT +FLOAT:window_size,k,n +LITERAL_FLOAT:2.0 +-(+(-(n,window_size),2.0),k) +::STMT +MATRIX:tmp,w,out +LITERAL_FLOAT:0.5 ++(*(0.5,cast.FLOAT(%*%(out,out))),*(0.5,cast.FLOAT(%*%(w,tmp)))) +::STMT +MATRIX:flip_neg,is_LT_infinite,Y,parsertemp171294 +rowSums(*(Y,%*%(+(parsertemp171294,is_LT_infinite),flip_neg))) +::STMT +MATRIX:lambda,B,S +LITERAL_FLOAT:2.0 +*(lambda,^(+(B,S),2.0)) +::STMT +FLOAT:n_components,cov_param,n_features +LITERAL_FLOAT:1.0 +-(+(+(cov_param,*(n_features,n_components)),n_components),1.0) +::STMT +MATRIX:X +LITERAL_FLOAT:0.0 +*(-(0.0,sum(X)),-(0.0,sum(X))) +::STMT +FLOAT:sample_block_size,num_samples +LITERAL_FLOAT:1.0 +-(*(sample_block_size,num_samples),1.0) +::STMT +MATRIX:R,parsertemp500359 +LITERAL_FLOAT:2.0 +%*%(rowSums(^(R,2.0)),parsertemp500359) +::STMT +MATRIX:intercept,X,beta +FLOAT:int198,int797 +INT:num_records,int979 ++(%*%(X,beta),%*%(rand(num_records,int979,int797,int198),intercept)) +::STMT +MATRIX:finite_linear_terms +LITERAL_FLOAT:0.0 +exp(-(0.0,exp(finite_linear_terms))) +::STMT +MATRIX:output_values +FLOAT:log_odds,learning_rate +LITERAL_FLOAT:2.7182818284 +^(2.7182818284,+(log_odds,*(learning_rate,sum(output_values)))) +::STMT +MATRIX:log_prob,X +FLOAT:parsertemp436712 +LITERAL_FLOAT:-0.5 +*(-0.5,+(*(ncol(X),parsertemp436712),log_prob)) +::STMT +MATRIX:parsertemp174552 +LITERAL_FLOAT:0.0 +sum(abs(==(parsertemp174552,0.0))) +::STMT +MATRIX:dl_matrix +FLOAT:cost ++(cast.FLOAT(dl_matrix),cost) +::STMT +MATRIX:C,Xm,parsertemp265701 +%*%(t(Xm),%*%(Xm,%*%(C,parsertemp265701))) +::STMT +LITERAL_FLOAT:6.0 +sqrt(6.0) +::STMT +MATRIX:X +LITERAL_FLOAT:-1.0 +exp(*(X,-1.0)) +::STMT +MATRIX:parsertemp72202,subspace_idx +LITERAL_FLOAT:1.0 +diag(/(1.0,<(-(subspace_idx,parsertemp72202),1.0))) +::STMT +MATRIX:y_corr +FLOAT:link_power +LITERAL_FLOAT:0.0 +^(+(y_corr,==(y_corr,0.0)),link_power) +::STMT +MATRIX:prec_chol,X,parsertemp436696,bc_matrix,parsertemp436692 +FLOAT:int149 +LITERAL_FLOAT:2.0 ++(-(*(bc_matrix,t(parsertemp436692)),*(2.0,%*%(X,parsertemp436696))),%*%(rowSums(*(X,X)),t(^(prec_chol,int149)))) +::STMT +MATRIX:mean,X,weight,parsertemp437211,parsertemp437629 ++(/(%*%(t(parsertemp437211),-(X,mean)),cast.FLOAT(weight)),diag(parsertemp437629)) +::STMT +MATRIX:parsertemp170158,parsertemp170136 +FLOAT:r_CG,g_reg,278_sq_root_d,z,parsertemp170171,parsertemp170150 +LITERAL_FLOAT:0.5 ++(*(0.5,*(cast.FLOAT(z),+(r_CG,g_reg))),*(+(+(parsertemp170171,z),sum(parsertemp170158)),/(-(parsertemp170150,278_sq_root_d),sum(parsertemp170136)))) +::STMT +MATRIX:E,O +/(*(sum(-(O,E)),sum(-(O,E))),sum(E)) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:2.0 +exp(*(linear_terms,2.0)) +::STMT +MATRIX:S,addedX2 +FLOAT:level +rowSums(==(%*%(S,t(addedX2)),level)) +::STMT +MATRIX:parsertemp410080,d_r_rev,parsertemp410079,parsertemp410090 +LITERAL_FLOAT:-1.0 ++(*(cast.FLOAT(%*%(parsertemp410079,parsertemp410080)),-1.0),cast.FLOAT(%*%(t(d_r_rev),parsertemp410090))) +::STMT +FLOAT:i,subvector_size +LITERAL_FLOAT:1.0 ++(*(-(i,1.0),subvector_size),1.0) +::STMT +MATRIX:Y_prob +/(Y_prob,rowSums(Y_prob)) +::STMT +MATRIX:scale_X,p_CG +*(cast.FLOAT(diag(scale_X)),p_CG) +::STMT +MATRIX:prob +FLOAT:threshold +LITERAL_FLOAT:0.0 +==(>(prob,threshold),0.0) +::STMT +MATRIX:288_left,291_d,288_right +LITERAL_FLOAT:0.0,2.0 ++(/(^(sum(288_left),2.0),+(sum(291_d),0.0)),/(^(sum(288_right),2.0),+(sum(291_d),0.0))) +::STMT +LITERAL_FLOAT:1.0E9 +1.0E9 +::STMT +MATRIX:y_corr +FLOAT:link_power +^(y_corr,link_power) +::STMT +MATRIX:X_batch,186_dX,parsertemp146949,parsertemp146957,parsertemp146955 +LITERAL_FLOAT:2.0 +^(%*%(t(X_batch),*(*(parsertemp146957,parsertemp146955),%*%(186_dX,parsertemp146949))),2.0) +::STMT +LITERAL_FLOAT:2.0 +sqrt(2.0) +::STMT +MATRIX:W1_rand +LITERAL_FLOAT:0.086386842558136 +*(0.086386842558136,W1_rand) +::STMT +MATRIX:S +FLOAT:level +LITERAL_FLOAT:2.0 +==(%*%(S,t(S)),-(level,2.0)) +::STMT +MATRIX:D,parsertemp220844,ZERODIAG,beta +*(*(exp(*(parsertemp220844,beta)),ZERODIAG),D) +::STMT +MATRIX:resp,mean,X +%*%(t(*(-(X,mean),resp)),-(X,mean)) +::STMT +MATRIX:K_inv,scores,Ks +cast.FLOAT(%*%(%*%(t(Ks),K_inv),scores)) +::STMT +MATRIX:parsertemp43993,s,d,alpha_deno ++(s,*(/(sum(parsertemp43993),cast.FLOAT(alpha_deno)),d)) +::STMT +MATRIX:sb +FLOAT:delta +LITERAL_FLOAT:2.0 +-(cast.FLOAT(%*%(t(sb),sb)),^(delta,2.0)) +::STMT +MATRIX:surv,se_surv +FLOAT:z_alpha_2 +exp(/(*(z_alpha_2,se_surv),surv)) +::STMT +FLOAT:m2 +LITERAL_FLOAT:1.0004995004995005 +*(1.0004995004995005,m2) +::STMT +MATRIX:X +FLOAT:val +<(X,val) +::STMT +MATRIX:parsertemp171366,p_one_m_one +LITERAL_FLOAT:3.141592653589793 +/(%*%(parsertemp171366,p_one_m_one),3.141592653589793) +::STMT +MATRIX:W +FLOAT:int98,parsertemp97,int798,m4,int820,int239,wt +LITERAL_FLOAT:1.0 +-(*(*(^(wt,int98),+(wt,int798)),m4),*(*(*(int820,parsertemp97),^(wt,int239)),-(sum(W),1.0))) +::STMT +MATRIX:184_probs,183_dpred,parsertemp146939 +LITERAL_FLOAT:2.0 +^(colSums(-(*(183_dpred,184_probs),*(184_probs,parsertemp146939))),2.0) +::STMT +MATRIX:p_CG,z +FLOAT:trust_delta_sq +LITERAL_FLOAT:2.0 +*(sum(^(p_CG,2.0)),-(^(cast.FLOAT(z),2.0),trust_delta_sq)) +::STMT +FLOAT:run_index +LITERAL_FLOAT:1.0,96.0 +*(96.0,-(run_index,1.0)) +::STMT +MATRIX:X,Y +LITERAL_FLOAT:2.0 +/(abs(-(X,Y)),/(+(abs(X),abs(Y)),2.0)) +::STMT +FLOAT:padh,Hin,Hf +LITERAL_FLOAT:2.0 +-(+(Hin,*(2.0,padh)),Hf) +::STMT +MATRIX:Grad +LITERAL_FLOAT:-1.0,2.0 +^(*(Grad,-1.0),2.0) +::STMT +MATRIX:parsertemp389604,X_batch,2364_2361_Y,W4,parsertemp389601 +FLOAT:int996 +LITERAL_FLOAT:1.0 +%*%(*(-(/(parsertemp389601,parsertemp389604),X_batch),-(1.0,^(2364_2361_Y,int996))),W4) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:0.0 +*(exp(-(0.0,exp(linear_terms))),exp(linear_terms)) +::STMT +MATRIX:parsertemp436667,precisions,bc_matrix +*(bc_matrix,t(rowSums(*(parsertemp436667,precisions)))) +::STMT +MATRIX:_sbcvar96,_sbcvar95 +LITERAL_FLOAT:-1.0 ++(%*%(_sbcvar95,_sbcvar96),-1.0) +::STMT +FLOAT:D +LITERAL_FLOAT:0.5 +*(0.5,sqrt(D)) +::STMT +MATRIX:W +LITERAL_FLOAT:1.0,2.0,3.0 +*(*(-(sum(W),1.0),-(sum(W),2.0)),-(sum(round(W)),3.0)) +::STMT +MATRIX:parsertemp496901 +FLOAT:std,arch_coef +LITERAL_FLOAT:2.0 +*(arch_coef,^(*(cast.FLOAT(parsertemp496901),std),2.0)) +::STMT +MATRIX:parsertemp183431,X,mu +FLOAT:int754,N +LITERAL_FLOAT:1.0 +-(/(%*%(t(X),X),-(N,1.0)),*(/(N,-(N,int754)),%*%(t(mu),/(parsertemp183431,N)))) +::STMT +MATRIX:ss,X2 +LITERAL_FLOAT:1.0 +-(/(nrow(X2),ss),1.0) +::STMT +MATRIX:r_CG,q_CG +FLOAT:alpha_CG +sum(*(+(r_CG,*(alpha_CG,q_CG)),+(r_CG,*(alpha_CG,q_CG)))) +::STMT +FLOAT:int528 +LITERAL_FLOAT:0.0 +INT:m,int848 +sum(abs(rand(m,int848,0.0,int528))) +::STMT +FLOAT:ytest,yhat +LITERAL_FLOAT:1.0,2.0 +*(1.0,^(/(-(ytest,yhat),1.0),2.0)) +::STMT +MATRIX:z +sqrt(cast.FLOAT(%*%(t(z),z))) +::STMT +MATRIX:X_batch,W_1 +LITERAL_FLOAT:0.0 ++(%*%(X_batch,W_1),0.0) +::STMT +FLOAT:parsertemp169812 +LITERAL_FLOAT:2.302585092994046,0.5 +-(/(parsertemp169812,2.302585092994046),0.5) +::STMT +MATRIX:finite_linear_terms +FLOAT:int375 +LITERAL_FLOAT:-1.0,2.0 +exp(/(*(^(finite_linear_terms,int375),-1.0),2.0)) +::STMT +MATRIX:ones_ctg +LITERAL_FLOAT:1.0 +-(1.0,diag(ones_ctg)) +::STMT +MATRIX:parsertemp11251 +LITERAL_FLOAT:2.0 +t(^(2.0,parsertemp11251)) +::STMT +MATRIX:means,Y_counts,parsertemp560529 +LITERAL_FLOAT:5.0 +sum(<(*(means,%*%(Y_counts,parsertemp560529)),5.0)) +::STMT +FLOAT:parsertemp83 +-(cast.MATRIX(parsertemp83),parsertemp83) +::STMT +MATRIX:ts +LITERAL_FLOAT:1.0,4.0 ++(-(length(ts),4.0),1.0) +::STMT +MATRIX:parsertemp389604,X_batch,2364_2361_Y,parsertemp389601 +FLOAT:int323 +LITERAL_FLOAT:1.0 +t(*(-(/(parsertemp389601,parsertemp389604),X_batch),-(1.0,^(2364_2361_Y,int323)))) +::STMT +MATRIX:parsertemp397824,W3_rand +FLOAT:int796,int27 +LITERAL_FLOAT:0.5107539184552492 +%*%(*(0.5107539184552492,W3_rand),t(/(-(parsertemp397824,int27),+(parsertemp397824,int796)))) +::STMT +MATRIX:X +FLOAT:x +/(-(x,cast.FLOAT(X)),-(cast.FLOAT(X),cast.FLOAT(X))) +::STMT +MATRIX:parsertemp27718,_sbcvar92 +FLOAT:220_W +LITERAL_FLOAT:0.0,1.0E-4 ++(*(==(/(parsertemp27718,220_W),0.0),1.0E-4),/(%*%(rowSums(_sbcvar92),colSums(_sbcvar92)),sum(_sbcvar92))) +::STMT +MATRIX:Y,predicted_Y +LITERAL_FLOAT:0.0,1000.0 +/(sum(==(-(predicted_Y,Y),0.0)),1000.0) +::STMT +FLOAT:x,parsertemp169817 +LITERAL_FLOAT:10000.0 +/(round(*(x,exp(parsertemp169817))),10000.0) +::STMT +FLOAT:e,decay +LITERAL_FLOAT:1.0 +/(1.0,+(1.0,*(decay,e))) +::STMT +MATRIX:z +FLOAT:trust_delta_sq +-(cast.FLOAT(%*%(t(z),z)),trust_delta_sq) +::STMT +MATRIX:W +FLOAT:m2 +LITERAL_FLOAT:1.0 +/(*(m2,sum(round(W))),-(sum(round(W)),1.0)) +::STMT +MATRIX:scale_X,beta +*(cast.FLOAT(diag(scale_X)),cast.FLOAT(beta)) +::STMT +MATRIX:X,Y,K +LITERAL_FLOAT:-1.0 ++(*(*(K,-1.0),-(X,X)),-(Y,Y)) +::STMT +MATRIX:yhat +FLOAT:mean_y_test +LITERAL_FLOAT:2.0 +sum(^(-(yhat,mean_y_test),2.0)) +::STMT +MATRIX:parsertemp436669,prec_chol,X,parsertemp436673,bc_matrix +FLOAT:int367 +LITERAL_FLOAT:2.0 ++(-(*(bc_matrix,t(parsertemp436669)),*(2.0,%*%(X,parsertemp436673))),%*%(^(X,2.0),t(^(prec_chol,int367)))) +::STMT +FLOAT:sd_Y,sd_X +abs(-(sqrt(sd_Y),sqrt(sd_X))) +::STMT +MATRIX:parsertemp231464 +FLOAT:feature_frac +t(<=(parsertemp231464,feature_frac)) +::STMT +MATRIX:m_correct +FLOAT:i,in_i_k_min +LITERAL_FLOAT:1.0 +/(rowSums(m_correct),-(+(in_i_k_min,i),1.0)) +::STMT +MATRIX:R,parsertemp503780 +FLOAT:int440,int936 +INT:int175,parsertemp503363 +%*%(t(+(R,diag(parsertemp503780))),+(R,diag(rand(parsertemp503363,int175,int440,int936)))) +::STMT +MATRIX:Q,R +LITERAL_FLOAT:2.0 +*(2.0,%*%(R,t(Q))) +::STMT +MATRIX:C,X +LITERAL_FLOAT:-2.0 +*(-2.0,%*%(X,t(C))) +::STMT +MATRIX:the_exp +FLOAT:int76,int968 +LITERAL_FLOAT:1.0,1.0E7 +*(-(1.0,==(+(int76,the_exp),1.0E7)),-(1.0,exp(*(the_exp,int968)))) +::STMT +MATRIX:X +LITERAL_FLOAT:2.0,1.0E-6 +*(1.0E-6,sum(^(X,2.0))) +::STMT +MATRIX:2701_mask,2702_X +LITERAL_FLOAT:0.0,0.5 +*(>(2702_X,0.0),/(2701_mask,0.5)) +::STMT +MATRIX:svUpBnd,R,svLowBnd +diag(*(<=(R,cast.FLOAT(svUpBnd)),>(R,cast.FLOAT(svLowBnd)))) +::STMT +MATRIX:X +rev(X) +::STMT +MATRIX:obj,parsertemp44077 +FLOAT:C,float763,parsertemp44081 +cast.FLOAT(-(obj,+(*(float763,parsertemp44077),*(C,parsertemp44081)))) +::STMT +MATRIX:finite_linear_terms +LITERAL_FLOAT:1.0,1.0E7 +-(1.0,==(+(1.0E7,exp(finite_linear_terms)),1.0E7)) +::STMT +MATRIX:ot2 +FLOAT:int160 +LITERAL_FLOAT:25.0,100.0 +/(*(sum(>(ot2,int160)),100.0),25.0) +::STMT +FLOAT:link_power +LITERAL_FLOAT:0.5 +^(0.5,link_power) +::STMT +MATRIX:prec_chol,mu +LITERAL_FLOAT:2.0 +t(*(mu,^(prec_chol,2.0))) +::STMT +LITERAL_FLOAT:192.0 +INT:int522,int415 +rand(int522,int415,192.0,192.0) +::STMT +LITERAL_FLOAT:10000.0 +10000.0 +::STMT +MATRIX:U,V_sum +rowSums(rowSums(/(*(U,U),sum(V_sum)))) +::STMT +FLOAT:num_records +LITERAL_FLOAT:1920.0,1.0 +-(1.0,/(1920.0,num_records)) +::STMT +MATRIX:Q3,IQR +LITERAL_FLOAT:1.5 ++(Q3,*(1.5,IQR)) +::STMT +MATRIX:F +LITERAL_FLOAT:1.0 +*(-(nrow(F),1.0),-(ncol(F),1.0)) +::STMT +MATRIX:ytest,yhat +FLOAT:mean_y_test +LITERAL_FLOAT:2.0 +/(sum(^(-(yhat,mean_y_test),2.0)),sum(^(-(ytest,mean_y_test),2.0))) +::STMT +FLOAT:approx_sample_size,num_records +LITERAL_FLOAT:1.0 +-(1.0,/(approx_sample_size,num_records)) +::STMT +MATRIX:valueCount,parsertemp552531,resp,Y +rowSums(*(==(+(resp,parsertemp552531),Y),valueCount)) +::STMT +MATRIX:pearson_residual_sq +FLOAT:num_records +LITERAL_FLOAT:1.0 +/(sum(pearson_residual_sq),-(num_records,1.0)) +::STMT +MATRIX:z,parsertemp285752 +FLOAT:2234_sq_root_d,parsertemp285742,parsertemp285763,pp_CG +LITERAL_FLOAT:0.5 ++(*(0.5,cast.FLOAT(%*%(z,parsertemp285752))),*(parsertemp285763,/(-(parsertemp285742,2234_sq_root_d),pp_CG))) +::STMT +MATRIX:parsertemp539203,T,event +FLOAT:int631 +LITERAL_FLOAT:2.0,0.6666666666666666 +/(^(/(*(parsertemp539203,int631),2.0),0.6666666666666666),/(-(max(T),min(T)),sum(event))) +::STMT +MATRIX:prec_chol,parsertemp438810,X,bc_matrix,parsertemp438806 +FLOAT:int230,int476 +LITERAL_FLOAT:2.0 ++(-(*(bc_matrix,t(parsertemp438806)),*(2.0,%*%(X,parsertemp438810))),%*%(rowSums(^(X,int230)),t(^(prec_chol,int476)))) +::STMT +MATRIX:finite_linear_terms +LITERAL_FLOAT:0.231641888 +*(abs(finite_linear_terms),0.231641888) +::STMT +MATRIX:r_LS +FLOAT:alpha_LS,norm_r2_LS,p_LS +LITERAL_FLOAT:2.0 +/(^(+(cast.FLOAT(r_LS),*(alpha_LS,p_LS)),2.0),norm_r2_LS) +::STMT +MATRIX:Y_counts,parsertemp560507,Y,parsertemp560512 +-(sum(rowSums(*(Y,parsertemp560507))),sum(*(Y_counts,rowSums(parsertemp560512)))) +::STMT +MATRIX:maskd1,out1,186_dX,parsertemp146949 +FLOAT:p +LITERAL_FLOAT:0.0 +colSums(*(>(out1,0.0),*(/(maskd1,p),%*%(186_dX,parsertemp146949)))) +::STMT +LITERAL_FLOAT:6.0,2003.0 +*(6.0,2003.0) +::STMT +MATRIX:col_nonzeros,U,parsertemp382849,V,parsertemp382852 +FLOAT:reg ++(t(%*%(t(U),*(parsertemp382849,parsertemp382852))),*(*(reg,V),col_nonzeros)) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:1.0,500.0 ++(rowSums(classFeatureCounts),*(500.0,1.0)) +::STMT +MATRIX:X,parsertemp471907 +LITERAL_FLOAT:1.0E-14 +sum(>(abs(-(X,parsertemp471907)),1.0E-14)) +::STMT +FLOAT:42_m2Y,42_m2X +LITERAL_FLOAT:1.001001001001001 +*(sqrt(*(42_m2X,1.001001001001001)),sqrt(*(42_m2Y,1.001001001001001))) +::STMT +MATRIX:posSampleVariances,negSampleMeans,posSampleMeans,negSampleVariances +FLOAT:int673,int18 +/(-(posSampleMeans,negSampleMeans),sqrt(+(/(posSampleVariances,int673),/(negSampleVariances,int18)))) +::STMT +MATRIX:parsertemp31908,e +FLOAT:l +/(t(%*%(t(e),==(parsertemp31908,l))),t(colSums(==(parsertemp31908,l)))) +::STMT +MATRIX:scale_X,p_CG,shift_X ++(*(cast.FLOAT(diag(scale_X)),p_CG),*(cast.FLOAT(shift_X),p_CG)) +::STMT +MATRIX:posSamples,posSampleMeans +FLOAT:int970,int726 +LITERAL_FLOAT:1.0,100.0 +/(-(colSums(^(posSamples,int726)),*(100.0,^(posSampleMeans,int970))),-(100.0,1.0)) +::STMT +MATRIX:X,parsertemp129018 +LITERAL_FLOAT:1.0 +*(max(parsertemp129018),-(ncol(X),1.0)) +::STMT +MATRIX:ss,se +FLOAT:130_eAvg +LITERAL_FLOAT:1.0 +-(/(/(se,ss),130_eAvg),1.0) +::STMT +MATRIX:X,parsertemp222929 +*(cast.FLOAT(parsertemp222929),-(X,X)) +::STMT +MATRIX:yhat +FLOAT:mean_y_test +LITERAL_FLOAT:2.0 +^(-(yhat,mean_y_test),2.0) +::STMT +LITERAL_FLOAT:1.0E20 +1.0E20 +::STMT +MATRIX:Yhat_prime,E,H3 +%*%(t(*(E,Yhat_prime)),H3) +::STMT +MATRIX:ssX_p,X +%*%(t(X),%*%(X,ssX_p)) +::STMT +MATRIX:lambda,beta +LITERAL_FLOAT:0.0 +%*%(t(+(0.0,*(lambda,beta))),+(0.0,*(lambda,beta))) +::STMT +MATRIX:Xm,parsertemp265718 +abs(/(-(sum(parsertemp265718),sum(Xm)),sum(Xm))) +::STMT +MATRIX:feature +FLOAT:n_bins +/(-(max(feature),min(feature)),n_bins) +::STMT +MATRIX:2700_X,2700_W,parsertemp459178,2699_dtemp +FLOAT:lr +LITERAL_FLOAT:5.0E-4 +*(lr,+(%*%(t(2700_X),-(2699_dtemp,parsertemp459178)),*(5.0E-4,2700_W))) +::STMT +MATRIX:foffb,foffe +LITERAL_FLOAT:1.0 +-(cast.FLOAT(foffe),+(cast.FLOAT(foffb),1.0)) +::STMT +MATRIX:p_CG +FLOAT:int351,trust_delta_sq,z +-(*(*(cast.FLOAT(z),sum(p_CG)),*(cast.FLOAT(z),sum(p_CG))),*(sum(^(p_CG,int351)),-(*(z,z),trust_delta_sq))) +::STMT +MATRIX:b,H +%*%(t(b),-(+(H,t(H)),diag(diag(H)))) +::STMT +MATRIX:p_CG,z +FLOAT:trust_delta_sq +*(cast.FLOAT(%*%(t(p_CG),p_CG)),-(cast.FLOAT(%*%(z,z)),trust_delta_sq)) +::STMT +FLOAT:eps +LITERAL_FLOAT:0.5 ++(0.5,eps) +::STMT +MATRIX:z +FLOAT:pp,trust_delta_sq +*(pp,-(sum(*(z,z)),trust_delta_sq)) +::STMT +MATRIX:n_risk_stratum,n_risk_i2j,V1 +FLOAT:I_i1i2 +*(V1,-(I_i1i2,/(n_risk_i2j,n_risk_stratum))) +::STMT +MATRIX:col,missing_indicator_mat +FLOAT:global_mean ++(col,*(missing_indicator_mat,global_mean)) +::STMT +FLOAT:parsertemp539092,parsertemp539091,num_groups +LITERAL_FLOAT:1.0,2.0 +-(+(+(*(parsertemp539091,parsertemp539092),1.0),num_groups),2.0) +::STMT +MATRIX:I,parsertemp472299 +LITERAL_FLOAT:0.0 +*(==(!=(*(parsertemp472299,I),0.0),0.0),I) +::STMT +MATRIX:Y,linear_terms +LITERAL_FLOAT:0.0 +*(rowSums(Y),exp(-(0.0,exp(linear_terms)))) +::STMT +MATRIX:W +FLOAT:m4 +LITERAL_FLOAT:1.0,2.0 +*(*(^(sum(W),2.0),+(sum(W),1.0)),m4) +::STMT +MATRIX:X_batch,dout1 +LITERAL_FLOAT:2.0 +^(%*%(t(X_batch),dout1),2.0) +::STMT +MATRIX:m_err +rowSums(colSums(m_err)) +::STMT +FLOAT:int453,se_g1,int711,int305,int506,parsertemp113,wt +sqrt(/(*(*(int711,parsertemp113),^(se_g1,int453)),*(+(wt,int506),-(wt,int305)))) +::STMT +MATRIX:X +FLOAT:N +LITERAL_FLOAT:1.0 +/(%*%(t(X),X),-(N,1.0)) +::STMT +FLOAT:277_sq_root_d,parsertemp170093,pp_CG,pq_CG +LITERAL_FLOAT:0.5 +*(*(0.5,/(+(parsertemp170093,277_sq_root_d),pp_CG)),pq_CG) +::STMT +FLOAT:parsertemp191170,Wf +LITERAL_FLOAT:0.0,1.0,2.0 +INT:parsertemp191169,F +*(rand(F,parsertemp191169,0.0,1.0),sqrt(/(2.0,*(parsertemp191170,Wf)))) +::STMT +MATRIX:b,W,X ++(%*%(X,W),b) +::STMT +MATRIX:y_corr +LITERAL_FLOAT:0.0,1.0 ++(<=(y_corr,0.0),>=(y_corr,1.0)) +::STMT +MATRIX:parsertemp231461 +LITERAL_FLOAT:0.1 +<=(parsertemp231461,0.1) +::STMT +MATRIX:C,Xm,parsertemp265701,XtZ +FLOAT:ZtZ_sum +%*%(%*%(%*%(Xm,%*%(C,parsertemp265701)),t(/(XtZ,ZtZ_sum))),t(Xm)) +::STMT +MATRIX:X_batch,parsertemp146957,187_dX +FLOAT:beta2 +LITERAL_FLOAT:1.0,2.0 +*(-(1.0,beta2),^(%*%(t(X_batch),*(parsertemp146957,187_dX)),2.0)) +::STMT +MATRIX:parsertemp437305,_funvar2125,parsertemp437277,parsertemp437272 +exp(-(+(_funvar2125,parsertemp437305),+(parsertemp437272,parsertemp437277))) +::STMT +MATRIX:W +FLOAT:m3 +LITERAL_FLOAT:2.0 +*(^(sum(round(W)),2.0),m3) +::STMT +MATRIX:pearson_residual_sq +FLOAT:num_features,num_records +/(sum(pearson_residual_sq),-(num_records,num_features)) +::STMT +MATRIX:parsertemp12846,F,parsertemp12848 +FLOAT:q,int265,W +LITERAL_FLOAT:1.0 +/(sum(/(^(parsertemp12848,int265),/(parsertemp12846,W))),*(sum(F),-(q,1.0))) +::STMT +FLOAT:o_init,N +LITERAL_FLOAT:-2.0 +/(*(-2.0,o_init),N) +::STMT +LITERAL_FLOAT:1.0 ++(+(+(1.0,1.0),1.0),1.0) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:1.0 +*(^(exp(linear_terms),1.0),exp(linear_terms)) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:-1.0 +*(^(exp(linear_terms),-1.0),exp(linear_terms)) +::STMT +MATRIX:X +FLOAT:2917_split +-($1:nrow(X),round(*($1,2917_split))) +::STMT +LITERAL_FLOAT:9999.0 +9999.0 +::STMT +MATRIX:prec_chol,mu +LITERAL_FLOAT:2.0 +*(rowSums(^(mu,2.0)),^(prec_chol,2.0)) +::STMT +LITERAL_FLOAT:0.01 +0.01 +::STMT +MATRIX:cm +FLOAT:n +==(t(cm),n) +::STMT +MATRIX:parsertemp389341,X,parsertemp389344 +LITERAL_FLOAT:1.0 +-(/(-(exp(parsertemp389341),1.0),+(exp(parsertemp389344),1.0)),X) +::STMT +LITERAL_FLOAT:100.0 ++(100.0,100.0) +::STMT +MATRIX:b_cumulant,is_natural_parameter_log_zero,parsertemp560392,Y,natural_parameters +FLOAT:int562 +LITERAL_FLOAT:1.0 +-(-(*(Y,natural_parameters),b_cumulant),/(*(>(Y,int562),is_natural_parameter_log_zero),-(1.0,*(parsertemp560392,is_natural_parameter_log_zero)))) +::STMT +MATRIX:X,Y +FLOAT:x +LITERAL_FLOAT:1.0 +*(-(1.0,/(-(x,X),-(X,X))),Y) +::STMT +MATRIX:g_new,s,g_old +*(/(sum(*(g_new,g_new)),sum(*(g_old,g_old))),s) +::STMT +MATRIX:parsertemp265720,parsertemp265715,parsertemp265722 +FLOAT:m,n +LITERAL_FLOAT:2.0 +/(-(+(sum(parsertemp265722),trace(parsertemp265715)),*(2.0,sum(parsertemp265720))),*(n,m)) +::STMT +MATRIX:I,y2 +/(%*%(I,y2),rowSums(I)) +::STMT +MATRIX:A,lambda ++(A,diag(lambda)) +::STMT +MATRIX:X +abs(-(X,round(X))) +::STMT +MATRIX:C,Xm,parsertemp265702 +sum(-(%*%(%*%(Xm,parsertemp265702),t(C)),Xm)) +::STMT +MATRIX:objvals +LITERAL_FLOAT:10.0,1.5,-8.0 +*(*(1.5,^(10.0,-8.0)),cast.FLOAT(objvals)) +::STMT +FLOAT:parsertemp496689,parsertemp496690,parsertemp496694,int69,parsertemp496686,n +LITERAL_FLOAT:1.0,2.0 +*(/(1.0,*(2.0,n)),+(parsertemp496694,/(^(parsertemp496686,int69),+(parsertemp496689,parsertemp496690)))) +::STMT +MATRIX:p,lambda,X ++(%*%(t(X),%*%(X,p)),*(lambda,p)) +::STMT +MATRIX:ot2 +FLOAT:int689 +LITERAL_FLOAT:200.0,100.0 +/(*(sum(>(ot2,int689)),100.0),200.0) +::STMT +LITERAL_FLOAT:1.0,8.0 +-(8.0,1.0) +::STMT +MATRIX:parsertemp171083 +FLOAT:float666 +LITERAL_FLOAT:0.001308,0.189269 ++(0.189269,*(sqrt(*(float666,parsertemp171083)),0.001308)) +::STMT +MATRIX:r_CG,q_CG +FLOAT:alpha_CG +*(+(r_CG,*(alpha_CG,q_CG)),+(r_CG,*(alpha_CG,q_CG))) +::STMT +MATRIX:parsertemp389329,parsertemp389332,W4 +FLOAT:int14,int822 +%*%(W4,t(/(-(parsertemp389329,int14),+(parsertemp389332,int822)))) +::STMT +MATRIX:linear_terms +FLOAT:int6 +LITERAL_FLOAT:1.0 +/(1.0,-(exp(*(linear_terms,int6)),1.0)) +::STMT +MATRIX:p_LS +FLOAT:alpha_LS,r_LS,norm_r2_LS +LITERAL_FLOAT:2.0 +*(/(^(+(r_LS,alpha_LS),2.0),norm_r2_LS),cast.FLOAT(p_LS)) +::STMT +MATRIX:simplex,parsertemp503570 +LITERAL_FLOAT:2.0 +-(*(2.0,/(-(parsertemp503570,simplex),nrow(simplex))),simplex) +::STMT +MATRIX:X_cluster,_funvar62 +|(X_cluster,_funvar62) +::STMT +MATRIX:Y_counts,parsertemp560517,ent1_vec +FLOAT:int324 +sum(*(Y_counts,-(rowSums(parsertemp560517),^(ent1_vec,int324)))) +::STMT +MATRIX:parsertemp410246,parsertemp410249 +LITERAL_FLOAT:0.6666666666666666 +-(max(^(/(parsertemp410246,parsertemp410249),0.6666666666666666)),min(^(/(parsertemp410246,parsertemp410249),0.6666666666666666))) +::STMT +MATRIX:means,Y_counts,Y +/(colSums(-(Y,means)),sum(Y_counts)) +::STMT +MATRIX:p,V +LITERAL_FLOAT:1.0E-8 +%*%(t(p),+(%*%(t(V),%*%(V,p)),*(1.0E-8,p))) +::STMT +MATRIX:parsertemp171245,Y +LITERAL_FLOAT:1.0 +*(Y,/(1.0,-(exp(parsertemp171245),1.0))) +::STMT +MATRIX:parsertemp410977,W,H,parsertemp410974 +%*%(W,/(*(H,%*%(parsertemp410974,parsertemp410977)),t(colSums(W)))) +::STMT +MATRIX:parsertemp31026,parsertemp31033 +FLOAT:int225,int394 +LITERAL_FLOAT:2.0,3352500.0,990000.0 ++(/(^(/(parsertemp31026,int394),2.0),990000.0),/(^(/(parsertemp31033,int225),2.0),3352500.0)) +::STMT +MATRIX:CFreqs1,parsertemp27492,present_domain_vals_mat +FLOAT:int634 +LITERAL_FLOAT:1.0 +/(sum(*(%*%(present_domain_vals_mat,CFreqs1),^(parsertemp27492,int634))),-(nrow(present_domain_vals_mat),1.0)) +::STMT +MATRIX:log_prob,log_det_chol +FLOAT:parsertemp443052,float150 +LITERAL_FLOAT:-0.5 ++(*(-0.5,+(*(parsertemp443052,float150),log_prob)),cast.FLOAT(log_det_chol)) +::STMT +FLOAT:s,i2,n +-(n,*(s,i2)) +::STMT +FLOAT:num_records +LITERAL_FLOAT:100.0 +/(100.0,num_records) +::STMT +MATRIX:parsertemp31115,parsertemp31108 +FLOAT:int207,int915 +LITERAL_FLOAT:7.996E9,2.0,3.37275E9 ++(/(^(/(parsertemp31108,int915),2.0),7.996E9),/(^(/(parsertemp31115,int207),2.0),3.37275E9)) +::STMT +MATRIX:uniqueValues,X +cast.FLOAT(==(X,uniqueValues)) +::STMT +MATRIX:resp,X +LITERAL_FLOAT:2.22E-16 +/(%*%(t(resp),X),t(+(colSums(resp),2.22E-16))) +::STMT +MATRIX:id +diag(diag(==(id,cast.FLOAT(id)))) +::STMT +MATRIX:p_LS,X +*(cast.FLOAT(%*%(t(X),X)),p_LS) +::STMT +FLOAT:iter +LITERAL_FLOAT:5.0 +/(iter,5.0) +::STMT +FLOAT:b,rad +LITERAL_FLOAT:-1.0 +*(-(b,rad),-1.0) +::STMT +LITERAL_FLOAT:1.0,4.0 +-(4.0,1.0) +::STMT +MATRIX:X +FLOAT:val +==(X,val) +::STMT +MATRIX:W,X,H +LITERAL_FLOAT:1.0E-8 +%*%(t(W),/(X,+(%*%(W,H),1.0E-8))) +::STMT +FLOAT:sum_y_test,sum_sq_y_test,n +LITERAL_FLOAT:2.0 +-(sum_sq_y_test,*(n,^(/(sum_y_test,n),2.0))) +::STMT +FLOAT:window_size +LITERAL_FLOAT:4.0 +/(window_size,4.0) +::STMT +MATRIX:2696_mask,outr3 +LITERAL_FLOAT:0.5 +/(*(outr3,2696_mask),0.5) +::STMT +MATRIX:p,q,lambda,X +FLOAT:norm_r2 +*(/(norm_r2,sum(*(p,q))),+(%*%(t(X),%*%(X,p)),*(lambda,p))) +::STMT +MATRIX:scale_X,X,beta +*(*(cast.FLOAT(diag(scale_X)),cast.FLOAT(beta)),X) +::STMT +MATRIX:Train,2342_m_colmax,2342_m_colmin +LITERAL_FLOAT:1.0,2.0 +-(/(*(2.0,-(Train,2342_m_colmin)),-(2342_m_colmax,2342_m_colmin)),1.0) +::STMT +MATRIX:p,A +sum(*(p,%*%(t(A),%*%(A,p)))) +::STMT +MATRIX:r,d,X,Hd,parsertemp44001 +FLOAT:int656 +*(/(sum(^(r,int656)),cast.FLOAT(%*%(parsertemp44001,Hd))),%*%(X,d)) +::STMT +MATRIX:parsertemp393591,W4 +LITERAL_FLOAT:2.0 +exp(*(2.0,t(%*%(W4,parsertemp393591)))) +::STMT +MATRIX:2701_mask,2700_W,2726_dpred,parsertemp459177,2699_probs +LITERAL_FLOAT:0.5 +*(/(2701_mask,0.5),%*%(-(*(2726_dpred,2699_probs),*(2699_probs,parsertemp459177)),t(2700_W))) +::STMT +MATRIX:X,parsertemp386474 +LITERAL_FLOAT:-2.0 ++(+(*(-2.0,%*%(X,parsertemp386474)),X),t(X)) +::STMT +FLOAT:strideh,Hin,Hf +LITERAL_FLOAT:1.0 ++(/(-(Hin,Hf),strideh),1.0) +::STMT +MATRIX:P,D,Z,ZERODIAG +FLOAT:int934 +LITERAL_FLOAT:1.0 +*(-(P,/(*(Z,ZERODIAG),sum(Z))),*(/(1.0,+(D,int934)),ZERODIAG)) +::STMT +MATRIX:W2_rand +LITERAL_FLOAT:0.16823164622761327 +*(0.16823164622761327,W2_rand) +::STMT +MATRIX:parsertemp265709,Xm,tmp,Z,parsertemp265702 +%*%(t(/(%*%(parsertemp265709,Z),sum(tmp))),/(%*%(t(Xm),%*%(Xm,parsertemp265702)),sum(tmp))) +::STMT +MATRIX:curr_prediction +FLOAT:int567,282_lambda ++(sum(*(curr_prediction,-(int567,curr_prediction))),282_lambda) +::STMT +MATRIX:X2 +max(colSums(X2)) +::STMT +MATRIX:parsertemp31115,parsertemp31108 +FLOAT:parsertemp31116,parsertemp31109 +LITERAL_FLOAT:2.0,1500.0,2000.0 +^(+(/(/(parsertemp31108,parsertemp31109),2000.0),/(/(parsertemp31115,parsertemp31116),1500.0)),2.0) +::STMT +MATRIX:parsertemp106 +LITERAL_FLOAT:10.0 +*(10.0,parsertemp106) +::STMT +MATRIX:W +FLOAT:m2 +LITERAL_FLOAT:1.0 +sqrt(/(*(m2,sum(W)),-(sum(W),1.0))) +::STMT +LITERAL_FLOAT:128.0 +INT:int502,int800 +rand(int800,int502,128.0,128.0) +::STMT +MATRIX:parsertemp386440,parsertemp386441 +LITERAL_FLOAT:1.0,5.0 +>=(+(rowSums(*(parsertemp386440,parsertemp386441)),1.0),5.0) +::STMT +MATRIX:simplex +LITERAL_FLOAT:2.0 +*(2.0,/(-(rowSums(simplex),simplex),nrow(simplex))) +::STMT +MATRIX:parsertemp169867 +FLOAT:pp,zz,trust_delta_sq +sqrt(-(*(sum(parsertemp169867),sum(parsertemp169867)),*(pp,-(zz,trust_delta_sq)))) +::STMT +MATRIX:X,permut +LITERAL_FLOAT:2.0 +colSums(^(%*%(permut,X),2.0)) +::STMT +MATRIX:2697_b,parsertemp459149,2697_W,outd3 +exp(-(+(%*%(outd3,2697_W),2697_b),parsertemp459149)) +::STMT +MATRIX:b,scale_X,shift_X,X,y ++(%*%(diag(scale_X),%*%(t(X),y)),*(cast.FLOAT(b),shift_X)) +::STMT +MATRIX:parsertemp395001,W4_rand +FLOAT:int764,int842 +LITERAL_FLOAT:0.08692913816996169 +%*%(*(0.08692913816996169,W4_rand),t(/(-(parsertemp395001,int842),+(parsertemp395001,int764)))) +::STMT +MATRIX:S,addedE,parsertemp31676 +FLOAT:level +rowSums(*(==(%*%(S,parsertemp31676),level),t(addedE))) +::STMT +MATRIX:parsertemp421322 +LITERAL_FLOAT:1.0,11.0 +*(11.0,-(max(round(parsertemp421322)),1.0)) +::STMT +MATRIX:dW,parsertemp459256 +LITERAL_FLOAT:5.0E-4 ++(dW,*(5.0E-4,parsertemp459256)) +::STMT +MATRIX:R,HS +FLOAT:alpha +LITERAL_FLOAT:2.0 +^(-(R,*(alpha,HS)),2.0) +::STMT +LITERAL_FLOAT:1.0,100.0 ++(100.0,1.0) +::STMT +MATRIX:R,parsertemp40219,parsertemp40216,parsertemp40226,parsertemp40220,parsertemp40231 +FLOAT:level +/(-(+(R,rowSums(parsertemp40226)),rowSums(*(parsertemp40220,parsertemp40231))),-(+(R,rowSums(parsertemp40216)),rowSums(==(parsertemp40219,level)))) +::STMT +MATRIX:V +-(max(V),min(V)) +::STMT +MATRIX:Y,linear_terms +LITERAL_FLOAT:0.0 +*(rowSums(Y),>=(linear_terms,0.0)) +::STMT +MATRIX:parsertemp285809,p_CG,z +FLOAT:parsertemp285799,parsertemp285820,2235_sq_root_d +LITERAL_FLOAT:0.5 ++(*(0.5,cast.FLOAT(%*%(z,parsertemp285809))),*(parsertemp285820,/(-(parsertemp285799,2235_sq_root_d),cast.FLOAT(p_CG)))) +::STMT +MATRIX:Y +==(Y,min(Y)) +::STMT +FLOAT:i,num_runs,num_centroids +LITERAL_FLOAT:1.0 ++(*(num_centroids,-(num_runs,1.0)),i) +::STMT +MATRIX:E,X +LITERAL_FLOAT:0.0 +-(0.0,t(colSums(*(X,E)))) +::STMT +MATRIX:p,e,u +LITERAL_FLOAT:0.15000000000000002 +*(0.15000000000000002,%*%(%*%(e,u),p)) +::STMT +FLOAT:iter +LITERAL_FLOAT:3.0 +/(iter,3.0) +::STMT +MATRIX:t,parsertemp32854,parsertemp32848,Y,parsertemp32857,parsertemp32858 +cast.FLOAT(+(+(*(parsertemp32848,Y),*(t,Y)),*(*(t,parsertemp32854),+(parsertemp32857,parsertemp32858)))) +::STMT +MATRIX:parsertemp149335,LT,Y +LITERAL_FLOAT:-1.0 ++(*(sum(*(Y,LT)),-1.0),sum(parsertemp149335)) +::STMT +MATRIX:col +FLOAT:min_val,bin_width +LITERAL_FLOAT:0.5 +round(-(/(-(col,min_val),bin_width),0.5)) +::STMT +FLOAT:lambda +LITERAL_FLOAT:2.0 +/(lambda,2.0) +::STMT +MATRIX:diff +LITERAL_FLOAT:2.0 +sqrt(rowSums(^(diff,2.0))) +::STMT +MATRIX:F,parsertemp12916,parsertemp12915 +FLOAT:int496,int64,meanX +LITERAL_FLOAT:1.0 +*(/(F,-(sum(F),1.0)),-(+(-(parsertemp12915,parsertemp12916),/(int496,int64)),meanX)) +::STMT +MATRIX:W,X,H +FLOAT:eps +/(X,+(%*%(W,H),eps)) +::STMT +MATRIX:lambda,parsertemp149338,parsertemp149335,parsertemp149331 +LITERAL_FLOAT:-1.0,0.5 ++(+(*(sum(parsertemp149331),-1.0),sum(parsertemp149335)),*(0.5,sum(*(lambda,parsertemp149338)))) +::STMT +MATRIX:R,3_ss,dsep +FLOAT:3_eAvg +LITERAL_FLOAT:1.0 +-(/(/(+(R,dsep),3_ss),3_eAvg),1.0) +::STMT +MATRIX:S,V,W +%*%(*(W,%*%(S,t(V))),V) +::STMT +MATRIX:parsertemp389339 +LITERAL_FLOAT:1.0,2.0 +-(exp(*(2.0,t(parsertemp389339))),1.0) +::STMT +LITERAL_FLOAT:1.0,1500.0 +-(1500.0,1.0) +::STMT +LITERAL_FLOAT:-1.0,1.0 +INT:int633,n ++(diag(rand(n,int633,-1.0,-1.0)),1.0) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:0.0,0.5 +-(>=(linear_terms,0.0),0.5) +::STMT +MATRIX:posSampleMeans +LITERAL_FLOAT:2.0,2000.0 +*(2000.0,^(posSampleMeans,2.0)) +::STMT +FLOAT:run_index +LITERAL_FLOAT:1.0,96.0 ++(*(96.0,-(run_index,1.0)),1.0) +::STMT +MATRIX:dout,mask +FLOAT:p +*(/(mask,p),dout) +::STMT +MATRIX:parsertemp13725,parsertemp13720,45_CVars,45_CFreqs +LITERAL_FLOAT:1.0,1000.0 +/(/(sum(*(45_CFreqs,parsertemp13720)),-(nrow(45_CFreqs),1.0)),/(sum(*(parsertemp13725,45_CVars)),-(1000.0,nrow(45_CFreqs)))) +::STMT +MATRIX:finite_linear_terms +LITERAL_FLOAT:0.0 +>=(finite_linear_terms,0.0) +::STMT +MATRIX:parsertemp170246,parsertemp170240,parsertemp170238 +FLOAT:float336,float335,float235 +LITERAL_FLOAT:1.0,0.254829592 +*(/(1.0,+(1.0,*(parsertemp170238,float235))),+(0.254829592,*(/(float336,parsertemp170240),+(float335,parsertemp170246)))) +::STMT +MATRIX:m_iter_err_sum,parsertemp379567,m_err +FLOAT:i_process_item +LITERAL_FLOAT:2.0 +-(*(^(/(parsertemp379567,i_process_item),2.0),i_process_item),*(*(2.0,/(parsertemp379567,i_process_item)),+(colSums(m_err),m_iter_err_sum))) +::STMT +MATRIX:parsertemp539204 +FLOAT:float280,float688,int423,float881,float839,int969 +-(max(^(/(parsertemp539204,float688),/(int423,float839))),min(^(/(parsertemp539204,float881),/(int969,float280)))) +::STMT +MATRIX:parsertemp171367,is_LT_infinite +FLOAT:float643 +LITERAL_FLOAT:1.0,0.5 ++(*(+(0.5,/(parsertemp171367,float643)),-(1.0,rowSums(is_LT_infinite))),is_LT_infinite) +::STMT +MATRIX:w,yt,Xt +LITERAL_FLOAT:0.0 +>(*(yt,%*%(Xt,w)),0.0) +::STMT +MATRIX:A,parsertemp12899,CVars,CFreqs,parsertemp12904 +LITERAL_FLOAT:1.0 +/(/(sum(*(CFreqs,parsertemp12899)),-(nrow(CFreqs),1.0)),/(sum(*(parsertemp12904,CVars)),-(nrow(A),nrow(CFreqs)))) +::STMT +MATRIX:parsertemp456742,r,y +LITERAL_FLOAT:0.0 +-(0.0,cast.FLOAT(%*%(t(r),%*%(parsertemp456742,y)))) +::STMT +MATRIX:W,H +FLOAT:eps ++(%*%(%*%(t(W),W),H),eps) +::STMT +FLOAT:F1 +LITERAL_FLOAT:2.0 +*(*(*(F1,2.0),2.0),2.0) +::STMT +LITERAL_FLOAT:10000.0,0.8 +*(10000.0,0.8) +::STMT +MATRIX:grad +LITERAL_FLOAT:0.0 +-(0.0,grad) +::STMT +MATRIX:parsertemp31111,parsertemp31113 +FLOAT:int50 +LITERAL_FLOAT:1.0,2.0,1500.0 +^(/(-(colSums(parsertemp31111),*(int50,parsertemp31113)),-(1500.0,1.0)),2.0) +::STMT +MATRIX:X +LITERAL_FLOAT:4.0 +>=(X,4.0) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,11.0 +-(+(i,11.0),1.0) +::STMT +MATRIX:xs +LITERAL_FLOAT:4.5 +>=(xs,4.5) +::STMT +MATRIX:elt,ones_ctg +LITERAL_FLOAT:1.0 +%*%(/(elt,%*%(rowSums(elt),t(ones_ctg))),-(1.0,diag(ones_ctg))) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,3.0 +*(3.0,-(i,1.0)) +::STMT +MATRIX:grad +LITERAL_FLOAT:-1.0 +sum(*(*(grad,-1.0),*(grad,-1.0))) +::STMT +MATRIX:m_err_for_order,m_active_flag +LITERAL_FLOAT:0.0 +*(m_err_for_order,t(==(m_active_flag,0.0))) +::STMT +LITERAL_FLOAT:3.141592653589793 +3.141592653589793 +::STMT +MATRIX:parsertemp31111,parsertemp31113 +FLOAT:int73 +LITERAL_FLOAT:1.0,1500.0 +/(/(-(colSums(parsertemp31111),*(int73,parsertemp31113)),-(1500.0,1.0)),1500.0) +::STMT +MATRIX:R,B,parsertemp503364 +LITERAL_FLOAT:0.0 +-(0.0,%*%(t(+(R,parsertemp503364)),B)) +::STMT +MATRIX:ss,map +LITERAL_FLOAT:1.0 +*(map,/(1.0,t(ss))) +::STMT +MATRIX:w_X,z_LS,X +*(/(nrow(X),*(cast.FLOAT(w_X),cast.FLOAT(z_LS))),z_LS) +::STMT +MATRIX:parsertemp220853,W,sum_Pi,beta +LITERAL_FLOAT:3.4011973816621555 +-(+(parsertemp220853,*(beta,/(W,sum_Pi))),3.4011973816621555) +::STMT +MATRIX:X +LITERAL_FLOAT:1.0 +>=(X,1.0) +::STMT +LITERAL_FLOAT:6.283185307179586 +6.283185307179586 +::STMT +LITERAL_FLOAT:0.15915494309189535 +0.15915494309189535 +::STMT +FLOAT:arch_coef,var_coef +LITERAL_FLOAT:1.0 +-(-(1.0,arch_coef),var_coef) +::STMT +MATRIX:log_l_part_saturated,log_l_part +LITERAL_FLOAT:2.0 +-(*(2.0,sum(log_l_part_saturated)),*(2.0,sum(log_l_part))) +::STMT +MATRIX:X +FLOAT:x +-(x,cast.FLOAT(X)) +::STMT +MATRIX:parsertemp220863,parsertemp220864,Hdiff,betamax,beta +FLOAT:INF,int45 +LITERAL_FLOAT:2.0 +/(*(*(>=(Hdiff,int45),!=(betamax,INF)),+(beta,+(parsertemp220863,parsertemp220864))),2.0) +::STMT +MATRIX:parsertemp222331 +FLOAT:sample_block_size +LITERAL_FLOAT:0.5 +round(+(0.5,/(parsertemp222331,sample_block_size))) +::STMT +MATRIX:parsertemp75086 +LITERAL_FLOAT:1.0,32.0 ++(*(parsertemp75086,32.0),1.0) +::STMT +MATRIX:parsertemp496901 +FLOAT:std +LITERAL_FLOAT:2.0 +^(*(cast.FLOAT(parsertemp496901),std),2.0) +::STMT +LITERAL_FLOAT:512.0,0.8 +*(512.0,0.8) +::STMT +MATRIX:ss +LITERAL_FLOAT:1.0,20.0 +-(/(20.0,ss),1.0) +::STMT +MATRIX:R +LITERAL_FLOAT:32.0 +>=(R,32.0) +::STMT +LITERAL_FLOAT:6.144102863722254 +6.144102863722254 +::STMT +MATRIX:y_prob,parsertemp560892,linear_terms,elt +FLOAT:int566,int338,int507 +LITERAL_FLOAT:1.0 ++(*(-(1.0,==(parsertemp560892,int566)),-(1.0,y_prob)),*(*(==(parsertemp560892,int338),exp(linear_terms)),-(1.0,/(elt,int507)))) +::STMT +FLOAT:float982,parsertemp169812 +LITERAL_FLOAT:4.0,0.5 +-(4.0,round(-(/(parsertemp169812,float982),0.5))) +::STMT +MATRIX:A +LITERAL_FLOAT:1.0 ++(/(1.0,cast.FLOAT(A)),/(1.0,cast.FLOAT(A))) +::STMT +MATRIX:W,H,X,parsertemp411102 +FLOAT:eps +*(H,/(%*%(t(W),X),+(%*%(parsertemp411102,H),eps))) +::STMT +MATRIX:parsertemp170101 +FLOAT:r_CG,g_reg,z,277_sq_root_d,parsertemp170108,parsertemp170093,pp_CG +LITERAL_FLOAT:0.5 ++(*(0.5,*(cast.FLOAT(z),+(r_CG,g_reg))),*(+(+(parsertemp170108,z),sum(parsertemp170101)),/(+(parsertemp170093,277_sq_root_d),pp_CG))) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,12.0 +-(+(i,12.0),1.0) +::STMT +FLOAT:max_iter +LITERAL_FLOAT:100.0 +/(max_iter,100.0) +::STMT +FLOAT:sample_block_size +LITERAL_FLOAT:1.0,3.0 +-(*(sample_block_size,3.0),1.0) +::STMT +MATRIX:p,A +sum(*(p,%*%(t(A),%*%(A,p)))) +::STMT +FLOAT:idx +LITERAL_FLOAT:1.0,256.0 ++(-(256.0,idx),1.0) +::STMT +MATRIX:Xm,parsertemp265717,Z +LITERAL_FLOAT:2.0 +*(2.0,sum(%*%(%*%(Z,parsertemp265717),t(Xm)))) +::STMT +MATRIX:X +FLOAT:x +-(nrow(X),sum(>=(X,x))) +::STMT +MATRIX:W +sqrt(sum(round(W))) +::STMT +MATRIX:linear_terms,Y +FLOAT:var_power +LITERAL_FLOAT:1.0 +*(^(exp(linear_terms),-(1.0,var_power)),-(Y,exp(linear_terms))) +::STMT +MATRIX:surv,se_surv +FLOAT:z_alpha_2,int420 +exp(/(*(*(z_alpha_2,int420),se_surv),surv)) +::STMT +MATRIX:WM +FLOAT:m2X +LITERAL_FLOAT:1.0 +*(m2X,/(sum(WM),-(sum(WM),1.0))) +::STMT +MATRIX:X2,85_s +LITERAL_FLOAT:1.0 +*(/(1.0,85_s),nrow(X2)) +::STMT +MATRIX:X +FLOAT:eps +*(eps,nrow(X)) +::STMT +MATRIX:W +FLOAT:int246,parsertemp65,int96,parsertemp66,wt +LITERAL_FLOAT:3.0,4.0 +*(*(*(-(wt,int96),-(wt,int246)),-(sum(W),3.0)),^(sqrt(/(parsertemp65,parsertemp66)),4.0)) +::STMT +LITERAL_FLOAT:2.0,100.0 +^(100.0,2.0) +::STMT +FLOAT:parsertemp65,parsertemp66,mu +LITERAL_FLOAT:5.0 +-(mu,*(5.0,sqrt(/(parsertemp65,parsertemp66)))) +::STMT +MATRIX:F +LITERAL_FLOAT:1.0 +/(F,-(sum(F),1.0)) +::STMT +MATRIX:mat_chol +/(nrow(mat_chol),ncol(mat_chol)) +::STMT +MATRIX:g_reg,p_CG +FLOAT:parsertemp170148,parsertemp170164,q_CG,z,int13,pq_CG,int470 +*(+(+(*(parsertemp170164,pq_CG),*(z,q_CG)),sum(*(g_reg,p_CG))),/(+(*(z,int470),sqrt(parsertemp170148)),sum(^(p_CG,int13)))) +::STMT +FLOAT:num_records,i +LITERAL_FLOAT:1.0 ++(*(num_records,-(i,1.0)),1.0) +::STMT +MATRIX:R +FLOAT:int595,int353 +INT:parsertemp503361,int790 ++(R,diag(rand(parsertemp503361,int790,int595,int353))) +::STMT +MATRIX:X +LITERAL_FLOAT:0.0 +t(!=(X,0.0)) +::STMT +FLOAT:offset_x +round(offset_x) +::STMT +MATRIX:X,tS +FLOAT:l +colSums(==(%*%(X,tS),l)) +::STMT +FLOAT:C,Hf,Wf +*(*(C,Hf),Wf) +::STMT +MATRIX:f,parsertemp472177,parsertemp472179 +-(%*%(f,parsertemp472177),t(parsertemp472179)) +::STMT +MATRIX:obj,objnew +-(cast.FLOAT(objnew),cast.FLOAT(obj)) +::STMT +MATRIX:lambda,g,beta +t(+(g,*(lambda,beta))) +::STMT +MATRIX:WM,CVars,CFreqs +FLOAT:int548 +/(sum(*(-(CFreqs,int548),CVars)),-(sum(WM),nrow(CFreqs))) +::STMT +MATRIX:X,W1,b1 ++(%*%(W1,t(X)),b1) +::STMT +MATRIX:linear_terms +FLOAT:var_power +LITERAL_FLOAT:1.0 +^(exp(linear_terms),-(1.0,var_power)) +::STMT +FLOAT:current_hash_value +LITERAL_FLOAT:1.0,9.0 +-(9.0,+(current_hash_value,1.0)) +::STMT +MATRIX:z +FLOAT:trust_delta_sq,pp_CG +*(pp_CG,-(*(cast.FLOAT(z),cast.FLOAT(z)),trust_delta_sq)) +::STMT +MATRIX:W +LITERAL_FLOAT:1.0,6.0 +*(*(6.0,sum(round(W))),-(sum(round(W)),1.0)) +::STMT +MATRIX:S +FLOAT:delta2 +LITERAL_FLOAT:2.0 +-(delta2,sum(^(S,2.0))) +::STMT +MATRIX:X +FLOAT:parsertemp72162,M +LITERAL_FLOAT:1.0 +-(*(+(parsertemp72162,1.0),M),ncol(X)) +::STMT +FLOAT:s,g,int170,num_groups +LITERAL_FLOAT:1.0,7.0 ++(*(*(-(s,int170),num_groups),7.0),*(-(g,1.0),7.0)) +::STMT +MATRIX:r,Hd +FLOAT:c +%*%(t(+(r,*(c,Hd))),+(r,*(c,Hd))) +::STMT +LITERAL_FLOAT:0.0,1.0,0.282842712474619 +INT:int945,int604 +*(rand(int945,int604,0.0,1.0),0.282842712474619) +::STMT +LITERAL_FLOAT:0.08333333333333333 +0.08333333333333333 +::STMT +MATRIX:resp,mean,X,weight +/(*(mean,%*%(t(resp),X)),t(weight)) +::STMT +LITERAL_FLOAT:-1.0E30 +INT:int924,M +rand(M,int924,-1.0E30,-1.0E30) +::STMT +FLOAT:x1,x2 +LITERAL_FLOAT:2.0 +^(-(x1,x2),2.0) +::STMT +MATRIX:r,scale_X,shift_X,y,parsertemp116004 +LITERAL_FLOAT:0.0 +-(0.0,+(*(scale_X,%*%(parsertemp116004,y)),*(cast.FLOAT(r),shift_X))) +::STMT +MATRIX:R,dssp +FLOAT:4_n,4_alpha +LITERAL_FLOAT:1.0 +*(-(1.0,4_alpha),-(/(4_n,+(R,dssp)),1.0)) +::STMT +LITERAL_FLOAT:0.6666666666666666 +0.6666666666666666 +::STMT +MATRIX:xs +FLOAT:254_x +LITERAL_FLOAT:1.0,100.0 ++(-(100.0,sum(>=(xs,254_x))),1.0) +::STMT +MATRIX:parsertemp109934 +LITERAL_FLOAT:1.0,42.0 ++(*(parsertemp109934,42.0),1.0) +::STMT +MATRIX:r +FLOAT:int435,tolerance +LITERAL_FLOAT:2.0 +sqrt(*(sum(^(r,int435)),^(tolerance,2.0))) +::STMT +MATRIX:Y +-(length(Y),sum(Y)) +::STMT +MATRIX:R,parsertemp40226,parsertemp40220 +FLOAT:eAvg +/(/(+(R,rowSums(parsertemp40226)),-(R,rowSums(parsertemp40220))),eAvg) +::STMT +MATRIX:P,Y,parsertemp221025,Z,ZERODIAG +FLOAT:int525 +LITERAL_FLOAT:1.0,4.0 +*(-(*(P,4.0),/(*(Z,ZERODIAG),sum(Z))),*(/(1.0,+(Y,int525)),+(diag(parsertemp221025),1.0))) +::STMT +MATRIX:r,parsertemp1945 +FLOAT:norm_r2 +/(sum(*(+(r,parsertemp1945),+(r,parsertemp1945))),norm_r2) +::STMT +MATRIX:p,q,lambda ++(q,*(lambda,p)) +::STMT +MATRIX:r,g,z +*(z,+(r,g)) +::STMT +MATRIX:parsertemp72333 +FLOAT:int203,rows +/(colSums(rowSums(^(parsertemp72333,int203))),rows) +::STMT +FLOAT:parsertemp40813,m2,m3 +LITERAL_FLOAT:3.0 +/(m3,^(sqrt(*(parsertemp40813,m2)),3.0)) +::STMT +MATRIX:s,w +FLOAT:lambda +*(lambda,sum(*(w,s))) +::STMT +MATRIX:posSampleVariances,negSampleVariances +FLOAT:int771,parsertemp31048,int464,parsertemp31047,int798,int713,parsertemp31053,parsertemp31052 +LITERAL_FLOAT:2.0 +/(^(+(/(posSampleVariances,int464),/(negSampleVariances,int713)),2.0),+(/(^(posSampleVariances,int771),*(parsertemp31047,parsertemp31048)),/(^(negSampleVariances,int798),*(parsertemp31052,parsertemp31053)))) +::STMT +MATRIX:y_hat,b,parsertemp31748 +sum(*(-(-(b,parsertemp31748),y_hat),-(-(b,parsertemp31748),y_hat))) +::STMT +FLOAT:parsertemp40813,m2,m4 +LITERAL_FLOAT:4.0 +/(m4,^(sqrt(*(parsertemp40813,m2)),4.0)) +::STMT +MATRIX:I,y2 +LITERAL_FLOAT:2.0 +sum(^(/(%*%(I,y2),sum(I)),2.0)) +::STMT +MATRIX:parsertemp500607,X,y,parsertemp500610,wnew +%*%(t(-(%*%(X,wnew),y)),-(%*%(X,*(parsertemp500607,parsertemp500610)),y)) +::STMT +LITERAL_FLOAT:0.1093262138242341 +0.1093262138242341 +::STMT +MATRIX:linear_terms +FLOAT:int267 +LITERAL_FLOAT:1.0,2.0 +-(1.0,-(*(2.0,>=(linear_terms,int267)),1.0)) +::STMT +MATRIX:parsertemp27546 +FLOAT:labelCorrection +t(-(parsertemp27546,labelCorrection)) +::STMT +MATRIX:parsertemp16959,id +-(==(id,t(id)),diag(diag(==(id,parsertemp16959)))) +::STMT +MATRIX:A,scale_X,shift_X,X ++(%*%(diag(scale_X),%*%(t(X),X)),%*%(shift_X,A)) +::STMT +FLOAT:191_beta2,191_t +LITERAL_FLOAT:1.0 +-(1.0,^(191_beta2,+(191_t,1.0))) +::STMT +FLOAT:parsertemp557354,weight,parsertemp557358,prob_true,prob_false +LITERAL_FLOAT:-1.0,0.6931471805599453 +*(*(-1.0,weight),+(/(*(prob_true,parsertemp557354),0.6931471805599453),/(*(prob_false,parsertemp557358),0.6931471805599453))) +::STMT +MATRIX:parsertemp31188,parsertemp31186 +FLOAT:int441 +LITERAL_FLOAT:6999.0,7000.0 +/(/(-(colSums(parsertemp31186),*(int441,parsertemp31188)),6999.0),7000.0) +::STMT +FLOAT:parsertemp40936,parsertemp40941,int194 +LITERAL_FLOAT:2.0,3.0,4.0,5.0,2001.0 +/(*(*(4.0,-(parsertemp40941,int194)),^(sqrt(parsertemp40936),2.0)),*(+(2001.0,5.0),-(2001.0,3.0))) +::STMT +MATRIX:prevTK2,X2 +==(%*%(X2,t(prevTK2)),t(rowSums(prevTK2))) +::STMT +MATRIX:sv,Xd +FLOAT:dd ++(dd,sum(*(*(Xd,sv),Xd))) +::STMT +FLOAT:num_records +LITERAL_FLOAT:1.0,3840.0 +-(1.0,/(3840.0,num_records)) +::STMT +MATRIX:parsertemp389217,parsertemp389216 +FLOAT:n +LITERAL_FLOAT:1.0 +sqrt(/(*(-(parsertemp389216,parsertemp389217),n),-(n,1.0))) +::STMT +MATRIX:parsertemp171346,parsertemp171344,linear_terms,the_exp +FLOAT:int422,int41 +LITERAL_FLOAT:1.0,1.0E7 +/(*(-(1.0,==(parsertemp171346,int422)),-(1.0,exp(parsertemp171344))),+(exp(linear_terms),==(+(int41,the_exp),1.0E7))) +::STMT +FLOAT:parsertemp166531 +LITERAL_FLOAT:2.0,10.0 ++(2.0,*(10.0,parsertemp166531)) +::STMT +FLOAT:parsertemp40837,parsertemp40832,int270 +LITERAL_FLOAT:2.0,3.0,4.0,5.0,2000.0 +/(*(*(4.0,-(parsertemp40837,int270)),^(sqrt(parsertemp40832),2.0)),*(+(2000.0,5.0),-(2000.0,3.0))) +::STMT +FLOAT:num_strata,num_groups +LITERAL_FLOAT:7.0 +*(*(num_groups,num_strata),7.0) +::STMT +FLOAT:run_index +LITERAL_FLOAT:1.0,2.0 +*(2.0,-(run_index,1.0)) +::STMT +MATRIX:output,mask +LITERAL_FLOAT:0.0,1.0 +&(==(output,0.0),==(mask,1.0)) +::STMT +MATRIX:p,G +LITERAL_FLOAT:0.85 +*(0.85,%*%(G,p)) +::STMT +FLOAT:parsertemp31330 +LITERAL_FLOAT:1.0,10000.0 +*(-(10000.0,1.0),/(*(parsertemp31330,10000.0),-(10000.0,1.0))) +::STMT +MATRIX:parsertemp220867,parsertemp220866,Hdiff,parsertemp220871,parsertemp220872,beta,betamin +FLOAT:int591 +LITERAL_FLOAT:2.0 ++(+(*(*(parsertemp220866,parsertemp220867),beta),/(*(parsertemp220871,parsertemp220872),2.0)),/(*(<(Hdiff,int591),+(beta,betamin)),2.0)) +::STMT +MATRIX:parsertemp382671,X +FLOAT:int751,int347 +LITERAL_FLOAT:0.5 +*(0.5,sum(*(!=(X,int347),^(parsertemp382671,int751)))) +::STMT +MATRIX:is_LT_infinite,Y_prob +LITERAL_FLOAT:1.0 ++(*(/(Y_prob,rowSums(Y_prob)),-(1.0,rowSums(is_LT_infinite))),is_LT_infinite) +::STMT +MATRIX:M2 +LITERAL_FLOAT:0.0 +&(!(!=(M2,0.0)),!=(M2,0.0)) +::STMT +FLOAT:parsertemp41040,int116,parsertemp41045 +LITERAL_FLOAT:2.0,3.0,4.0,5.0,2003.0 +/(*(*(4.0,-(parsertemp41045,int116)),^(sqrt(parsertemp41040),2.0)),*(+(2003.0,5.0),-(2003.0,3.0))) +::STMT +MATRIX:S,col_nonzeros,parsertemp382922,parsertemp382920 +sum(*(S,+(t(parsertemp382920),*(parsertemp382922,col_nonzeros)))) +::STMT +MATRIX:r,s,grad +-(cast.FLOAT(%*%(t(s),grad)),cast.FLOAT(%*%(t(s),r))) +::STMT +FLOAT:s,num_groups +LITERAL_FLOAT:1.0 ++(*(-(s,1.0),-(num_groups,1.0)),1.0) +::STMT +MATRIX:parsertemp1904,y +LITERAL_FLOAT:0.0,2.0 +sum(^(-(0.0,%*%(parsertemp1904,y)),2.0)) +::STMT +MATRIX:A +*(cast.FLOAT(A),cast.FLOAT(A)) +::STMT +MATRIX:parsertemp42200,F +LITERAL_FLOAT:1.0,2.0 ++(-(parsertemp42200,/(rowSums(F),2.0)),/(1.0,2.0)) +::STMT +MATRIX:tmp +FLOAT:N +LITERAL_FLOAT:1.0 +/(tmp,-(N,1.0)) +::STMT +MATRIX:C,Xm,parsertemp265702 +-(sum(%*%(%*%(Xm,parsertemp265702),t(C))),sum(Xm)) +::STMT +MATRIX:sig,parsertemp181037 +FLOAT:window_size,q +/(-(q,*(window_size,cast.FLOAT(parsertemp181037))),*(window_size,cast.FLOAT(*(sig,sig)))) +::STMT +MATRIX:parsertemp163760 +FLOAT:bin_length +/(rowSums(parsertemp163760),bin_length) +::STMT +MATRIX:X +FLOAT:value +!(<(X,value)) +::STMT +MATRIX:cumHistMul,offset,parsertemp132494,histMul,outBucket +-(offset,%*%(==(outBucket,t(parsertemp132494)),-(cumHistMul,histMul))) +::STMT +MATRIX:X +LITERAL_FLOAT:1.0 +/(nrow(X),-(nrow(X),1.0)) +::STMT +MATRIX:y_hat,A,B +-(-(B,%*%(A,y_hat)),y_hat) +::STMT +FLOAT:int395,Hin,Win +LITERAL_FLOAT:2.0,64.0 +*(*(64.0,/(/(Hin,int395),2.0)),/(/(Win,2.0),2.0)) +::STMT +MATRIX:R +FLOAT:s,i8 +-(nrow(R),*(s,i8)) +::STMT +MATRIX:X +LITERAL_FLOAT:1.0,1.0E-6 +/(*(1.0E-6,cast.FLOAT(%*%(X,X))),1.0) +::STMT +MATRIX:_sbcvar12 +LITERAL_FLOAT:999.0 +/(_sbcvar12,999.0) +::STMT +MATRIX:std,rad,dtd +/(-(rad,cast.FLOAT(std)),cast.FLOAT(dtd)) +::STMT +MATRIX:parsertemp79022 +LITERAL_FLOAT:1270.0 +/(parsertemp79022,1270.0) +::STMT +MATRIX:p,V +FLOAT:eps +%*%(t(p),+(%*%(t(V),%*%(V,p)),*(eps,p))) +::STMT +MATRIX:prec,X,mu +rowSums(*(-(%*%(X,prec),%*%(mu,prec)),-(%*%(X,prec),%*%(mu,prec)))) +::STMT +MATRIX:w +FLOAT:tau +*(tau,sum(abs(w))) +::STMT +MATRIX:p_CG +FLOAT:parsertemp170148,int652,z,int229 +LITERAL_FLOAT:0.5 +*(0.5,/(+(*(z,int652),sqrt(parsertemp170148)),sum(^(p_CG,int229)))) +::STMT +FLOAT:window_size,k +LITERAL_FLOAT:1.0 +-(+(k,window_size),1.0) +::STMT +FLOAT:m2,mu +LITERAL_FLOAT:1.0004995004995005 +/(sqrt(*(1.0004995004995005,m2)),mu) +::STMT +MATRIX:tmp,X,Y,out +-(%*%(t(X),*(out,Y)),tmp) +::STMT +MATRIX:_sbcvar92,parsertemp27718,parsertemp27720 +FLOAT:220_W,float581 +LITERAL_FLOAT:2.0 +^(-(_sbcvar92,+(*(parsertemp27720,float581),/(parsertemp27718,220_W))),2.0) +::STMT +MATRIX:f +LITERAL_FLOAT:1.0,2.0 +-(1.0,rowSums(^(f,2.0))) +::STMT +MATRIX:Xm,Z,parsertemp265713 +/(-(sum(%*%(Z,parsertemp265713)),sum(Xm)),sum(Xm)) +::STMT +MATRIX:B,X,y +LITERAL_FLOAT:2.0 +sum(^(-(y,%*%(X,B)),2.0)) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:1.0,750.0 +INT:int354,int611 +%*%(+(rowSums(classFeatureCounts),*(750.0,1.0)),rand(int354,int611,1.0,1.0)) +::STMT +MATRIX:curr_prediction +LITERAL_FLOAT:1.0 +sum(*(curr_prediction,-(1.0,curr_prediction))) +::STMT +LITERAL_FLOAT:1.001001001001001 +1.001001001001001 +::STMT +MATRIX:scale_X,X,y +LITERAL_FLOAT:0.0 +*(scale_X,%*%(-(0.0,t(X)),y)) +::STMT +MATRIX:vectors,pq_result +LITERAL_FLOAT:2.0 +rowSums(^(-(vectors,pq_result),2.0)) +::STMT +MATRIX:R,parsertemp72406 +LITERAL_FLOAT:2.0 +^(-(%*%(t(R),R),diag(parsertemp72406)),2.0) +::STMT +MATRIX:A,scale_X,shift_X ++(%*%(diag(scale_X),A),%*%(shift_X,A)) +::STMT +FLOAT:i7 +LITERAL_FLOAT:1.0 ++(1.0,i7) +::STMT +MATRIX:out2,parsertemp146940,184_dtemp,outd1,W3 +LITERAL_FLOAT:0.0 +%*%(t(outd1),*(>(out2,0.0),%*%(-(184_dtemp,parsertemp146940),t(W3)))) +::STMT +MATRIX:parsertemp42200,_sbcvar330 +LITERAL_FLOAT:2.0,0.5 ++(-(parsertemp42200,/(rowSums(_sbcvar330),2.0)),0.5) +::STMT +LITERAL_FLOAT:0.07261134713572442 +0.07261134713572442 +::STMT +FLOAT:int671,int784,parsertemp86,parsertemp87,int369,wt +sqrt(/(*(*(int369,wt),-(wt,int784)),*(*(parsertemp86,parsertemp87),+(wt,int671)))) +::STMT +MATRIX:U,V,X,parsertemp382840 +LITERAL_FLOAT:0.0 +%*%(*(!=(X,0.0),-(%*%(U,parsertemp382840),X)),V) +::STMT +MATRIX:P,D,beta +LITERAL_FLOAT:1.0E-12 +*(beta,/(rowSums(*(P,D)),+(rowSums(P),1.0E-12))) +::STMT +MATRIX:B,_sbcvar887 ++(%*%(_sbcvar887,B),cast.FLOAT(B)) +::STMT +MATRIX:R2,R1 +LITERAL_FLOAT:1.0E-6 +sum(<(abs(-(R1,R2)),1.0E-6)) +::STMT +MATRIX:resp,X +LITERAL_FLOAT:2.0,2.22E-16 +/(%*%(t(resp),^(X,2.0)),t(+(colSums(resp),2.22E-16))) +::STMT +MATRIX:linear_terms +FLOAT:int323 +LITERAL_FLOAT:3.141592653589793,1.0,2.0 +^(*(+(1.0,^(linear_terms,int323)),3.141592653589793),2.0) +::STMT +MATRIX:parsertemp409054,ctab +LITERAL_FLOAT:0.6 +>(/(parsertemp409054,rowSums(ctab)),0.6) +::STMT +MATRIX:parsertemp1654,A,scale_X,shift_X +%*%(diag(scale_X),t(+(%*%(parsertemp1654,A),%*%(shift_X,A)))) +::STMT +MATRIX:131_s,parsertemp115723 +FLOAT:eAvg +LITERAL_FLOAT:1.0,0.95 +*(0.95,-(/(/(parsertemp115723,131_s),eAvg),1.0)) +::STMT +MATRIX:W +LITERAL_FLOAT:6.0 +*(6.0,sum(round(W))) +::STMT +MATRIX:minD,D +t(/(<=(D,minD),rowSums(<=(D,minD)))) +::STMT +MATRIX:termination_bitmap,parsertemp222665 +FLOAT:num_successful_runs +/(sum(*(parsertemp222665,termination_bitmap)),num_successful_runs) +::STMT +MATRIX:tpr,fpr +LITERAL_FLOAT:2.0 +sum(/(*(-(fpr,fpr),+(tpr,tpr)),2.0)) +::STMT +MATRIX:d,parsertemp43998 +FLOAT:int458 +cast.FLOAT(%*%(t(d),+(d,*(int458,parsertemp43998)))) +::STMT +FLOAT:i8 +LITERAL_FLOAT:1.0,24.0 ++(1.0,*(24.0,i8)) +::STMT +MATRIX:Y,parsertemp2798,Xw +FLOAT:int484,int328 +LITERAL_FLOAT:0.0,1.0 +*(*(>(-(int328,parsertemp2798),0.0),-(1.0,*(Y,Xw))),*(>(-(int484,parsertemp2798),0.0),-(1.0,*(Y,Xw)))) +::STMT +FLOAT:m2 +LITERAL_FLOAT:1.0005002501250626,44.721359549995796 +/(sqrt(*(1.0005002501250626,m2)),44.721359549995796) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,133.0 ++(*(133.0,-(i,1.0)),1.0) +::STMT +MATRIX:X2,85_s +LITERAL_FLOAT:1.0 +-(*(/(1.0,85_s),nrow(X2)),1.0) +::STMT +FLOAT:df +LITERAL_FLOAT:4.890349128221754 +*(df,4.890349128221754) +::STMT +FLOAT:P,pIn,qIn,i8 ++(+(+(P,pIn),qIn),i8) +::STMT +FLOAT:i,k +LITERAL_FLOAT:2.0 +-(+(i,k),2.0) +::STMT +MATRIX:Y_row_norm,parsertemp16881 +FLOAT:epsilon +t(+(sqrt(rowSums(parsertemp16881)),*(<(Y_row_norm,epsilon),epsilon))) +::STMT +MATRIX:parsertemp387154,y +LITERAL_FLOAT:2.0 +cast.MATRIX(sum(^(-(y,parsertemp387154),2.0))) +::STMT +FLOAT:o_init,o +LITERAL_FLOAT:2.0 +-(*(2.0,o_init),*(2.0,o)) +::STMT +MATRIX:parsertemp149248,V,X,P_1K +-(*(P_1K,%*%(X,V)),*(P_1K,rowSums(*(P_1K,parsertemp149248)))) +::STMT +MATRIX:xs +FLOAT:254_x +LITERAL_FLOAT:100.0 +-(100.0,sum(>=(xs,254_x))) +::STMT +MATRIX:s,d,parsertemp44021 +FLOAT:delta2 +*(cast.FLOAT(%*%(t(d),d)),-(delta2,cast.FLOAT(%*%(parsertemp44021,s)))) +::STMT +MATRIX:w,ones_ns +*(ones_ns,cast.FLOAT(w)) +::STMT +MATRIX:parsertemp1511,X +FLOAT:int967,n +LITERAL_FLOAT:2.0 +-(t(colSums(^(X,int967))),*(n,^(/(parsertemp1511,n),2.0))) +::STMT +MATRIX:r,s,grad +-(%*%(t(s),grad),%*%(t(s),r)) +::STMT +MATRIX:A +LITERAL_FLOAT:1.0 ++(/(1.0,cast.FLOAT(A)),/(1.0,cast.FLOAT(A))) +::STMT +MATRIX:p,w,parsertemp1940 +FLOAT:norm_r2 ++(w,*(/(norm_r2,cast.FLOAT(parsertemp1940)),p)) +::STMT +LITERAL_FLOAT:1.0,2.0 +-(2.0,1.0) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:0.0 +sum(<(linear_terms,0.0)) +::STMT +LITERAL_FLOAT:0.0,1.0 +INT:parsertemp557199,int866 +==(diag(rand(parsertemp557199,int866,1.0,1.0)),0.0) +::STMT +MATRIX:X_row_norm,parsertemp16875,parsertemp16884,parsertemp16882 +FLOAT:epsilon +%*%(+(sqrt(rowSums(parsertemp16875)),*(<(X_row_norm,epsilon),epsilon)),t(+(sqrt(parsertemp16882),*(parsertemp16884,epsilon)))) +::STMT +MATRIX:parsertemp437548,pred,parsertemp437666 +colSums(==(*(parsertemp437666,t(parsertemp437548)),pred)) +::STMT +FLOAT:i,n +LITERAL_FLOAT:1.0,1048.0 +-(n,-(+(i,1048.0),1.0)) +::STMT +LITERAL_FLOAT:2003.0 +sqrt(2003.0) +::STMT +LITERAL_FLOAT:0.0,1.0,0.05 +INT:int670,int414 +*(rand(int670,int414,0.0,1.0),0.05) +::STMT +MATRIX:selCols2 +FLOAT:n +-(n,sum(selCols2)) +::STMT +MATRIX:ytest,yhat +FLOAT:int551,mean_y_test,int687 +LITERAL_FLOAT:2.0 +/(sum(^(-(ytest,yhat),2.0)),-(sum(^(ytest,int551)),*(nrow(ytest),^(mean_y_test,int687)))) +::STMT +MATRIX:B +FLOAT:M +*(ncol(B),M) +::STMT +MATRIX:s,w +cast.FLOAT(%*%(t(+(w,s)),+(w,s))) +::STMT +FLOAT:int99,arch_coef,var_coef,int481,a0 +INT:int876,int329 +rand(int876,int329,/(a0,-(-(int99,arch_coef),var_coef)),/(a0,-(-(int481,arch_coef),var_coef))) +::STMT +MATRIX:X +FLOAT:int675 +LITERAL_FLOAT:0.0 +sum(!=(rowSums(!=(X,int675)),0.0)) +::STMT +FLOAT:i,n +LITERAL_FLOAT:1.0,1024.0 +-(n,-(+(i,1024.0),1.0)) +::STMT +MATRIX:Y_counts +FLOAT:num_features +-(sum(Y_counts),num_features) +::STMT +MATRIX:ss,se +FLOAT:130_eAvg +/(/(se,ss),130_eAvg) +::STMT +MATRIX:adjacency +LITERAL_FLOAT:0.0 +>(rowSums(adjacency),0.0) +::STMT +MATRIX:parsertemp477718,parsertemp477715,X,Y +FLOAT:x +LITERAL_FLOAT:1.0 ++(*(-(1.0,/(parsertemp477715,parsertemp477718)),Y),*(/(-(x,X),-(X,X)),Y)) +::STMT +MATRIX:parsertemp11509 +LITERAL_FLOAT:1.0,2.0 ++(1.0,*(2.0,parsertemp11509)) +::STMT +MATRIX:finite_linear_terms,the_exp +FLOAT:int960 +LITERAL_FLOAT:1.0,2.0,1.0E7 +*(*(==(+(int960,the_exp),1.0E7),exp(finite_linear_terms)),-(1.0,/(exp(finite_linear_terms),2.0))) +::STMT +MATRIX:p,V +LITERAL_FLOAT:1.0E-8 ++(%*%(t(V),%*%(V,p)),*(1.0E-8,p)) +::STMT +MATRIX:linear_terms,Y +FLOAT:parsertemp171225,link_power,float353 +LITERAL_FLOAT:1.0 +*(^(linear_terms,-(/(parsertemp171225,link_power),1.0)),-(Y,^(linear_terms,/(float353,link_power)))) +::STMT +MATRIX:s,d +FLOAT:norm_r2,alpha_deno +t(+(s,*(/(norm_r2,alpha_deno),d))) +::STMT +MATRIX:X +FLOAT:x +-(nrow(X),sum(>=(X,x))) +::STMT +MATRIX:F,parsertemp42207 +LITERAL_FLOAT:2.0 +-(parsertemp42207,/(t(colSums(F)),2.0)) +::STMT +MATRIX:parsertemp389218 +FLOAT:int263,n +LITERAL_FLOAT:1.0E-17 ++(sqrt(/(*(parsertemp389218,n),-(n,int263))),1.0E-17) +::STMT +FLOAT:parsertemp170472,parsertemp170473,log_odds,learning_rate +LITERAL_FLOAT:1.0,2.7182818284 +/(^(2.7182818284,+(log_odds,*(learning_rate,parsertemp170472))),+(1.0,^(2.7182818284,+(log_odds,parsertemp170473)))) +::STMT +MATRIX:y_corr +LITERAL_FLOAT:0.0,1.0,0.5 +*(0.5,+(<=(y_corr,0.0),>=(y_corr,1.0))) +::STMT +FLOAT:parsertemp169814 +LITERAL_FLOAT:2.302585092994046,4.0 +exp(*(2.302585092994046,-(4.0,round(parsertemp169814)))) +::STMT +FLOAT:s +LITERAL_FLOAT:81.0,-1.0,3.0 +*(81.0,^(3.0,*(s,-1.0))) +::STMT +MATRIX:F +/(%*%(rowSums(F),colSums(F)),sum(F)) +::STMT +MATRIX:Yhat_prime,E,W4 +%*%(*(E,Yhat_prime),W4) +::STMT +MATRIX:d,parsertemp43996,parsertemp43997 +LITERAL_FLOAT:2.0 +%*%(t(d),+(d,*(2.0,%*%(parsertemp43996,parsertemp43997)))) +::STMT +LITERAL_FLOAT:1.0,10000.0 +/(10000.0,-(10000.0,1.0)) +::STMT +MATRIX:X,Y +LITERAL_FLOAT:2.0 +/(+(abs(X),abs(Y)),2.0) +::STMT +FLOAT:end_stepsize,k,kmax,start_stepsize +LITERAL_FLOAT:1.0 ++(*(-(1.0,/(k,kmax)),start_stepsize),*(/(k,kmax),end_stepsize)) +::STMT +FLOAT:int764,tau,int900 +INT:int902,m +*(tau,sum(abs(rand(m,int902,int764,int900)))) +::STMT +MATRIX:parsertemp13627,43_E +FLOAT:int232,43_q +LITERAL_FLOAT:1000.0 +sqrt(/(sum(/(parsertemp13627,43_E)),*(1000.0,-(43_q,int232)))) +::STMT +MATRIX:parsertemp31112,parsertemp31114,parsertemp31105,parsertemp31107 +LITERAL_FLOAT:1499.0,1999.0,1500.0,2000.0 ++(/(/(-(parsertemp31105,parsertemp31107),1999.0),2000.0),/(/(-(parsertemp31112,parsertemp31114),1499.0),1500.0)) +::STMT +MATRIX:l1,l2 +cast.FLOAT(<(l1,l2)) +::STMT +MATRIX:D,ZERODIAG,beta +FLOAT:int333 +*(exp(*(*(D,int333),beta)),ZERODIAG) +::STMT +MATRIX:y_hat,A,B +LITERAL_FLOAT:2.0 +^(-(-(B,%*%(A,y_hat)),y_hat),2.0) +::STMT +MATRIX:missing_indicator_mat +FLOAT:global_mean +*(missing_indicator_mat,global_mean) +::STMT +MATRIX:surv,se_surv +FLOAT:parsertemp538723 +*(surv,exp(/(*(parsertemp538723,se_surv),surv))) +::STMT +MATRIX:parsertemp31026,parsertemp31033 +FLOAT:parsertemp31034,parsertemp31027 +LITERAL_FLOAT:2.0,150.0,100.0 +^(+(/(/(parsertemp31026,parsertemp31027),100.0),/(/(parsertemp31033,parsertemp31034),150.0)),2.0) +::STMT +MATRIX:Y +LITERAL_FLOAT:1.0 ++(-(ncol(Y),1.0),1.0) +::STMT +MATRIX:parsertemp396410,parsertemp396407,W3_rand +LITERAL_FLOAT:0.16823164622761327 +t(%*%(*(0.16823164622761327,W3_rand),t(/(parsertemp396407,parsertemp396410)))) +::STMT +MATRIX:g_Y,scale_X,X +LITERAL_FLOAT:0.0 +*(cast.FLOAT(diag(scale_X)),%*%(-(0.0,t(X)),g_Y)) +::STMT +MATRIX:parsertemp429917,parsertemp429915 +LITERAL_FLOAT:0.0,1.0,299.0 +-(1.0,<=(/(-(parsertemp429915,parsertemp429917),299.0),0.0)) +::STMT +MATRIX:r,g,z +sum(*(z,+(r,g))) +::STMT +FLOAT:delta +LITERAL_FLOAT:0.25 +*(0.25,delta) +::STMT +FLOAT:arch_coef,int306,var_coef,a0 +sqrt(/(a0,-(-(int306,arch_coef),var_coef))) +::STMT +MATRIX:parsertemp149323,LT,Y +LITERAL_FLOAT:-1.0 +*(sum(*(Y,-(LT,parsertemp149323))),-1.0) +::STMT +MATRIX:p_CG,z +LITERAL_FLOAT:-1.0 +*(*(cast.FLOAT(z),sum(p_CG)),-1.0) +::STMT +MATRIX:parsertemp24101 +FLOAT:num_bins,float936 +LITERAL_FLOAT:1.0 +>(+(round(-(parsertemp24101,float936)),1.0),num_bins) +::STMT +MATRIX:parsertemp459193,2700_dX,2703_X,2703_W +FLOAT:lr +LITERAL_FLOAT:5.0E-4 +*(lr,+(%*%(t(2703_X),*(parsertemp459193,2700_dX)),*(5.0E-4,2703_W))) +::STMT +MATRIX:Y +FLOAT:parsertemp185166 +>(-(cast.MATRIX(max(Y)),parsertemp185166),-(parsertemp185166,min(Y))) +::STMT +FLOAT:m2 +LITERAL_FLOAT:1.0,2001.0 +*(/(2001.0,-(2001.0,1.0)),m2) +::STMT +MATRIX:w,parsertemp43626 +FLOAT:int200,float862,float690 +LITERAL_FLOAT:2.0,0.5 +INT:int774,int235 ++(*(0.5,%*%(t(w),rand(int774,int235,float690,float862))),*(2.0,sum(*(parsertemp43626,int200)))) +::STMT +MATRIX:R,parsertemp40216,parsertemp40225 +/(+(R,rowSums(*(parsertemp40216,parsertemp40225))),R) +::STMT +FLOAT:max_depth +LITERAL_FLOAT:1.0,2.0 +-(^(2.0,max_depth),1.0) +::STMT +LITERAL_FLOAT:1.0,2001.0 +/(2001.0,-(2001.0,1.0)) +::STMT +FLOAT:int154,i +LITERAL_FLOAT:1.0,100.0 ++(*(*(-(i,int154),100.0),100.0),1.0) +::STMT +MATRIX:131_s +FLOAT:n2,int815 +LITERAL_FLOAT:0.050000000000000044,1.0 +*(0.050000000000000044,-(*(/(int815,131_s),n2),1.0)) +::STMT +MATRIX:b,E,X,sb +%*%(colSums(*(X,E)),+(b,sb)) +::STMT +MATRIX:p,r,Z +FLOAT:norm_r2,parsertemp503396 +LITERAL_FLOAT:-1.0 +*(+(r,*(/(norm_r2,parsertemp503396),%*%(Z,p))),-1.0) +::STMT +FLOAT:obj,obj_new,gs +LITERAL_FLOAT:-0.5 +/(*(-0.5,gs),-(-(obj_new,obj),gs)) +::STMT +FLOAT:step +LITERAL_FLOAT:0.85 +*(step,0.85) +::STMT +MATRIX:w_X,z_LS,X +/(nrow(X),sum(*(w_X,z_LS))) +::STMT +MATRIX:parsertemp285531,z,parsertemp285533 +FLOAT:pp,sq_root_d,zq,parsertemp285523,parsertemp285538 +LITERAL_FLOAT:0.5 ++(*(0.5,sum(*(z,parsertemp285533))),*(+(+(parsertemp285538,zq),sum(parsertemp285531)),/(+(parsertemp285523,sq_root_d),pp))) +::STMT +FLOAT:191_beta2,191_t +LITERAL_FLOAT:1.0 +^(191_beta2,+(191_t,1.0)) +::STMT +MATRIX:2814_K +LITERAL_FLOAT:0.0 +cast.FLOAT(-(0.0,2814_K)) +::STMT +MATRIX:dw,history +FLOAT:sigma,float741,alpha +-(max(history),*(*(*(float741,sigma),alpha),sum(*(dw,dw)))) +::STMT +MATRIX:X,parsertemp222929 ++(X,*(cast.FLOAT(parsertemp222929),-(X,X))) +::STMT +MATRIX:dout1 +FLOAT:192_beta1 +LITERAL_FLOAT:1.0 +*(-(1.0,192_beta1),colSums(dout1)) +::STMT +FLOAT:lambda,beta +LITERAL_FLOAT:0.0,2.0 +sqrt(^(+(0.0,*(lambda,beta)),2.0)) +::STMT +MATRIX:C,parsertemp11064 +LITERAL_FLOAT:10000.0,100.0 +*(/(sum(==(parsertemp11064,C)),10000.0),100.0) +::STMT +FLOAT:N +LITERAL_FLOAT:1.0 +/(N,-(N,1.0)) +::STMT +MATRIX:residual_matrix +LITERAL_FLOAT:2.0 +^(sum(residual_matrix),2.0) +::STMT +MATRIX:E,F +LITERAL_FLOAT:0.001 +sum(<(-(E,F),0.001)) +::STMT +MATRIX:parsertemp170505 +LITERAL_FLOAT:-1.0,2.0 +sum(^(*(t(parsertemp170505),-1.0),2.0)) +::STMT +MATRIX:parsertemp1518,parsertemp1516,parsertemp1514 +FLOAT:parsertemp1519,n +LITERAL_FLOAT:0.0,1.0 +*(/(-(t(parsertemp1514),*(n,parsertemp1516)),-(n,1.0)),-(1.0,<=(/(parsertemp1518,parsertemp1519),0.0))) +::STMT +MATRIX:resp,X,parsertemp437188 +FLOAT:float191 +*(/(%*%(t(resp),X),t(+(parsertemp437188,float191))),%*%(t(resp),X)) +::STMT +LITERAL_FLOAT:225.0 +INT:int873,int730 +rand(int873,int730,225.0,225.0) +::STMT +MATRIX:X_batch,parsertemp389606,2364_2361_Y,parsertemp389586 +FLOAT:int440 +LITERAL_FLOAT:1.0 +%*%(t(*(-(2364_2361_Y,X_batch),-(int440,parsertemp389606))),/(-(exp(parsertemp389586),1.0),+(exp(parsertemp389586),1.0))) +::STMT +LITERAL_FLOAT:1.0 ++(+(1.0,1.0),1.0) +::STMT +MATRIX:2846_Q,X +FLOAT:int123,int579 +LITERAL_FLOAT:2.0 +-(+(rowSums(^(X,int123)),sum(^(2846_Q,int579))),*(2.0,%*%(X,t(2846_Q)))) +::STMT +MATRIX:s,w +LITERAL_FLOAT:0.5 +*(0.5,%*%(t(+(w,s)),+(w,s))) +::STMT +FLOAT:FN,FP,TN,TP +*(*(+(TP,FP),+(TP,FN)),+(TN,FP)) +::STMT +MATRIX:r,w +FLOAT:tau +LITERAL_FLOAT:0.5 ++(*(0.5,sum(*(r,r))),*(tau,sum(abs(w)))) +::STMT +MATRIX:parsertemp31190,parsertemp31197 +FLOAT:parsertemp31191,parsertemp31198 +LITERAL_FLOAT:2.0,1500.0,7000.0 +^(+(/(/(parsertemp31190,parsertemp31191),7000.0),/(/(parsertemp31197,parsertemp31198),1500.0)),2.0) +::STMT +MATRIX:flip_neg,is_LT_infinite,Y_prob,parsertemp171292,parsertemp171290 +FLOAT:float877 +%*%(+(*(/(Y_prob,parsertemp171290),-(float877,parsertemp171292)),is_LT_infinite),flip_neg) +::STMT +MATRIX:parsertemp171090,is_one_y_corr,t,parsertemp171096,parsertemp171080 +FLOAT:int787,float950 +LITERAL_FLOAT:1.0 ++(*(+(-(int787,t),/(parsertemp171090,parsertemp171096)),-(1.0,*(float950,parsertemp171080))),/(is_one_y_corr,-(1.0,is_one_y_corr))) +::STMT +MATRIX:simplex +FLOAT:num_func_invoc +LITERAL_FLOAT:1.0 ++(num_func_invoc,-(ncol(simplex),1.0)) +::STMT +MATRIX:parsertemp220848,parsertemp220853,parsertemp220850,beta +FLOAT:float768 ++(parsertemp220853,*(beta,/(rowSums(parsertemp220850),+(parsertemp220848,float768)))) +::STMT +MATRIX:W +FLOAT:m2 +LITERAL_FLOAT:1.0 +sqrt(/(*(m2,sum(W)),-(sum(W),1.0))) +::STMT +MATRIX:neighbors +diag(diag(neighbors)) +::STMT +MATRIX:X,y +LITERAL_FLOAT:0.0,2.0 +^(%*%(-(0.0,t(X)),y),2.0) +::STMT +MATRIX:S,addedX2 +FLOAT:level +==(%*%(S,t(addedX2)),level) +::STMT +MATRIX:p,e,u,G +LITERAL_FLOAT:0.15000000000000002,0.85 ++(*(0.85,%*%(G,p)),*(0.15000000000000002,%*%(%*%(e,u),p))) +::STMT +MATRIX:C,tmp,parsertemp265713 +FLOAT:Xm ++(Xm,trace(*(tmp,%*%(parsertemp265713,C)))) +::STMT +MATRIX:parsertemp42190,X +LITERAL_FLOAT:2.0 +-(parsertemp42190,/(X,2.0)) +::STMT +MATRIX:s +LITERAL_FLOAT:2.0 +sum(^(s,2.0)) +::STMT +MATRIX:lambda,g,beta +%*%(t(+(g,*(lambda,beta))),+(g,*(lambda,beta))) +::STMT +MATRIX:dW2 +FLOAT:193_beta2 +LITERAL_FLOAT:1.0,2.0 +*(-(1.0,193_beta2),^(dW2,2.0)) +::STMT +MATRIX:parsertemp163717,p_gaps_vector +t(*(parsertemp163717,p_gaps_vector)) +::STMT +MATRIX:img_in1 +FLOAT:weight +LITERAL_FLOAT:1.0 +*(-(1.0,weight),img_in1) +::STMT +MATRIX:dout1,mb1 +FLOAT:parsertemp147007,192_t,192_lr,192_beta1,int736 +LITERAL_FLOAT:1.0 +*(/(*(192_lr,sqrt(parsertemp147007)),-(1.0,^(192_beta1,192_t))),+(*(192_beta1,mb1),*(-(int736,192_beta1),colSums(dout1)))) +::STMT +FLOAT:parsertemp169812 +LITERAL_FLOAT:2.302585092994046 +/(parsertemp169812,2.302585092994046) +::STMT +MATRIX:residuals_vector +LITERAL_FLOAT:0.0 +/(sum(residuals_vector),+(nrow(residuals_vector),0.0)) +::STMT +MATRIX:ZtZ,parsertemp265709,Xm,parsertemp265707,parsertemp265705,Z,parsertemp265702 +%*%(t(/(%*%(parsertemp265709,Z),sum(ZtZ))),/(%*%(t(Xm),%*%(Xm,parsertemp265702)),sum(+(parsertemp265705,parsertemp265707)))) +::STMT +FLOAT:dd,step_sz +*(step_sz,dd) +::STMT +MATRIX:WM,CVars,CFreqs +FLOAT:parsertemp31268,int795,W,float277 +LITERAL_FLOAT:1.0 +/(sum(*(-(CFreqs,int795),CVars)),*(-(sum(WM),1.0),/(*(parsertemp31268,W),-(W,float277)))) +::STMT +MATRIX:ss,se +/(se,ss) +::STMT +MATRIX:g_Y,scale_X,X +LITERAL_FLOAT:-1.0 +%*%(diag(scale_X),%*%(*(t(X),-1.0),g_Y)) +::STMT +MATRIX:maskd1,out1 +FLOAT:p +LITERAL_FLOAT:0.0 +*(>(out1,0.0),/(maskd1,p)) +::STMT +MATRIX:V,W,parsertemp10741,H +LITERAL_FLOAT:1.0E-8 +*(H,/(%*%(t(W),V),+(%*%(parsertemp10741,H),1.0E-8))) +::STMT +MATRIX:parsertemp386448,withinEps +LITERAL_FLOAT:0.0,1.0 +>(colSums(>(*(parsertemp386448,withinEps),0.0)),1.0) +::STMT +MATRIX:finite_linear_terms +LITERAL_FLOAT:-1.0 +*(exp(finite_linear_terms),-1.0) +::STMT +MATRIX:X +FLOAT:x +LITERAL_FLOAT:1.0 ++(-(nrow(X),sum(>=(X,x))),1.0) +::STMT +MATRIX:parsertemp122290,X2 +LITERAL_FLOAT:0.0,4.0 +|(<(t(colSums(X2)),4.0),==(t(%*%(parsertemp122290,X2)),0.0)) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:1.0,50.0 +INT:int258,int270 +%*%(+(rowSums(classFeatureCounts),*(50.0,1.0)),rand(int270,int258,1.0,1.0)) +::STMT +MATRIX:f,parsertemp472177,I,parsertemp472179 +LITERAL_FLOAT:2.0 +^(*(I,-(%*%(f,parsertemp472177),t(parsertemp472179))),2.0) +::STMT +MATRIX:parsertemp387552 +LITERAL_FLOAT:10.0 +^(10.0,parsertemp387552) +::STMT +MATRIX:parsertemp72182 +FLOAT:subvector_size +LITERAL_FLOAT:1.0 ++(*(parsertemp72182,subvector_size),1.0) +::STMT +MATRIX:Y,parsertemp282723 +==(Y,cast.FLOAT(parsertemp282723)) +::STMT +MATRIX:Xm,parsertemp265733 +abs(/(sum(-(parsertemp265733,Xm)),sum(Xm))) +::STMT +FLOAT:end_stepsize,k,kmax +*(/(k,kmax),end_stepsize) +::STMT +MATRIX:parsertemp271862,parsertemp271860 +FLOAT:obj,parsertemp271888 +LITERAL_FLOAT:-0.5 +/(-(obj,parsertemp271888),*(-0.5,-(sum(parsertemp271860),sum(parsertemp271862)))) +::STMT +MATRIX:parsertemp500606,parsertemp500604,w +FLOAT:int50 +t(-(*(*(parsertemp500604,parsertemp500606),>(parsertemp500606,int50)),w)) +::STMT +MATRIX:binary_array +LITERAL_FLOAT:1.0 ++(1.0,cast.FLOAT(binary_array)) +::STMT +MATRIX:R,dssp,dsep,parsertemp40232,parsertemp40220 +FLOAT:eAvg +/(/(-(+(R,dsep),rowSums(parsertemp40232)),-(+(R,dssp),rowSums(parsertemp40220))),eAvg) +::STMT +MATRIX:parsertemp386457,parsertemp386459,neighbors,corePts,withinEps,parsertemp386456 +LITERAL_FLOAT:0.0 +*(>(*(*(neighbors,corePts),withinEps),0.0),==(-(*(parsertemp386456,parsertemp386457),parsertemp386459),0.0)) +::STMT +MATRIX:parsertemp222331 +LITERAL_FLOAT:200.0,0.5 +round(+(0.5,/(parsertemp222331,200.0))) +::STMT +MATRIX:CFreqs1,present_domain_vals_mat,parsertemp27487 +LITERAL_FLOAT:1.0 +*(-(%*%(present_domain_vals_mat,CFreqs1),1.0),%*%(present_domain_vals_mat,parsertemp27487)) +::STMT +MATRIX:p,e,u +%*%(%*%(e,u),p) +::STMT +MATRIX:parsertemp220853,parsertemp220854,betamin +FLOAT:logU +LITERAL_FLOAT:0.0 +*(<(-(+(parsertemp220853,parsertemp220854),logU),0.0),betamin) +::STMT +MATRIX:b,E,X,sb +cast.FLOAT(%*%(colSums(*(X,E)),+(b,sb))) +::STMT +MATRIX:sb +FLOAT:delta +LITERAL_FLOAT:2.0 +-(sum(^(sb,2.0)),^(delta,2.0)) +::STMT +MATRIX:parsertemp171084,parsertemp171083 +LITERAL_FLOAT:0.010328,-2.0,0.802853 +*(sqrt(*(-2.0,parsertemp171083)),+(0.802853,*(sqrt(parsertemp171084),0.010328))) +::STMT +MATRIX:c,G +*(G,t(c)) +::STMT +MATRIX:parsertemp399242,W3_rand +FLOAT:int741,int312 +LITERAL_FLOAT:0.6546536707079771 +%*%(*(0.6546536707079771,W3_rand),t(/(-(parsertemp399242,int741),+(parsertemp399242,int312)))) +::STMT +FLOAT:parsertemp164939 +LITERAL_FLOAT:2.0,100.0 ++(2.0,*(100.0,parsertemp164939)) +::STMT +MATRIX:p,p2 +LITERAL_FLOAT:1.0E8 +>(abs(-(p2,p)),1.0E8) +::STMT +MATRIX:ytest,yhat +sum(-(ytest,yhat)) +::STMT +MATRIX:parsertemp221021 +LITERAL_FLOAT:1.0 ++(diag(parsertemp221021),1.0) +::STMT +MATRIX:CFreqs1,present_domain_vals_mat +LITERAL_FLOAT:1.0 +-(%*%(present_domain_vals_mat,CFreqs1),1.0) +::STMT +MATRIX:G,authorities +/(%*%(G,authorities),max(%*%(G,authorities))) +::STMT +LITERAL_FLOAT:1.0,2003.0 +-(2003.0,1.0) +::STMT +MATRIX:parsertemp137847,keyPos1 +*(t(parsertemp137847),keyPos1) +::STMT +MATRIX:s,w,wnew,parsertemp44079 +LITERAL_FLOAT:-1.0,2.0,0.5 ++(*(0.5,%*%(t(wnew),+(w,s))),*(2.0,*(-1.0,sum(parsertemp44079)))) +::STMT +MATRIX:m_iter_err_sum_squared,parsertemp379562,parsertemp379571,m_iter_err_sum,parsertemp379569 +FLOAT:i_process_item +LITERAL_FLOAT:1.0 +/(+(-(*(parsertemp379569,i_process_item),*(parsertemp379571,m_iter_err_sum)),+(colSums(parsertemp379562),m_iter_err_sum_squared)),-(i_process_item,1.0)) +::STMT +MATRIX:p,r,Z +FLOAT:norm_r2,parsertemp503396 +LITERAL_FLOAT:2.0 +^(+(r,*(/(norm_r2,parsertemp503396),%*%(Z,p))),2.0) +::STMT +MATRIX:dX,v +FLOAT:lr,mu +-(*(mu,v),*(lr,dX)) +::STMT +FLOAT:246_AIC_best,246_thr +abs(*(246_thr,246_AIC_best)) +::STMT +MATRIX:X,Centering,ScaleFactor +%*%(t(/(-(X,Centering),ScaleFactor)),/(-(X,Centering),ScaleFactor)) +::STMT +MATRIX:d,X,logisticD +LITERAL_FLOAT:2.0 +*(2.0,%*%(t(X),*(logisticD,%*%(X,d)))) +::STMT +MATRIX:U,row_nonzeros +LITERAL_FLOAT:2.0 +sum(*(^(U,2.0),row_nonzeros)) +::STMT +MATRIX:s,w +LITERAL_FLOAT:0.5 +*(0.5,%*%(t(+(w,s)),+(w,s))) +::STMT +MATRIX:parsertemp410979,W,X,H,parsertemp410981,parsertemp410984 +/(*(W,%*%(/(X,parsertemp410984),t(H))),t(rowSums(/(parsertemp410979,parsertemp410981)))) +::STMT +MATRIX:S,parsertemp382904,V,W,row_nonzeros +LITERAL_FLOAT:1.0E-6 ++(%*%(*(W,%*%(S,parsertemp382904)),V),*(*(1.0E-6,S),row_nonzeros)) +::STMT +FLOAT:parsertemp31330 +LITERAL_FLOAT:9999.0,10000.0 +*(9999.0,/(*(parsertemp31330,10000.0),9999.0)) +::STMT +LITERAL_FLOAT:3.0,2003.0 +-(2003.0,3.0) +::STMT +MATRIX:out +LITERAL_FLOAT:0.5 +*(0.5,cast.FLOAT(%*%(t(out),out))) +::STMT +MATRIX:upd_W1,W1_grad,W1 +FLOAT:parsertemp389637,mu,step ++(W1,-(*(mu,upd_W1),*(/(step,parsertemp389637),W1_grad))) +::STMT +MATRIX:ones,classFeatureCounts +FLOAT:float714,int456 +LITERAL_FLOAT:1.0 +/(+(classFeatureCounts,1.0),%*%(+(rowSums(classFeatureCounts),*(int456,float714)),ones)) +::STMT +LITERAL_FLOAT:2.0,2001.0 +^(2001.0,2.0) +::STMT +MATRIX:W1_rand,X,parsertemp400556,parsertemp400566 +FLOAT:float936 +LITERAL_FLOAT:0.08333333333333333 +%*%(*(0.08333333333333333,W1_rand),t(/(-(X,parsertemp400556),+(parsertemp400566,float936)))) +::STMT +FLOAT:avg_res,ytest,mean_y_test,int765,yhat,int958 +LITERAL_FLOAT:1.0,2.0 +/(-(^(-(ytest,yhat),2.0),*(1.0,^(avg_res,int958))),-(^(cast.FLOAT(ytest),2.0),*(1.0,^(mean_y_test,int765)))) +::STMT +MATRIX:X +FLOAT:x +/(-(x,X),-(cast.FLOAT(X),cast.FLOAT(X))) +::STMT +MATRIX:col_nonzeros,U,V,row_nonzeros +FLOAT:int917,int674 ++(sum(*(^(U,int917),row_nonzeros)),sum(*(^(V,int674),col_nonzeros))) +::STMT +MATRIX:parsertemp24102 +FLOAT:num_bins +LITERAL_FLOAT:1.0 +*(>(+(round(parsertemp24102),1.0),num_bins),num_bins) +::STMT +MATRIX:parsertemp539204 +FLOAT:float276,float683 +LITERAL_FLOAT:0.6666666666666666 +-(max(^(/(parsertemp539204,float276),0.6666666666666666)),min(^(/(parsertemp539204,float683),0.6666666666666666))) +::STMT +MATRIX:r,d,Hd,parsertemp44001 +FLOAT:int112 +*(/(sum(^(r,int112)),cast.FLOAT(%*%(parsertemp44001,Hd))),d) +::STMT +MATRIX:m_active_flag +LITERAL_FLOAT:0.0 +t(==(m_active_flag,0.0)) +::STMT +LITERAL_FLOAT:1.0005002501250626 +1.0005002501250626 +::STMT +MATRIX:parsertemp170242,parsertemp170240,parsertemp170238 +FLOAT:float516,float545,float457 +LITERAL_FLOAT:1.0,1.421413741 +*(/(1.0,+(1.0,*(parsertemp170238,float545))),+(1.421413741,*(/(float457,parsertemp170240),+(float516,parsertemp170242)))) +::STMT +LITERAL_FLOAT:2.0,2003.0 +-(2003.0,2.0) +::STMT +MATRIX:t_gp,parsertemp170243,parsertemp170239 +FLOAT:float433 +LITERAL_FLOAT:1.0,1.421413741,-0.284496736 ++(-0.284496736,*(/(1.0,+(float433,parsertemp170239)),+(1.421413741,*(t_gp,parsertemp170243)))) +::STMT +MATRIX:X +FLOAT:int432 +LITERAL_FLOAT:1.0E-6 +<(sqrt(rowSums(^(X,int432))),1.0E-6) +::STMT +MATRIX:A +LITERAL_FLOAT:1.0 +>=(rowSums(abs(A)),1.0) +::STMT +FLOAT:int709,b +LITERAL_FLOAT:2.0 +-(^(b,2.0),int709) +::STMT +MATRIX:B +FLOAT:M +/(nrow(B),M) +::STMT +MATRIX:simplex +LITERAL_FLOAT:0.0 ++(0.0,ncol(simplex)) +::STMT +MATRIX:minD +FLOAT:sumXsq ++(sumXsq,sum(minD)) +::STMT +MATRIX:H,parsertemp220860,parsertemp220861,beta +FLOAT:logU +LITERAL_FLOAT:0.0,2.0 +/(*(<(-(H,logU),0.0),+(beta,+(parsertemp220860,parsertemp220861))),2.0) +::STMT +MATRIX:output1,dataset +LITERAL_FLOAT:0.16 +<(abs(-(output1,dataset)),0.16) +::STMT +MATRIX:r,s,grad +LITERAL_FLOAT:-0.5 +*(-0.5,-(%*%(t(s),grad),%*%(t(s),r))) +::STMT +MATRIX:parsertemp131907,cumHistMul,offset,parsertemp132092,histMul,outBucket +-(offset,%*%(==(outBucket,%*%(parsertemp132092,parsertemp131907)),-(cumHistMul,histMul))) +::STMT +LITERAL_FLOAT:-1.0,0.001 +*(0.001,-1.0) +::STMT +MATRIX:centroid_placer,X_samples +%*%(centroid_placer,%*%(centroid_placer,X_samples)) +::STMT +LITERAL_FLOAT:0.0,1.0 +/(1.0,0.0) +::STMT +LITERAL_FLOAT:1.0,2.0 +/(1.0,2.0) +::STMT +LITERAL_FLOAT:-1.0,2.0 +/(-1.0,2.0) +::STMT +MATRIX:X +LITERAL_FLOAT:1.0,2.0 +/(*($1:ncol(X),+($1,1.0)),2.0) +::STMT +MATRIX:parsertemp165076,X,y +LITERAL_FLOAT:2.0 +/(sum(^(-(y,parsertemp165076),2.0)),nrow(X)) +::STMT +MATRIX:parsertemp170277 +LITERAL_FLOAT:3.141592653589793 +/(parsertemp170277,3.141592653589793) +::STMT +MATRIX:parsertemp403497,parsertemp403500,W3_rand +LITERAL_FLOAT:0.1651445647689541 +t(%*%(*(0.1651445647689541,W3_rand),t(/(parsertemp403497,parsertemp403500)))) +::STMT +MATRIX:parsertemp286536,parsertemp286535 +FLOAT:float220 +sqrt(cast.FLOAT(%*%(t(parsertemp286536),+(float220,parsertemp286535)))) +::STMT +MATRIX:n_event_stratum,n_risk_stratum,n_risk +/(*(n_risk,n_event_stratum),n_risk_stratum) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,80656.0 +*(-(i,1.0),80656.0) +::STMT +MATRIX:y_corr +LITERAL_FLOAT:1.0 +-(1.0,>=(y_corr,1.0)) +::STMT +MATRIX:means,parsertemp389215 +FLOAT:n +LITERAL_FLOAT:1.0 +/(*(-(/(parsertemp389215,n),*(means,means)),n),-(n,1.0)) +::STMT +FLOAT:max_depth +LITERAL_FLOAT:1.0,2.0 +*(2.0,-(^(2.0,max_depth),1.0)) +::STMT +LITERAL_FLOAT:1.0,1.5 +/(1.0,1.5) +::STMT +FLOAT:e,mu +LITERAL_FLOAT:0.999,4.0 ++(mu,/(-(0.999,mu),-(4.0,e))) +::STMT +MATRIX:B2,ytest,Xtest,parsertemp387577 +cast.FLOAT(%*%(t(-(ytest,parsertemp387577)),-(ytest,%*%(Xtest,B2)))) +::STMT +MATRIX:r,obj,parsertemp44063,parsertemp44077,parsertemp44065,grad +FLOAT:float27,C,parsertemp44081 +LITERAL_FLOAT:-0.5 +/(-(obj,+(*(float27,parsertemp44077),*(C,parsertemp44081))),*(-0.5,-(%*%(parsertemp44063,grad),%*%(parsertemp44065,r)))) +::STMT +MATRIX:LT,parsertemp149320,parsertemp150469 +exp(-(LT,%*%(parsertemp149320,parsertemp150469))) +::STMT +FLOAT:i +LITERAL_FLOAT:80656.0 +*(i,80656.0) +::STMT +MATRIX:r,d,parsertemp43998 +FLOAT:C +/(sum(*(r,r)),%*%(t(d),+(d,*(C,parsertemp43998)))) +::STMT +MATRIX:X,parsertemp220785 +FLOAT:int457,int358 +LITERAL_FLOAT:-2.0 ++(+(*(-2.0,%*%(X,parsertemp220785)),rowSums(^(X,int457))),t(rowSums(^(X,int358)))) +::STMT +MATRIX:D,parsertemp10961,parsertemp10958 ++(%*%(D,t(parsertemp10958)),t(parsertemp10961)) +::STMT +LITERAL_FLOAT:1.0,10.0 +/(1.0,10.0) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:105.0,1.0 +INT:int623,int718 +%*%(+(rowSums(classFeatureCounts),*(105.0,1.0)),rand(int623,int718,1.0,1.0)) +::STMT +MATRIX:237_present_domain_vals_mat +LITERAL_FLOAT:10000.0 +-(10000.0,nrow(237_present_domain_vals_mat)) +::STMT +MATRIX:F +LITERAL_FLOAT:1.0 +/(F,-(sum(F),1.0)) +::STMT +MATRIX:Q +FLOAT:int677 +LITERAL_FLOAT:1.0 +INT:int897,parsertemp500306 +%*%(rand(parsertemp500306,int897,1.0,1.0),t(rowSums(^(Q,int677)))) +::STMT +LITERAL_FLOAT:1.0,2001.0 ++(2001.0,1.0) +::STMT +MATRIX:r,g,z +LITERAL_FLOAT:0.5 +*(0.5,sum(*(z,+(r,g)))) +::STMT +LITERAL_FLOAT:1.0E-14 +1.0E-14 +::STMT +LITERAL_FLOAT:9.999999999999998E-15 +9.999999999999998E-15 +::STMT +MATRIX:pearson_residual_sq +LITERAL_FLOAT:9950.0 +/(sum(pearson_residual_sq),9950.0) +::STMT +FLOAT:parsertemp72162,M +LITERAL_FLOAT:1.0 +*(+(parsertemp72162,1.0),M) +::STMT +MATRIX:g_Y,lambda,parsertemp171599,scale_X,beta +FLOAT:int223 ++(*(cast.FLOAT(diag(scale_X)),%*%(-(int223,parsertemp171599),g_Y)),*(cast.FLOAT(lambda),cast.FLOAT(beta))) +::STMT +MATRIX:S +LITERAL_FLOAT:2.0 +^(diag(S),2.0) +::STMT +MATRIX:R,ones +%*%(t(+(R,diag(ones))),+(R,diag(ones))) +::STMT +MATRIX:scale_X,shift_X,X +LITERAL_FLOAT:2.0 +%*%(X,*(*(2.0,scale_X),shift_X)) +::STMT +MATRIX:P +LITERAL_FLOAT:1.0 +<=(rowSums(P),1.0) +::STMT +MATRIX:ytest +LITERAL_FLOAT:1.0,2.0 +*(1.0,^(/(cast.FLOAT(ytest),1.0),2.0)) +::STMT +LITERAL_FLOAT:5.0,2001.0 ++(2001.0,5.0) +::STMT +MATRIX:out1,187_dX,parsertemp146955 +FLOAT:beta1,int533 +LITERAL_FLOAT:1.0 +*(-(1.0,beta1),colSums(*(>(out1,int533),*(parsertemp146955,187_dX)))) +::STMT +LITERAL_FLOAT:3.0,2001.0 ++(2001.0,3.0) +::STMT +MATRIX:d,od,X,logisticD +FLOAT:C ++(d,*(C,%*%(t(X),*(logisticD,od)))) +::STMT +MATRIX:M +LITERAL_FLOAT:2.0 +/(ncol(M),2.0) +::STMT +MATRIX:X_batch,maskd1,out1,185_dX,parsertemp146947,W2 +FLOAT:p,int850 +%*%(t(X_batch),*(*(>(out1,int850),/(maskd1,p)),%*%(*(parsertemp146947,185_dX),t(W2)))) +::STMT +MATRIX:M +sum(exp(-(M,max(M)))) +::STMT +FLOAT:int134,z,pp_CG,parsertemp170091 +LITERAL_FLOAT:0.5 +*(0.5,/(-(*(z,int134),sqrt(parsertemp170091)),pp_CG)) +::STMT +MATRIX:Grad +LITERAL_FLOAT:-1.0,2.0 +sum(^(*(Grad,-1.0),2.0)) +::STMT +MATRIX:sums +LITERAL_FLOAT:4.0 +/(sums,4.0) +::STMT +MATRIX:parsertemp221417 +FLOAT:float22 +LITERAL_FLOAT:0.1,2.0 +*(sum(^(-(float22,parsertemp221417),2.0)),0.1) +::STMT +MATRIX:t,parsertemp32854,parsertemp32848,Y,parsertemp32857,parsertemp32858 +cast.FLOAT(+(+(*(parsertemp32848,Y),*(t,Y)),*(*(t,parsertemp32854),+(parsertemp32857,parsertemp32858)))) +::STMT +MATRIX:lambda,parsertemp286549 +FLOAT:new_log_l +LITERAL_FLOAT:0.5 +-(new_log_l,*(0.5,cast.FLOAT(%*%(lambda,parsertemp286549)))) +::STMT +MATRIX:parsertemp220786,parsertemp220783 +FLOAT:int927 +sqrt(+(+(*(int927,parsertemp220786),rowSums(parsertemp220783)),t(rowSums(parsertemp220783)))) +::STMT +MATRIX:parsertemp500607,X,y,parsertemp500610 +t(-(%*%(X,*(parsertemp500607,parsertemp500610)),y)) +::STMT +MATRIX:s,parsertemp44016,d +LITERAL_FLOAT:2.0 +^(%*%(t(-(s,parsertemp44016)),d),2.0) +::STMT +MATRIX:output_values +FLOAT:log_odds,learning_rate ++(log_odds,*(learning_rate,sum(output_values))) +::STMT +MATRIX:X +LITERAL_FLOAT:2.0 +/(X,2.0) +::STMT +MATRIX:parsertemp561025 +LITERAL_FLOAT:0.0 +/(parsertemp561025,0.0) +::STMT +MATRIX:y_corr +LITERAL_FLOAT:1.0,2.0,0.5 +*(-(1.0,*(2.0,y_corr)),>(y_corr,0.5)) +::STMT +MATRIX:prec_chol +LITERAL_FLOAT:2.0 +t(^(prec_chol,2.0)) +::STMT +MATRIX:g_reg,p_CG +FLOAT:parsertemp170113,q_CG,int940,z,pq_CG,pp_CG,parsertemp170091 +*(+(+(*(parsertemp170113,pq_CG),*(z,q_CG)),sum(*(g_reg,p_CG))),/(-(*(z,int940),sqrt(parsertemp170091)),pp_CG)) +::STMT +MATRIX:parsertemp171083 +LITERAL_FLOAT:0.010328,-2.0 +*(sqrt(*(-2.0,parsertemp171083)),0.010328) +::STMT +FLOAT:parsertemp22454,parsertemp22485 +LITERAL_FLOAT:2.0 +exp(-(parsertemp22485,*(2.0,sqrt(parsertemp22454)))) +::STMT +MATRIX:sig_sq +sqrt(sig_sq) +::STMT +MATRIX:parsertemp171083 +LITERAL_FLOAT:-2.0 +sqrt(*(-2.0,parsertemp171083)) +::STMT +MATRIX:parsertemp31910,X +FLOAT:alpha +LITERAL_FLOAT:1.0 +*(-(1.0,alpha),-(/(nrow(X),t(parsertemp31910)),1.0)) +::STMT +MATRIX:252_Y,252_X,252_K +LITERAL_FLOAT:0.0 ++(*(-(0.0,cast.FLOAT(252_K)),-(cast.FLOAT(252_X),cast.FLOAT(252_X))),-(cast.FLOAT(252_Y),cast.FLOAT(252_Y))) +::STMT +LITERAL_FLOAT:1.0,2.0,3.0,2000.0 +*(*(-(2000.0,2.0),+(2000.0,1.0)),+(2000.0,3.0)) +::STMT +MATRIX:R,S,Grad +-(sum(*(S,Grad)),sum(*(S,R))) +::STMT +MATRIX:p,e,u,G +LITERAL_FLOAT:0.15000000000000002,0.85 ++(*(0.85,%*%(G,p)),*(0.15000000000000002,%*%(%*%(e,u),p))) +::STMT +MATRIX:f,parsertemp472172 +LITERAL_FLOAT:0.0 +rowSums(*(-(0.0,f),parsertemp472172)) +::STMT +FLOAT:int780,ss2,ssPrev,Xm,parsertemp265718 +LITERAL_FLOAT:4000.0 +/(/(-(+(Xm,ss2),*(int780,parsertemp265718)),4000.0),ssPrev) +::STMT +MATRIX:parsertemp107030 +LITERAL_FLOAT:1.0,7.0 ++(*(parsertemp107030,7.0),1.0) +::STMT +MATRIX:X,K +*(cast.FLOAT(K),-(cast.FLOAT(X),cast.FLOAT(X))) +::STMT +MATRIX:Xm +rowSums(t(Xm)) +::STMT +MATRIX:parsertemp436659 +t(rowSums(parsertemp436659)) +::STMT +LITERAL_FLOAT:1.0E-5 +1.0E-5 +::STMT +MATRIX:r,d,parsertemp43998 +FLOAT:int550 +LITERAL_FLOAT:2.0 +/(sum(^(r,2.0)),%*%(t(d),+(d,*(int550,parsertemp43998)))) +::STMT +LITERAL_FLOAT:32.0 +INT:int197,int136 +rand(int197,int136,32.0,32.0) +::STMT +MATRIX:X,tS +FLOAT:l +t(colSums(==(%*%(X,tS),l))) +::STMT +MATRIX:Y_prob,Y +LITERAL_FLOAT:0.0 +sum(*(<=(Y_prob,0.0),abs(Y))) +::STMT +MATRIX:jaccardDist,adjacency +FLOAT:threshold +&(adjacency,>=(jaccardDist,threshold)) +::STMT +FLOAT:m2 +LITERAL_FLOAT:1.0005002501250626,4.0 +^(sqrt(*(1.0005002501250626,m2)),4.0) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:1.0,750.0 +/(+(classFeatureCounts,1.0),+(rowSums(classFeatureCounts),750.0)) +::STMT +MATRIX:_sbcvar11,43_r,43_c +LITERAL_FLOAT:2.0,1000.0 +^(-(_sbcvar11,/(%*%(43_r,43_c),1000.0)),2.0) +::STMT +MATRIX:G,authorities,hubs +LITERAL_FLOAT:2.0 +^(-(/(%*%(G,authorities),max(hubs)),hubs),2.0) +::STMT +FLOAT:m2 +LITERAL_FLOAT:1.0005002501250626,3.0 +^(sqrt(*(1.0005002501250626,m2)),3.0) +::STMT +MATRIX:surv,n_risk +FLOAT:int594 +/(*(surv,sqrt(-(int594,surv))),sqrt(n_risk)) +::STMT +FLOAT:so_linear_approx +LITERAL_FLOAT:-0.5 +*(-0.5,so_linear_approx) +::STMT +FLOAT:delta +LITERAL_FLOAT:0.5 +*(0.5,delta) +::STMT +MATRIX:se_surv +FLOAT:z_alpha_2 +LITERAL_FLOAT:-1.0 +*(*(z_alpha_2,-1.0),se_surv) +::STMT +FLOAT:m2 +LITERAL_FLOAT:1.0004995004995005 +sqrt(*(1.0004995004995005,m2)) +::STMT +MATRIX:classCounts +LITERAL_FLOAT:80.0 +/(classCounts,80.0) +::STMT +MATRIX:parsertemp379565,m_iter_err_sum +FLOAT:i_process_item +LITERAL_FLOAT:-1.0,2.0 +*(2.0,/(*(-(parsertemp379565,m_iter_err_sum),-1.0),i_process_item)) +::STMT +MATRIX:X +LITERAL_FLOAT:1.0,2.0 +/(colSums(^(X,2.0)),-(nrow(X),1.0)) +::STMT +MATRIX:posSampleVariances,negSampleVariances +FLOAT:int681,int584,int92,int34 +LITERAL_FLOAT:7.996E9,2.0,3.37275E9 +/(^(+(/(posSampleVariances,int681),/(negSampleVariances,int34)),2.0),+(/(^(posSampleVariances,int584),7.996E9),/(^(negSampleVariances,int92),3.37275E9))) +::STMT +LITERAL_FLOAT:1.0,20.0 +-(20.0,1.0) +::STMT +MATRIX:scores,parsertemp145878 +/(exp(-(scores,parsertemp145878)),rowSums(exp(-(scores,parsertemp145878)))) +::STMT +MATRIX:t_gp,parsertemp171332,pt_gp,parsertemp171331,Y,the_gauss_exp,parsertemp171327,parsertemp171316 +FLOAT:one_over_sqrt_two_pi +LITERAL_FLOAT:2.0,0.25 +/(*(one_over_sqrt_two_pi,+(-(Y,parsertemp171327),*(parsertemp171331,parsertemp171332))),*(*(0.25,*(t_gp,parsertemp171316)),-(2.0,*(the_gauss_exp,pt_gp)))) +::STMT +MATRIX:ss +FLOAT:alpha +LITERAL_FLOAT:1.0,40.0 +*(-(1.0,alpha),-(/(40.0,ss),1.0)) +::STMT +LITERAL_FLOAT:0.3989422804014327 +0.3989422804014327 +::STMT +LITERAL_FLOAT:0.1 +0.1 +::STMT +LITERAL_FLOAT:-0.1 +-0.1 +::STMT +MATRIX:X +FLOAT:var_lag,parsertemp496688,parsertemp496694,var_coef,a0 +LITERAL_FLOAT:2.0 ++(parsertemp496694,/(^(cast.FLOAT(X),2.0),+(+(a0,parsertemp496688),*(var_coef,var_lag)))) +::STMT +MATRIX:parsertemp222331 +LITERAL_FLOAT:200.0 +/(parsertemp222331,200.0) +::STMT +MATRIX:parsertemp220903 +FLOAT:float857 +LITERAL_FLOAT:2.0,1.0E-5 +*(sum(^(-(float857,parsertemp220903),2.0)),1.0E-5) +::STMT +MATRIX:parsertemp399255,W4_rand +FLOAT:int818,int687 +LITERAL_FLOAT:0.08725945907447251 +%*%(*(0.08725945907447251,W4_rand),t(/(-(parsertemp399255,int687),+(parsertemp399255,int818)))) +::STMT +MATRIX:tmp +FLOAT:N,parsertemp274090 +LITERAL_FLOAT:0.0,1.0 +*(/(tmp,-(N,1.0)),-(1.0,<=(/(tmp,parsertemp274090),0.0))) +::STMT +MATRIX:W,H,parsertemp411105 +LITERAL_FLOAT:1.0E-8 ++(%*%(W,%*%(*(H,parsertemp411105),t(H))),1.0E-8) +::STMT +MATRIX:log_prob,X +LITERAL_FLOAT:1.8378770664093453,-0.5 +*(-0.5,+(*(ncol(X),1.8378770664093453),log_prob)) +::STMT +LITERAL_FLOAT:1.5000000000000002E-8 +1.5000000000000002E-8 +::STMT +MATRIX:parsertemp539203 +FLOAT:int993 +LITERAL_FLOAT:1.0,2.0,1.5 +max(^(/(*(parsertemp539203,int993),2.0),/(1.0,1.5))) +::STMT +FLOAT:width,x1,x2 +LITERAL_FLOAT:-1.0,2.0 +/(*(-1.0,^(-(x1,x2),2.0)),*(2.0,^(width,2.0))) +::STMT +MATRIX:images +LITERAL_FLOAT:255.0 +/(images,255.0) +::STMT +MATRIX:W,parsertemp411110,X,H +FLOAT:eps +*(W,/(%*%(X,t(H)),+(%*%(W,parsertemp411110),eps))) +::STMT +MATRIX:ytest +LITERAL_FLOAT:1.0 +/(cast.FLOAT(ytest),1.0) +::STMT +LITERAL_FLOAT:1.0,2.0,4.0,2003.0 +*(4.0,-(^(2003.0,2.0),1.0)) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:0.0,1.0 +-(exp(-(0.0,linear_terms)),1.0) +::STMT +FLOAT:m2X +LITERAL_FLOAT:1.0002795638803466 +*(m2X,1.0002795638803466) +::STMT +MATRIX:classCounts +LITERAL_FLOAT:100.0 +/(classCounts,100.0) +::STMT +MATRIX:parsertemp1904,y +LITERAL_FLOAT:-1.0 +sum(*(*(%*%(parsertemp1904,y),-1.0),*(%*%(parsertemp1904,y),-1.0))) +::STMT +MATRIX:means,Y,vars +LITERAL_FLOAT:2.0 +sum(/(^(-(Y,means),2.0),vars)) +::STMT +MATRIX:parsertemp409788,parsertemp409797 +LITERAL_FLOAT:0.0 +t(+(-(0.0,t(parsertemp409788)),t(colSums(parsertemp409797)))) +::STMT +MATRIX:parsertemp386438,neighbors +FLOAT:eps +LITERAL_FLOAT:0.0 +rowSums(*(<=(-(neighbors,parsertemp386438),eps),<(0.0,-(neighbors,parsertemp386438)))) +::STMT +MATRIX:obj,parsertemp44077 +FLOAT:int642,parsertemp44079 +LITERAL_FLOAT:2.0,0.5 +-(cast.FLOAT(obj),+(*(0.5,cast.FLOAT(parsertemp44077)),*(2.0,*(int642,parsertemp44079)))) +::STMT +MATRIX:weight +LITERAL_FLOAT:133.0 +/(weight,133.0) +::STMT +MATRIX:F +/(%*%(rowSums(F),colSums(F)),sum(F)) +::STMT +LITERAL_FLOAT:0.025 +0.025 +::STMT +FLOAT:42_m2X +LITERAL_FLOAT:1.001001001001001 +sqrt(*(42_m2X,1.001001001001001)) +::STMT +MATRIX:Y_Train,Y_Test +FLOAT:sumY,sumX,parsertemp251796,parsertemp251795 +abs(-(-(+(sumX,sumY),+(parsertemp251795,parsertemp251796)),+(sum(Y_Train),sum(Y_Test)))) +::STMT +MATRIX:V +FLOAT:var,mu +LITERAL_FLOAT:5.0 +>(V,+(mu,*(5.0,sqrt(var)))) +::STMT +MATRIX:V +FLOAT:var,mu +LITERAL_FLOAT:5.0 +<(V,-(mu,*(5.0,sqrt(var)))) +::STMT +MATRIX:p_CG +FLOAT:trust_delta_sq,z,pp_CG +sqrt(-(*(cast.FLOAT(p_CG),cast.FLOAT(p_CG)),*(pp_CG,-(z,trust_delta_sq)))) +::STMT +MATRIX:parsertemp171326,is_lt_pos,parsertemp171323,Y +FLOAT:float940 +LITERAL_FLOAT:0.3989422804014327 +*(0.3989422804014327,+(-(Y,*(parsertemp171326,is_lt_pos)),*(*(parsertemp171323,parsertemp171326),-(is_lt_pos,float940)))) +::STMT +FLOAT:vicinity,target_a0,a0 +LITERAL_FLOAT:1.0 ++(*(vicinity,target_a0),*(-(1.0,vicinity),a0)) +::STMT +MATRIX:_sbcvar92,220_r,220_c +LITERAL_FLOAT:0.0,1.0E-4 +*(==(/(%*%(220_r,220_c),sum(_sbcvar92)),0.0),1.0E-4) +::STMT +MATRIX:p,q,lambda +FLOAT:norm_r2 +*(/(norm_r2,cast.FLOAT(%*%(p,q))),+(q,*(lambda,p))) +::STMT +MATRIX:r +FLOAT:int383 +LITERAL_FLOAT:2.0,9.999999999999998E-15 +*(sum(^(-(int383,r),2.0)),9.999999999999998E-15) +::STMT +LITERAL_FLOAT:1.0,2.0,1500.0 +*(^(1500.0,2.0),-(1500.0,1.0)) +::STMT +MATRIX:B,parsertemp410245,X_t +LITERAL_FLOAT:0.0,2.0 +/(-(0.0,parsertemp410245),*(2.0,exp(%*%(X_t,B)))) +::STMT +MATRIX:r,Hd +FLOAT:c +LITERAL_FLOAT:0.0 +-(0.0,+(r,*(c,Hd))) +::STMT +MATRIX:Y +FLOAT:class +LITERAL_FLOAT:2.0 +*(2.0,==(Y,class)) +::STMT +MATRIX:qLow,length,qUp +LITERAL_FLOAT:0.0 +>(rowSums(|(<(length,qLow),>(length,qUp))),0.0) +::STMT +MATRIX:var_X_cols,parsertemp429917,parsertemp429915 +FLOAT:int636 +LITERAL_FLOAT:0.0,1.0,299.0 ++(*(/(-(parsertemp429915,parsertemp429917),299.0),-(1.0,<=(var_X_cols,int636))),<=(/(-(parsertemp429915,parsertemp429917),299.0),0.0)) +::STMT +MATRIX:parsertemp43635 +FLOAT:float100 +LITERAL_FLOAT:2.0 +sqrt(sum(^(+(float100,parsertemp43635),2.0))) +::STMT +FLOAT:window_size,n +LITERAL_FLOAT:2.0 ++(-(n,window_size),2.0) +::STMT +MATRIX:R,w +FLOAT:int794,int742 +INT:parsertemp31673,int163 ++(R,diag(*(rand(parsertemp31673,int163,int742,int794),cast.FLOAT(w)))) +::STMT +MATRIX:cumLeftHist,parsertemp132494,parsertemp132506,leftHist,outBucket ++(%*%(==(outBucket,t(parsertemp132494)),-(cumLeftHist,leftHist)),parsertemp132506) +::STMT +MATRIX:parsertemp72182 +LITERAL_FLOAT:1.0,8.0 ++(*(parsertemp72182,8.0),1.0) +::STMT +FLOAT:idx +LITERAL_FLOAT:1048.0 +-(1048.0,idx) +::STMT +MATRIX:parsertemp13626,parsertemp13624 +FLOAT:int992,43_q,int581 +LITERAL_FLOAT:1.0,1000.0 +/(sum(/(^(parsertemp13626,int581),/(parsertemp13624,int992))),*(1000.0,-(43_q,1.0))) +::STMT +MATRIX:subspace_idx,parsertemp72201 +FLOAT:subvector_size +LITERAL_FLOAT:1.0 +<(-(subspace_idx,*(parsertemp72201,subvector_size)),1.0) +::STMT +MATRIX:os,y,o +LITERAL_FLOAT:-1.0 +exp(*(*(y,-1.0),+(o,os))) +::STMT +MATRIX:atan_linear_terms +LITERAL_FLOAT:3.141592653589793,0.5 +-(0.5,/(atan_linear_terms,3.141592653589793)) +::STMT +MATRIX:linear_terms,Y +FLOAT:var_power +LITERAL_FLOAT:-1.0 +*(^(linear_terms,*(var_power,-1.0)),-(Y,linear_terms)) +::STMT +MATRIX:w,X,y +%*%(t(-(%*%(X,w),y)),-(%*%(X,w),y)) +::STMT +MATRIX:H,betamax,Hneg,Hpos,beta +FLOAT:INF,logU +LITERAL_FLOAT:0.0,2.0 +*(*(2.0,>=(-(H,logU),0.0)),==(+(*(Hpos,betamax),*(Hneg,beta)),INF)) +::STMT +LITERAL_FLOAT:1.0E-4 +1.0E-4 +::STMT +MATRIX:X,parsertemp16876 +FLOAT:epsilon,int288 ++(sqrt(rowSums(^(X,int288))),*(<(sqrt(parsertemp16876),epsilon),epsilon)) +::STMT +LITERAL_FLOAT:1400.0,20.0 +*(1400.0,20.0) +::STMT +MATRIX:lt_pos_neg +LITERAL_FLOAT:0.5 +-(0.5,lt_pos_neg) +::STMT +MATRIX:parsertemp389219,tmp,X,parsertemp389212 +FLOAT:int464 +LITERAL_FLOAT:1.0E-17 +/(-(%*%(tmp,X),parsertemp389212),+(sqrt(/(parsertemp389219,int464)),1.0E-17)) +::STMT +MATRIX:Y,linear_terms,vec1,is_y_0,parsertemp171270 +LITERAL_FLOAT:0.0 +-(-(/(+(Y,is_y_0),+(parsertemp171270,is_y_0)),==(Y,0.0)),*(*(Y,vec1),linear_terms)) +::STMT +MATRIX:Bx,Yd,Yu +/(-(Yu,Yd),*(Bx,Bx)) +::STMT +MATRIX:W +LITERAL_FLOAT:1.0,2.0,3.0 +*(*(-(sum(W),2.0),+(sum(W),1.0)),+(sum(round(W)),3.0)) +::STMT +MATRIX:cm,FD +FLOAT:n +LITERAL_FLOAT:1.0 ++(+(FD,==(cm,1.0)),==(t(cm),n)) +::STMT +MATRIX:r,alpha,Hd +*(-(r,*(cast.FLOAT(alpha),Hd)),-(r,*(cast.FLOAT(alpha),Hd))) +::STMT +MATRIX:X +LITERAL_FLOAT:2.0 +exp(*(2.0,X)) +::STMT +MATRIX:g,parsertemp169907 +FLOAT:parsertemp169913 +*(sum(*(+(g,parsertemp169907),+(g,parsertemp169907))),parsertemp169913) +::STMT +MATRIX:is_zero_y_corr,is_one_y_corr,parsertemp171113 +LITERAL_FLOAT:1.0,-0.36651292058166435 ++(-(parsertemp171113,*(-0.36651292058166435,+(is_zero_y_corr,is_one_y_corr))),/(is_one_y_corr,-(1.0,is_one_y_corr))) +::STMT +FLOAT:int112 +LITERAL_FLOAT:2.0 +INT:int809,parsertemp282730 +rand(parsertemp282730,int809,int112,2.0) +::STMT +MATRIX:vI +FLOAT:beg +LITERAL_FLOAT:1.0 +-(+(cast.FLOAT(vI),beg),1.0) +::STMT +MATRIX:parsertemp557211 +LITERAL_FLOAT:0.0 +==(diag(parsertemp557211),0.0) +::STMT +FLOAT:var,m4 +LITERAL_FLOAT:3.0,4.0 +-(/(m4,^(sqrt(var),4.0)),3.0) +::STMT +MATRIX:lambda,B_new +FLOAT:int37 +LITERAL_FLOAT:0.5 +*(0.5,sum(*(lambda,^(B_new,int37)))) +::STMT +MATRIX:parsertemp413082 +LITERAL_FLOAT:1.0 +-(max(round(parsertemp413082)),1.0) +::STMT +MATRIX:parsertemp410190,b,parsertemp410188 +cast.FLOAT(%*%(%*%(t(b),-(parsertemp410188,parsertemp410190)),b)) +::STMT +MATRIX:_sbcvar96,_sbcvar95,221_CMeans +FLOAT:int455 +LITERAL_FLOAT:2.0 +sum(*(%*%(_sbcvar95,_sbcvar96),^(+(221_CMeans,int455),2.0))) +::STMT +MATRIX:X_Xd_exp_Xb_rev_agg,parsertemp410050,d_r_rev,Hd_2_num,D_r_rev +colSums(*(-(/(X_Xd_exp_Xb_rev_agg,D_r_rev),/(Hd_2_num,parsertemp410050)),d_r_rev)) +::STMT +MATRIX:scale_X,w,ssX_p_CG,X +%*%(diag(scale_X),%*%(t(X),*(w,%*%(X,ssX_p_CG)))) +::STMT +MATRIX:lambda,parsertemp148883,parsertemp148882 +FLOAT:int12 +LITERAL_FLOAT:2.0 +sum(^(+(%*%(parsertemp148882,parsertemp148883),*(lambda,int12)),2.0)) +::STMT +MATRIX:img_in +FLOAT:h +LITERAL_FLOAT:2.0 +/(-(nrow(img_in),h),2.0) +::STMT +FLOAT:var_power,link_power +LITERAL_FLOAT:2.0 +/(-(2.0,var_power),link_power) +::STMT +FLOAT:dummy_coding_beg_col,dummy_coding_end_col +LITERAL_FLOAT:1.0 ++(-(dummy_coding_end_col,dummy_coding_beg_col),1.0) +::STMT +MATRIX:y_batch,parsertemp146892 +FLOAT:int243 +/(sum(*(-(int243,y_batch),parsertemp146892)),nrow(y_batch)) +::STMT +LITERAL_FLOAT:1.421413741 +1.421413741 +::STMT +MATRIX:P,parsertemp220889,Z,parsertemp220891 +FLOAT:int562,int464,int63,parsertemp220894 +rowSums(*(-(*(P,int562),/(Z,parsertemp220894)),*(/(int464,parsertemp220891),+(parsertemp220889,int63)))) +::STMT +MATRIX:316_unnorm_probs,316_scores +abs(-(/(exp(316_scores),rowSums(316_unnorm_probs)),/(exp(316_scores),rowSums(316_unnorm_probs)))) +::STMT +MATRIX:ss +LITERAL_FLOAT:1.0,40.0 +-(/(40.0,ss),1.0) +::STMT +FLOAT:idx +LITERAL_FLOAT:1024.0 +-(1024.0,idx) +::STMT +FLOAT:current_hash_value +LITERAL_FLOAT:1.0,3.0 +-(3.0,+(current_hash_value,1.0)) +::STMT +MATRIX:tmp +FLOAT:int239,N +LITERAL_FLOAT:0.0,1.0 +-(1.0,<=(/(tmp,-(N,int239)),0.0)) +::STMT +MATRIX:r,d,parsertemp43999 +LITERAL_FLOAT:2.0 +/(sum(^(r,2.0)),cast.FLOAT(%*%(t(d),+(d,parsertemp43999)))) +::STMT +MATRIX:parsertemp387501 +LITERAL_FLOAT:1.0 +cast.FLOAT(+(parsertemp387501,1.0)) +::STMT +FLOAT:cvk +LITERAL_FLOAT:300.0 +/(300.0,cvk) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:105.0,1.0 ++(rowSums(classFeatureCounts),*(105.0,1.0)) +::STMT +MATRIX:w_X,z_LS +LITERAL_FLOAT:1000.0 +*(/(1000.0,cast.FLOAT(%*%(w_X,z_LS))),z_LS) +::STMT +MATRIX:addedE,addedX +/(sum(addedE),nrow(addedX)) +::STMT +MATRIX:X_Train,X_Test,X,Y,Y_Train,Y_Test +-(-(+(sum(X),sum(Y)),+(sum(X_Train),sum(X_Test))),+(sum(Y_Train),sum(Y_Test))) +::STMT +MATRIX:resp +LITERAL_FLOAT:2.22E-16 ++(colSums(resp),2.22E-16) +::STMT +MATRIX:X_train +LITERAL_FLOAT:2.0 +sqrt(/(2.0,ncol(X_train))) +::STMT +FLOAT:b,c,rad +LITERAL_FLOAT:-1.0,2.0 +/(*(*(2.0,c),-1.0),+(b,rad)) +::STMT +LITERAL_FLOAT:0.802853 +0.802853 +::STMT +MATRIX:parsertemp394992,parsertemp394989,W3_rand +LITERAL_FLOAT:0.21483446221182986 +t(%*%(*(0.21483446221182986,W3_rand),t(/(parsertemp394989,parsertemp394992)))) +::STMT +MATRIX:Y +FLOAT:check_max,check_min +LITERAL_FLOAT:2.0 +-(*(/(2.0,-(check_max,check_min)),Y),/(+(check_min,check_max),-(check_max,check_min))) +::STMT +MATRIX:2434_2432_Y,W4_rand +FLOAT:float108 +LITERAL_FLOAT:2.0 +*(2.0,t(%*%(*(float108,W4_rand),t(2434_2432_Y)))) +::STMT +MATRIX:inactive_set,w +LITERAL_FLOAT:0.0 +abs(-(inactive_set,!=(w,0.0))) +::STMT +MATRIX:p,e,u,G +FLOAT:alpha +LITERAL_FLOAT:1.0 ++(*(alpha,%*%(G,p)),*(-(1.0,alpha),%*%(%*%(e,u),p))) +::STMT +LITERAL_FLOAT:80.0,1200.0 +*(1200.0,80.0) +::STMT +FLOAT:n +LITERAL_FLOAT:1.0,2.0,4.0 +-(+(-(n,4.0),1.0),2.0) +::STMT +MATRIX:parsertemp443564,parsertemp443530,parsertemp443567,mean,parsertemp443973,X +FLOAT:float834 ++(/(-(%*%(parsertemp443564,X),%*%(parsertemp443567,mean)),sum(+(parsertemp443530,float834))),diag(parsertemp443973)) +::STMT +MATRIX:2701_mask +LITERAL_FLOAT:0.5 +/(2701_mask,0.5) +::STMT +MATRIX:X,mask +FLOAT:p +/(*(X,mask),p) +::STMT +MATRIX:X,parsertemp382984 +LITERAL_FLOAT:0.0 +-(ncol(X),sum(!=(t(parsertemp382984),0.0))) +::STMT +MATRIX:parsertemp2782,parsertemp2786 +FLOAT:dd,parsertemp2779,step_sz,wd +-(step_sz,/(-(+(wd,parsertemp2779),sum(parsertemp2782)),+(dd,sum(parsertemp2786)))) +::STMT +MATRIX:parsertemp410245,parsertemp410247 +LITERAL_FLOAT:0.0,2.0,0.6666666666666666 +^(/(-(0.0,parsertemp410245),*(2.0,exp(parsertemp410247))),0.6666666666666666) +::STMT +FLOAT:factor_up,parsertemp195892,int529 +LITERAL_FLOAT:1.0,2.0 +/(-(-(*(int529,factor_up),parsertemp195892),1.0),2.0) +::STMT +MATRIX:p,q,V,parsertemp1939 +FLOAT:norm_r2 +LITERAL_FLOAT:1.0E-8 +*(/(norm_r2,cast.FLOAT(%*%(parsertemp1939,q))),+(%*%(t(V),%*%(V,p)),*(1.0E-8,p))) +::STMT +MATRIX:upd_W1 +LITERAL_FLOAT:0.8 +*(0.8,upd_W1) +::STMT +MATRIX:X +FLOAT:x +/(-(x,X),-(X,X)) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:1.0,50.0 +/(+(classFeatureCounts,1.0),+(rowSums(classFeatureCounts),*(50.0,1.0))) +::STMT +FLOAT:C,Hf,Wf +LITERAL_FLOAT:2.0 +sqrt(/(2.0,*(*(C,Hf),Wf))) +::STMT +MATRIX:R +FLOAT:int548 +LITERAL_FLOAT:0.0 +sum(==(colSums(!=(R,int548)),0.0)) +::STMT +MATRIX:parsertemp539203 +LITERAL_FLOAT:-1.0,1.0,2.0,1.5 +^(/(*(parsertemp539203,-1.0),2.0),/(1.0,1.5)) +::STMT +MATRIX:X,Y,K +FLOAT:x,int118 +*(+(*(*(K,int118),-(X,X)),-(Y,Y)),/(-(x,X),-(X,X))) +::STMT +MATRIX:cdf_min_distances +LITERAL_FLOAT:0.0,1.0 +INT:int159,num_runs +*(rand(int159,num_runs,0.0,1.0),cdf_min_distances) +::STMT +FLOAT:m2X,m2Y +LITERAL_FLOAT:1.000010000100001 +*(sqrt(*(m2X,1.000010000100001)),sqrt(*(m2Y,1.000010000100001))) +::STMT +MATRIX:Y,linear_terms +LITERAL_FLOAT:-1.0 +*(rowSums(Y),exp(*(exp(linear_terms),-1.0))) +::STMT +FLOAT:cmLabels +LITERAL_FLOAT:1.000100010001 +*(cmLabels,1.000100010001) +::STMT +MATRIX:sv,out +LITERAL_FLOAT:0.5 +*(0.5,sum(*(*(sv,out),*(sv,out)))) +::STMT +MATRIX:y +LITERAL_FLOAT:1.0,-1.0 +*(/(1.0,nrow(y)),*(y,-1.0)) +::STMT +MATRIX:current_node +FLOAT:cur_node_index ++(cur_node_index,cast.FLOAT(current_node)) +::STMT +MATRIX:Kss,parsertemp387410 +sqrt(abs(-(cast.FLOAT(Kss),cast.FLOAT(parsertemp387410)))) +::STMT +MATRIX:resp +LITERAL_FLOAT:2.22E-16 +sum(+(colSums(resp),2.22E-16)) +::STMT +MATRIX:xs +LITERAL_FLOAT:100.0,4.5 +-(100.0,sum(>=(xs,4.5))) +::STMT +MATRIX:parsertemp410978,W,H +rowSums(/(*(H,t(parsertemp410978)),t(colSums(W)))) +::STMT +MATRIX:z,beta ++(beta,cast.FLOAT(z)) +::STMT +MATRIX:X +FLOAT:int902 +t(sqrt(rowSums(^(X,int902)))) +::STMT +MATRIX:R +FLOAT:minSup +LITERAL_FLOAT:0.0 +&(>=(R,minSup),>(R,0.0)) +::STMT +MATRIX:parsertemp12898,CFreqs +FLOAT:int517 +LITERAL_FLOAT:1.0 +/(sum(*(CFreqs,^(parsertemp12898,int517))),-(nrow(CFreqs),1.0)) +::STMT +FLOAT:float605,int893,float444,int152 +LITERAL_FLOAT:1.0,3.0,6.0,2001.0 +/(*(*(6.0,2001.0),-(2001.0,1.0)),*(*(-(int893,float444),+(int152,float605)),+(2001.0,3.0))) +::STMT +MATRIX:parsertemp150380 +LITERAL_FLOAT:0.0,0.16 +sum(==(<(abs(parsertemp150380),0.16),0.0)) +::STMT +MATRIX:237_CVars,parsertemp29525,237_CFreqs,parsertemp29520 +LITERAL_FLOAT:1.0,10000.0 +/(/(sum(*(237_CFreqs,parsertemp29520)),-(nrow(237_CFreqs),1.0)),/(sum(*(parsertemp29525,237_CVars)),-(10000.0,nrow(237_CFreqs)))) +::STMT +LITERAL_FLOAT:96.0 +INT:int523,int607 +rand(int523,int607,96.0,96.0) +::STMT +MATRIX:colDuplicates,adjacency +LITERAL_FLOAT:0.0 +*(colDuplicates,>(rowSums(adjacency),0.0)) +::STMT +MATRIX:cdf_min_distances,random_row +t(colSums(<(cdf_min_distances,*(random_row,cdf_min_distances)))) +::STMT +MATRIX:s,d,alpha +t(+(s,*(cast.FLOAT(alpha),d))) +::STMT +MATRIX:parsertemp472298,I +LITERAL_FLOAT:0.0 +==(!=(*(t(parsertemp472298),I),0.0),0.0) +::STMT +MATRIX:parsertemp171318 +FLOAT:int591 +LITERAL_FLOAT:2.0,0.15915494309189535 +*(exp(/(-(int591,parsertemp171318),2.0)),0.15915494309189535) +::STMT +FLOAT:m2X +LITERAL_FLOAT:1.0005 +*(m2X,1.0005) +::STMT +MATRIX:H,betamax,beta +FLOAT:logU +LITERAL_FLOAT:0.0 ++(*(>=(-(H,logU),0.0),betamax),*(<(-(H,logU),0.0),beta)) +::STMT +MATRIX:key_unique,key +==(key_unique,t(key)) +::STMT +MATRIX:e_r_rev_agg,parsertemp409787,parsertemp409796 +LITERAL_FLOAT:0.0 ++(-(0.0,t(colSums(parsertemp409787))),t(colSums(/(parsertemp409796,e_r_rev_agg)))) +::STMT +MATRIX:parsertemp132498,offset,parsertemp132494,rightHist,mask,outBucket +LITERAL_FLOAT:1.0 +/(-(-(offset,%*%(mask,parsertemp132498)),1.0),%*%(==(outBucket,t(parsertemp132494)),rightHist)) +::STMT +MATRIX:r,parsertemp44050 +FLOAT:norm_r2 +LITERAL_FLOAT:2.0 +/(sum(^(-(r,parsertemp44050),2.0)),norm_r2) +::STMT +MATRIX:y_prob,ones_ctg +LITERAL_FLOAT:1.0 +*(y_prob,%*%(y_prob,-(1.0,diag(ones_ctg)))) +::STMT +MATRIX:tmp +LITERAL_FLOAT:1.0 +*(1.0,cast.FLOAT(%*%(t(tmp),tmp))) +::STMT +MATRIX:X +LITERAL_FLOAT:1.0,2.0 +-(exp(*(2.0,X)),1.0) +::STMT +MATRIX:X +FLOAT:x +LITERAL_FLOAT:1.0 +-(1.0,/(-(x,X),-(X,X))) +::STMT +MATRIX:p,A,r,parsertemp51660 +FLOAT:norm_r2 ++(r,*(/(norm_r2,sum(parsertemp51660)),%*%(t(A),%*%(A,p)))) +::STMT +MATRIX:dout1,mb1 +FLOAT:192_beta1 +LITERAL_FLOAT:1.0 ++(*(192_beta1,mb1),*(-(1.0,192_beta1),colSums(dout1))) +::STMT +FLOAT:parsertemp31330 +LITERAL_FLOAT:9999.0,10000.0 +/(*(parsertemp31330,10000.0),9999.0) +::STMT +MATRIX:X +LITERAL_FLOAT:0.0,2.0 +^(-(0.0,sum(X)),2.0) +::STMT +MATRIX:w +sum(abs(w)) +::STMT +MATRIX:ytest,yhat +LITERAL_FLOAT:1.0 +/(-(cast.FLOAT(ytest),cast.FLOAT(yhat)),1.0) +::STMT +LITERAL_FLOAT:1.00001 +1.00001 +::STMT +MATRIX:ss,se,e +LITERAL_FLOAT:1.0,20.0 +-(/(/(se,ss),/(sum(e),20.0)),1.0) +::STMT +FLOAT:parsertemp65,parsertemp66,mu +LITERAL_FLOAT:5.0 ++(mu,*(5.0,sqrt(/(parsertemp65,parsertemp66)))) +::STMT +MATRIX:BLOCKS +FLOAT:current_hash_value +LITERAL_FLOAT:1.0 +-(nrow(BLOCKS),+(current_hash_value,1.0)) +::STMT +MATRIX:parsertemp170158,parsertemp170136 +FLOAT:r_CG,g_reg,parsertemp170165,278_sq_root_d,z,parsertemp170150 +LITERAL_FLOAT:0.5 ++(*(0.5,*(cast.FLOAT(z),+(r_CG,g_reg))),*(+(+(parsertemp170165,z),sum(parsertemp170158)),/(+(parsertemp170150,278_sq_root_d),sum(parsertemp170136)))) +::STMT +MATRIX:W,parsertemp411198,X,H,parsertemp411200 +LITERAL_FLOAT:1.0E-8 +%*%(/(X,+(%*%(W,H),1.0E-8)),t(/(*(H,parsertemp411198),t(parsertemp411200)))) +::STMT +FLOAT:num_records +LITERAL_FLOAT:1.0,960.0 +-(1.0,/(960.0,num_records)) +::STMT +MATRIX:H2_prime,H3_prime,W2,W3,parsertemp389610 +%*%(*(H2_prime,%*%(*(H3_prime,parsertemp389610),W3)),W2) +::STMT +MATRIX:R,dssp +FLOAT:4_n +LITERAL_FLOAT:1.0 +-(/(4_n,+(R,dssp)),1.0) +::STMT +FLOAT:neg_log_l_change_predicted,log_l_change +LITERAL_FLOAT:-1.0 +/(*(log_l_change,-1.0),neg_log_l_change_predicted) +::STMT +MATRIX:tmp_c +FLOAT:i +LITERAL_FLOAT:1.0,12.0 ++(tmp_c,*(-(i,1.0),12.0)) +::STMT +LITERAL_FLOAT:300.0,1.0 ++(300.0,1.0) +::STMT +MATRIX:s,sts,d,parsertemp44023 +FLOAT:delta2 +LITERAL_FLOAT:2.0 ++(^(%*%(t(s),d),2.0),*(cast.FLOAT(%*%(parsertemp44023,d)),-(delta2,cast.FLOAT(sts)))) +::STMT +MATRIX:U,V,X,parsertemp382841,row_nonzeros +FLOAT:reg,int524 ++(%*%(*(!=(X,int524),-(parsertemp382841,X)),V),*(*(reg,U),row_nonzeros)) +::STMT +MATRIX:C,Xm,parsertemp265702 +%*%(colSums(%*%(%*%(Xm,parsertemp265702),t(C))),rowSums(t(Xm))) +::STMT +MATRIX:Y +FLOAT:parsertemp185166 +-(cast.MATRIX(max(Y)),parsertemp185166) +::STMT +MATRIX:V,y +LITERAL_FLOAT:-1.0 +*(*(%*%(t(V),y),-1.0),-1.0) +::STMT +MATRIX:n_event_stratum,n_risk_stratum,n_risk +LITERAL_FLOAT:2.0 +*(*(^(n_risk_stratum,2.0),*(n_risk,n_event_stratum)),-(n_risk_stratum,n_event_stratum)) +::STMT +MATRIX:A,scale_lambda,scale_X,shift_X,parsertemp115882 +LITERAL_FLOAT:0.001 ++(+(%*%(diag(scale_X),t(parsertemp115882)),%*%(shift_X,A)),diag(*(scale_lambda,0.001))) +::STMT +MATRIX:parsertemp286680,lambda,scale_X,gXY,beta +cast.FLOAT(%*%(t(+(scale_X,parsertemp286680)),+(*(scale_X,gXY),*(lambda,beta)))) +::STMT +MATRIX:parsertemp443534,resp,parsertemp443566,parsertemp443533,X,weight +LITERAL_FLOAT:2.22E-16 +/(-(%*%(t(X),X),%*%(*(parsertemp443566,weight),/(parsertemp443533,parsertemp443534))),sum(+(colSums(resp),2.22E-16))) +::STMT +MATRIX:_funvar2124,parsertemp437267,parsertemp437277,parsertemp437272 +-(+(_funvar2124,parsertemp437267),+(parsertemp437272,parsertemp437277)) +::STMT +LITERAL_FLOAT:0.19999999999999996 +0.19999999999999996 +::STMT +MATRIX:m_err +/(colSums(m_err),sum(colSums(m_err))) +::STMT +FLOAT:check_max,check_min +/(+(check_min,check_max),-(check_max,check_min)) +::STMT +MATRIX:p_CG,z +*(*(cast.FLOAT(z),sum(p_CG)),*(cast.FLOAT(z),sum(p_CG))) +::STMT +LITERAL_FLOAT:100.0,0.8 +*(100.0,0.8) +::STMT +MATRIX:s,w +cast.FLOAT(%*%(t(+(w,s)),+(w,s))) +::STMT +FLOAT:int443,int775,weight,prob_true,prob_false +LITERAL_FLOAT:1.0 +*(weight,-(1.0,+(^(prob_true,int443),^(prob_false,int775)))) +::STMT +MATRIX:prec_chol,X +LITERAL_FLOAT:2.0 +%*%(rowSums(*(X,X)),t(^(prec_chol,2.0))) +::STMT +MATRIX:tmp,leftIdx +%*%(tmp,%*%(t(tmp),leftIdx)) +::STMT +LITERAL_FLOAT:0.2 +0.2 +::STMT +MATRIX:w,X,y +LITERAL_FLOAT:-1.0 +*(*(y,-1.0),%*%(X,w)) +::STMT +MATRIX:parsertemp220844,ZERODIAG,beta +rowSums(*(exp(*(parsertemp220844,beta)),ZERODIAG)) +::STMT +MATRIX:scale_X,w,ssX_p_CG,X +*(scale_X,%*%(t(X),*(w,%*%(X,ssX_p_CG)))) +::STMT +MATRIX:newbeta,lambda +FLOAT:int214 +LITERAL_FLOAT:0.5 +*(0.5,cast.FLOAT(%*%(t(lambda),^(newbeta,int214)))) +::STMT +MATRIX:79_77_X_row_norm,Y_block,parsertemp17170,79_77_Y_row_norm,X_block +LITERAL_FLOAT:0.9 +>(/(%*%(X_block,t(Y_block)),%*%(+(79_77_X_row_norm,parsertemp17170),t(79_77_Y_row_norm))),0.9) +::STMT +LITERAL_FLOAT:0.0 +INT:int576,int409 +cast.FLOAT(rand(int409,int576,0.0,0.0)) +::STMT +MATRIX:var_X_cols,parsertemp1517,parsertemp1515 +FLOAT:int932,int191,int490,n +LITERAL_FLOAT:0.0,1.0 ++(*(/(-(parsertemp1515,parsertemp1517),-(n,int932)),-(1.0,<=(var_X_cols,int191))),<=(/(-(parsertemp1515,parsertemp1517),-(n,int490)),0.0)) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:0.0 +*(^(exp(linear_terms),0.0),exp(linear_terms)) +::STMT +MATRIX:parsertemp42200,parsertemp42201,_sbcvar330 +FLOAT:meanX +LITERAL_FLOAT:1.0,0.5 +*(/(_sbcvar330,-(sum(_sbcvar330),1.0)),-(+(-(parsertemp42200,parsertemp42201),0.5),meanX)) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:-1.0 +*(exp(*(linear_terms,-1.0)),-1.0) +::STMT +MATRIX:parsertemp552345,tab,catTotal +LITERAL_FLOAT:-1.0 +*(*(/(tab,catTotal),-1.0),parsertemp552345) +::STMT +MATRIX:m_active_flag_tmp,m_active_flag +LITERAL_FLOAT:1.0 +-(>=(+(m_active_flag,m_active_flag_tmp),1.0),1.0) +::STMT +FLOAT:n_false,n_true,n_vars +/(+(n_true,n_false),n_vars) +::STMT +MATRIX:G,minDist +LITERAL_FLOAT:0.0 +*(!=(G,0.0),minDist) +::STMT +LITERAL_FLOAT:0.05 +0.05 +::STMT +LITERAL_FLOAT:-0.05 +-0.05 +::STMT +MATRIX:id +diag(==(id,cast.FLOAT(id))) +::STMT +MATRIX:grad +LITERAL_FLOAT:2.0 +sqrt(sum(^(grad,2.0))) +::STMT +MATRIX:select,d_r_rev,X_exp_Xb_rev_agg,D_r_rev +colSums(*(/(%*%(select,X_exp_Xb_rev_agg),D_r_rev),d_r_rev)) +::STMT +MATRIX:parsertemp43993,d,X,Hd,parsertemp44001 +*(cast.FLOAT(/(sum(parsertemp43993),%*%(parsertemp44001,Hd))),%*%(X,d)) +::STMT +MATRIX:parsertemp10964,C +LITERAL_FLOAT:100.0 +*(/(sum(==(parsertemp10964,C)),100.0),100.0) +::STMT +MATRIX:parsertemp31024,parsertemp31022 +FLOAT:int113 +LITERAL_FLOAT:2.0,99.0 +^(/(-(colSums(parsertemp31022),*(int113,parsertemp31024)),99.0),2.0) +::STMT +MATRIX:r,parsertemp44050 +LITERAL_FLOAT:2.0 +sqrt(sum(^(-(r,parsertemp44050),2.0))) +::STMT +MATRIX:Y_counts,Y +%*%(Y_counts,/(colSums(Y),sum(Y_counts))) +::STMT +MATRIX:prec_chol,mu +LITERAL_FLOAT:2.0 +*(mu,^(prec_chol,2.0)) +::STMT +LITERAL_FLOAT:0.4 +0.4 +::STMT +MATRIX:classFeatureCounts +rowSums(classFeatureCounts) +::STMT +MATRIX:parsertemp116065,p,r,lambda,shift_X,parsertemp116069 +FLOAT:norm_r2 ++(r,*(/(norm_r2,sum(parsertemp116069)),+(+(parsertemp116065,shift_X),*(lambda,p)))) +::STMT +MATRIX:TKC +cast.FLOAT(/(TKC,TKC)) +::STMT +MATRIX:p_LS,X +%*%(%*%(t(X),X),p_LS) +::STMT +FLOAT:m2,wt,float608 +LITERAL_FLOAT:3.0 +^(sqrt(/(*(m2,wt),-(wt,float608))),3.0) +::STMT +MATRIX:p_LS,tmp +FLOAT:norm_r2_LS +/(norm_r2_LS,cast.FLOAT(%*%(t(p_LS),tmp))) +::STMT +LITERAL_FLOAT:0.6546536707079771 +0.6546536707079771 +::STMT +FLOAT:parsertemp149336,obj,parsertemp149333,float101,qk,parsertemp149340 +/(-(obj,+(+(parsertemp149333,parsertemp149336),*(float101,parsertemp149340))),qk) +::STMT +MATRIX:d_r_rev,X_exp_Xb_rev_agg,D_r_rev +t(colSums(*(/(X_exp_Xb_rev_agg,D_r_rev),d_r_rev))) +::STMT +FLOAT:log_l,new_log_l ++(abs(log_l),abs(new_log_l)) +::STMT +MATRIX:d,parsertemp410053 +cast.FLOAT(%*%(t(d),t(colSums(parsertemp410053)))) +::STMT +MATRIX:Y_counts,means,parsertemp560511 +sum(*(Y_counts,rowSums(*(means,parsertemp560511)))) +::STMT +MATRIX:Y,Xd,Xw +FLOAT:step_sz +*(Y,+(Xw,*(step_sz,Xd))) +::STMT +MATRIX:2697_b,parsertemp459149,2697_W,outd3 +-(+(%*%(outd3,2697_W),2697_b),parsertemp459149) +::STMT +MATRIX:B,X_t +LITERAL_FLOAT:2.0 +*(2.0,exp(%*%(X_t,B))) +::STMT +MATRIX:D,beta +LITERAL_FLOAT:0.0 +exp(*(-(0.0,D),beta)) +::STMT +MATRIX:r,s,grad +LITERAL_FLOAT:-0.5 +*(-0.5,-(%*%(t(s),grad),%*%(t(s),r))) +::STMT +MATRIX:p,lambda,parsertemp1590,parsertemp1589 +sum(*(p,+(%*%(parsertemp1589,parsertemp1590),*(lambda,p)))) +::STMT +LITERAL_FLOAT:0.050000000000000044 +0.050000000000000044 +::STMT +FLOAT:m2,float774,wt +LITERAL_FLOAT:4.0 +^(sqrt(/(*(m2,wt),-(wt,float774))),4.0) +::STMT +MATRIX:d,parsertemp43996,parsertemp43997 +FLOAT:C +%*%(t(d),+(d,*(C,%*%(parsertemp43996,parsertemp43997)))) +::STMT +LITERAL_FLOAT:750.0 +*(750.0,750.0) +::STMT +MATRIX:X_Xd_exp_Xb_rev_agg,select,d_r_rev,X_exp_Xb_rev_agg,D_r_rev,Xd_exp_Xb_rev_agg +FLOAT:int929 +*(-(/(%*%(select,X_Xd_exp_Xb_rev_agg),D_r_rev),/(*(X_exp_Xb_rev_agg,Xd_exp_Xb_rev_agg),^(D_r_rev,int929))),d_r_rev) +::STMT +MATRIX:dout1 +LITERAL_FLOAT:2.0 +^(colSums(dout1),2.0) +::STMT +MATRIX:X +FLOAT:int174 +max(sqrt(rowSums(^(X,int174)))) +::STMT +MATRIX:p_LS,parsertemp170552 +FLOAT:lambda_LS +*(cast.FLOAT(p_LS),+(*(cast.FLOAT(parsertemp170552),cast.FLOAT(p_LS)),*(lambda_LS,cast.FLOAT(p_LS)))) +::STMT +LITERAL_FLOAT:2.0,0.5,-0.5 +INT:int818,int737 +sum(^(rand(int737,int818,-0.5,0.5),2.0)) +::STMT +FLOAT:nFeats +LITERAL_FLOAT:3.141592653589793,2.0 +^(*(2.0,3.141592653589793),nFeats) +::STMT +MATRIX:2701_mask,2700_W,parsertemp459178,2699_dtemp,2703_X,2702_X +FLOAT:float56,int493 +%*%(t(2703_X),*(*(>(2702_X,int493),/(2701_mask,float56)),%*%(-(2699_dtemp,parsertemp459178),t(2700_W)))) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,64.0 +-(+(i,64.0),1.0) +::STMT +LITERAL_FLOAT:0.8 +0.8 +::STMT +MATRIX:rowSums_X_sq +LITERAL_FLOAT:6.144102863722254 +/(6.144102863722254,max(sqrt(rowSums_X_sq))) +::STMT +MATRIX:parsertemp44025,s,d +FLOAT:delta2 ++(*(%*%(t(s),d),%*%(t(s),d)),*(%*%(t(d),d),-(delta2,%*%(parsertemp44025,s)))) +::STMT +FLOAT:sample_block_size,num_samples +LITERAL_FLOAT:1.0 ++(*(sample_block_size,num_samples),1.0) +::STMT +MATRIX:b4,W4,parsertemp389337 +LITERAL_FLOAT:2.0 +*(2.0,t(+(%*%(W4,parsertemp389337),b4))) +::STMT +MATRIX:g_Y,scale_X,X +LITERAL_FLOAT:0.0 +*(scale_X,-(0.0,%*%(t(X),g_Y))) +::STMT +MATRIX:features,beta_unscaled +FLOAT:intercept +LITERAL_FLOAT:3.0 +*(3.0,+(%*%(features,beta_unscaled),intercept)) +::STMT +MATRIX:R +LITERAL_FLOAT:0.0,4.0 +&(>(R,0.0),>=(R,4.0)) +::STMT +MATRIX:tmp,X,parsertemp389212 +-(%*%(tmp,X),parsertemp389212) +::STMT +LITERAL_FLOAT:0.16 +0.16 +::STMT +MATRIX:vectors,pq_result +LITERAL_FLOAT:2.0 +colSums(rowSums(^(-(vectors,pq_result),2.0))) +::STMT +MATRIX:parsertemp285516 +FLOAT:pp,parsertemp285518,parsertemp285520 +LITERAL_FLOAT:-1.0 +/(-(*(sum(parsertemp285516),-1.0),sqrt(-(parsertemp285518,parsertemp285520))),pp) +::STMT +MATRIX:221_present_domain_vals_mat,parsertemp27770 +t(sqrt(%*%(221_present_domain_vals_mat,parsertemp27770))) +::STMT +MATRIX:WM,Y +/(sum(*(Y,WM)),sum(WM)) +::STMT +MATRIX:X_nonzero_ind +LITERAL_FLOAT:0.0,6.0 +-(6.0,sum(!=(rowSums(X_nonzero_ind),0.0))) +::STMT +MATRIX:m_active_flag_tmp +LITERAL_FLOAT:1.0 +sum(!=(m_active_flag_tmp,1.0)) +::STMT +MATRIX:d,parsertemp410052,d_r_rev +%*%(t(d),t(colSums(*(parsertemp410052,d_r_rev)))) +::STMT +MATRIX:p,q +FLOAT:norm_r2 +/(norm_r2,sum(*(p,+(q,q)))) +::STMT +MATRIX:X +LITERAL_FLOAT:0.0 +-(0.0,t(colSums(X))) +::STMT +MATRIX:Xtest_dists +LITERAL_FLOAT:0.0,1.0 +*(<=(Xtest_dists,1.0),<(0.0,Xtest_dists)) +::STMT +MATRIX:parsertemp393595,tmp,X,parsertemp393475,parsertemp393466 +LITERAL_FLOAT:1.0,1.0E-17 +-(/(-(exp(parsertemp393595),1.0),+(exp(parsertemp393595),1.0)),/(-(%*%(tmp,X),parsertemp393466),+(sqrt(parsertemp393475),1.0E-17))) +::STMT +FLOAT:m2 +LITERAL_FLOAT:1.0005002501250626,5.0 +*(5.0,sqrt(*(1.0005002501250626,m2))) +::STMT +MATRIX:parsertemp31104,parsertemp31106 +FLOAT:int713 +LITERAL_FLOAT:1.0,2000.0 +/(/(-(colSums(parsertemp31104),*(int713,parsertemp31106)),-(2000.0,1.0)),2000.0) +::STMT +FLOAT:K +LITERAL_FLOAT:301.0 +*(301.0,K) +::STMT +MATRIX:lambda,g,parsertemp285556,beta +cast.FLOAT(%*%(t(+(g,parsertemp285556)),+(g,*(lambda,beta)))) +::STMT +MATRIX:distT +LITERAL_FLOAT:0.0 +sum(!=(distT,0.0)) +::STMT +MATRIX:parsertemp137844 +rev(t(parsertemp137844)) +::STMT +MATRIX:d_r +t(rev(d_r)) +::STMT +FLOAT:B,R,s +LITERAL_FLOAT:1.0 +/(/(B,R),+(s,1.0)) +::STMT +LITERAL_FLOAT:2.0,150.0 +^(150.0,2.0) +::STMT +MATRIX:n_risk_stratum,n_risk_i2j,V1 +FLOAT:I_i1i2 +sum(*(V1,-(I_i1i2,/(n_risk_i2j,n_risk_stratum)))) +::STMT +FLOAT:float246,d_eee,x +LITERAL_FLOAT:2.302585092994046 +*(x,exp(*(2.302585092994046,-(float246,d_eee)))) +::STMT +MATRIX:flip_neg,is_LT_infinite,Y_prob,Y,parsertemp171293 +*(Y,%*%(+(*(Y_prob,parsertemp171293),is_LT_infinite),flip_neg)) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:1.0,500.0 +/(+(classFeatureCounts,1.0),+(rowSums(classFeatureCounts),*(500.0,1.0))) +::STMT +MATRIX:g_Y,parsertemp171599,scale_X,shift_X,gXY +FLOAT:int545 ++(%*%(diag(scale_X),%*%(*(parsertemp171599,int545),g_Y)),%*%(shift_X,gXY)) +::STMT +MATRIX:cdf_min_distances,random_row +<(cdf_min_distances,*(random_row,cdf_min_distances)) +::STMT +MATRIX:parsertemp1532,y +LITERAL_FLOAT:2.0,9.999999999999998E-15 +*(sum(^(%*%(parsertemp1532,y),2.0)),9.999999999999998E-15) +::STMT +MATRIX:clusterMembers,adjacency +LITERAL_FLOAT:0.0 +>(*(clusterMembers,>(rowSums(adjacency),0.0)),0.0) +::STMT +MATRIX:ts +FLOAT:q +-(q,*(cast.FLOAT(ts),cast.FLOAT(ts))) +::STMT +FLOAT:max_features,n +/(^(n,max_features),n) +::STMT +LITERAL_FLOAT:1.000010000100001 +1.000010000100001 +::STMT +LITERAL_FLOAT:0.02 +0.02 +::STMT +FLOAT:i +LITERAL_FLOAT:100.0 +*(*(i,100.0),100.0) +::STMT +MATRIX:parsertemp410118,g0_1 +LITERAL_FLOAT:2.0 +sum(^(+(g0_1,t(parsertemp410118)),2.0)) +::STMT +MATRIX:d,dtd,parsertemp44021 +FLOAT:sts,delta2 +LITERAL_FLOAT:2.0 +sqrt(+(^(%*%(parsertemp44021,d),2.0),*(cast.FLOAT(dtd),-(delta2,sts)))) +::STMT +LITERAL_FLOAT:64.0 +INT:int753,int690 +rand(int690,int753,64.0,64.0) +::STMT +MATRIX:287_x,287_y,one_featureX +LITERAL_FLOAT:2.0 +<(one_featureX,/(+(cast.FLOAT(287_x),cast.FLOAT(287_y)),2.0)) +::STMT +MATRIX:Ileft +FLOAT:min_leaf +>=(rowSums(Ileft),min_leaf) +::STMT +MATRIX:parsertemp472315,parsertemp472326 +FLOAT:beg ++(-(nrow(parsertemp472315),cast.FLOAT(parsertemp472326)),beg) +::STMT +MATRIX:parsertemp402078,W3_rand +FLOAT:int259,int106 +LITERAL_FLOAT:0.1092173494617922 +%*%(*(0.1092173494617922,W3_rand),t(/(-(parsertemp402078,int106),+(parsertemp402078,int259)))) +::STMT +MATRIX:X +LITERAL_FLOAT:-1.0 +*(t(X),-1.0) +::STMT +LITERAL_FLOAT:0.10940797384659613 +0.10940797384659613 +::STMT +FLOAT:m2 +LITERAL_FLOAT:1.0005 +*(1.0005,m2) +::STMT +MATRIX:_sbcvar11 +LITERAL_FLOAT:1000.0 +-(_sbcvar11,/(%*%(rowSums(_sbcvar11),colSums(_sbcvar11)),1000.0)) +::STMT +LITERAL_FLOAT:200.0,1.0 ++(200.0,1.0) +::STMT +FLOAT:i2,n +LITERAL_FLOAT:24.0 +-(n,*(24.0,i2)) +::STMT +MATRIX:mb1,parsertemp146957,188_dX +FLOAT:beta1 +LITERAL_FLOAT:1.0 ++(*(beta1,mb1),*(-(1.0,beta1),colSums(*(parsertemp146957,188_dX)))) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:1.0 +^(exp(linear_terms),1.0) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:-1.0 +^(exp(linear_terms),-1.0) +::STMT +MATRIX:tmp,Y +1-*(Y,tmp) +::STMT +MATRIX:p_CG +FLOAT:trust_delta_sq,int444,z,pp_CG +-(*(*(cast.FLOAT(z),sum(p_CG)),*(cast.FLOAT(z),sum(p_CG))),*(pp_CG,-(^(z,int444),trust_delta_sq))) +::STMT +MATRIX:linear_terms +FLOAT:var_power +LITERAL_FLOAT:-1.0 +^(linear_terms,*(var_power,-1.0)) +::STMT +MATRIX:y_hat,X_adapted +FLOAT:parsertemp176421,k,parsertemp176418 +|(<(X_adapted,-(sqrt(parsertemp176421),*(k,y_hat))),>(X_adapted,+(sqrt(parsertemp176418),*(k,y_hat)))) +::STMT +MATRIX:X_adapted,yhat +FLOAT:int587,int291,parsertemp176418 +|(<(X_adapted,-(sqrt(parsertemp176418),*(int587,yhat))),>(X_adapted,+(sqrt(parsertemp176418),*(int291,yhat)))) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:0.0 +^(exp(linear_terms),0.0) +::STMT +MATRIX:parsertemp413082 +LITERAL_FLOAT:1.0,21.0 +*(21.0,-(max(round(parsertemp413082)),1.0)) +::STMT +MATRIX:y_train,prediction +LITERAL_FLOAT:0.5 +sum(==(y_train,>(prediction,0.5))) +::STMT +MATRIX:r,d,parsertemp43998 +FLOAT:C +/(sum(*(r,r)),%*%(t(d),+(d,*(C,parsertemp43998)))) +::STMT +MATRIX:parsertemp31029,parsertemp31031 +FLOAT:int586 +LITERAL_FLOAT:149.0,2.0 +^(/(-(colSums(parsertemp31029),*(int586,parsertemp31031)),149.0),2.0) +::STMT +MATRIX:_sbcvar264,_sbcvar262 +FLOAT:int495,int563,parsertemp31330 +LITERAL_FLOAT:9999.0 +/(sum(*(-(_sbcvar262,int495),_sbcvar264)),*(9999.0,/(*(parsertemp31330,int563),9999.0))) +::STMT +MATRIX:p,r +FLOAT:norm_r2,int58 +*(/(sum(^(r,int58)),norm_r2),p) +::STMT +MATRIX:A,CVars,CFreqs +FLOAT:W,int623,parsertemp12882,float120 +LITERAL_FLOAT:1.0 +/(sum(*(-(CFreqs,int623),CVars)),*(-(nrow(A),1.0),/(*(parsertemp12882,W),-(W,float120)))) +::STMT +MATRIX:X +LITERAL_FLOAT:1.0 +-(nrow(X),1.0) +::STMT +MATRIX:is_zero_y_corr,is_one_y_corr,parsertemp171100,parsertemp171086,parsertemp171097 +FLOAT:float279,float397 +LITERAL_FLOAT:1.0 +-(+(*(+(parsertemp171086,parsertemp171097),-(float279,parsertemp171100)),/(is_one_y_corr,-(float397,is_one_y_corr))),/(is_zero_y_corr,-(1.0,is_zero_y_corr))) +::STMT +FLOAT:502_strideh,502_padh,parsertemp193094,int645,502_Hf +LITERAL_FLOAT:0.0 ++(+(-(*(502_strideh,parsertemp193094),*(int645,502_padh)),502_Hf),0.0) +::STMT +MATRIX:posSamples,posSampleMeans +FLOAT:int928,int149 +LITERAL_FLOAT:1.0,7000.0 +/(-(colSums(^(posSamples,int149)),*(7000.0,^(posSampleMeans,int928))),-(7000.0,1.0)) +::STMT +MATRIX:R +LITERAL_FLOAT:0.0,32.0 +&(>(R,0.0),>=(R,32.0)) +::STMT +MATRIX:posSampleVariances,negSampleVariances +FLOAT:int366,int542,int961,int998 +LITERAL_FLOAT:3.42951E11,2.0,3.37275E9 +/(^(+(/(posSampleVariances,int961),/(negSampleVariances,int366)),2.0),+(/(^(posSampleVariances,int542),3.42951E11),/(^(negSampleVariances,int998),3.37275E9))) +::STMT +MATRIX:d,od,X,logisticD +LITERAL_FLOAT:2.0 ++(d,*(2.0,%*%(t(X),*(logisticD,od)))) +::STMT +MATRIX:_sbcvar78 +LITERAL_FLOAT:10000.0 +-(_sbcvar78,/(%*%(rowSums(_sbcvar78),colSums(_sbcvar78)),10000.0)) +::STMT +MATRIX:is_zero_y_corr,is_one_y_corr,parsertemp171113 +LITERAL_FLOAT:-0.36651292058166435 +-(parsertemp171113,*(-0.36651292058166435,+(is_zero_y_corr,is_one_y_corr))) +::STMT +FLOAT:C,K +LITERAL_FLOAT:2.0 +^(*(C,K),2.0) +::STMT +FLOAT:link_power +LITERAL_FLOAT:-1.0 +*(-1.0,link_power) +::STMT +MATRIX:parsertemp44025,s,d +FLOAT:delta2 ++(*(%*%(t(s),d),%*%(t(s),d)),*(%*%(t(d),d),-(delta2,%*%(parsertemp44025,s)))) +::STMT +MATRIX:Y_prob,Y +LITERAL_FLOAT:0.0 +*(<=(Y_prob,0.0),abs(Y)) +::STMT +FLOAT:approx_sample_size +LITERAL_FLOAT:10.0 +*(10.0,sqrt(approx_sample_size)) +::STMT +MATRIX:is_row_in_samples,parsertemp77566 +LITERAL_FLOAT:7075.0 +-(7075.0,*(is_row_in_samples,parsertemp77566)) +::STMT +MATRIX:dout1 +FLOAT:192_beta2 +LITERAL_FLOAT:1.0,2.0 +*(-(1.0,192_beta2),^(colSums(dout1),2.0)) +::STMT +FLOAT:parsertemp170147,parsertemp170145,p_CG,z +LITERAL_FLOAT:-1.0,2.0 +/(-(*(*(z,p_CG),-1.0),sqrt(-(parsertemp170145,parsertemp170147))),sum(^(p_CG,2.0))) +::STMT +FLOAT:m2,float248,mu,wt +/(sqrt(/(*(m2,wt),-(wt,float248))),mu) +::STMT +FLOAT:x,parsertemp169816,float183 +round(*(x,exp(*(float183,parsertemp169816)))) +::STMT +MATRIX:scale_lambda,X +LITERAL_FLOAT:1.0E-7 ++(%*%(t(X),X),diag(*(scale_lambda,1.0E-7))) +::STMT +MATRIX:cdf_min_distances +LITERAL_FLOAT:0.0,1.0 +INT:int795,num_runs +<(cdf_min_distances,*(rand(int795,num_runs,0.0,1.0),cdf_min_distances)) +::STMT +FLOAT:trust_delta_sq,p_CG,z,pp_CG +sqrt(-(*(*(z,p_CG),*(z,p_CG)),*(pp_CG,-(z,trust_delta_sq)))) +::STMT +MATRIX:A +LITERAL_FLOAT:1.0E-4 +<=(abs(-(A,t(A))),+(1.0E-4,abs(t(A)))) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:1.0 ++(1.0,exp(linear_terms)) +::STMT +MATRIX:parsertemp220889,Y,parsertemp221025,parsertemp220891 +FLOAT:int275,int359,int130 +LITERAL_FLOAT:1.0 +/(*(/(1.0,+(Y,int130)),+(diag(parsertemp221025),1.0)),sum(*(/(int275,parsertemp220891),+(parsertemp220889,int359)))) +::STMT +FLOAT:int242,lratio_t +LITERAL_FLOAT:1.0,50.0 +-(1.0,exp(/(*(lratio_t,int242),50.0))) +::STMT +MATRIX:Y_prob +FLOAT:int909 +LITERAL_FLOAT:0.0,1.0 ++(*(Y_prob,-(1.0,<=(Y_prob,int909))),<=(Y_prob,0.0)) +::STMT +MATRIX:m_err +/(colSums(m_err),cast.FLOAT(rowSums(colSums(m_err)))) +::STMT +LITERAL_FLOAT:1.0,1000.0 +/(1000.0,-(1000.0,1.0)) +::STMT +MATRIX:parsertemp389186,parsertemp389189 +LITERAL_FLOAT:1.0,2.0 +^(/(-(exp(parsertemp389186),1.0),+(exp(parsertemp389189),1.0)),2.0) +::STMT +MATRIX:logisticnew +LITERAL_FLOAT:1.0 +-(1.0,logisticnew) +::STMT +MATRIX:W1_rand,stds,parsertemp397732 +LITERAL_FLOAT:0.086386842558136 +t(%*%(*(0.086386842558136,W1_rand),t(/(parsertemp397732,stds)))) +::STMT +MATRIX:parsertemp183431,X +FLOAT:N +LITERAL_FLOAT:1.0 +*(/(N,-(N,1.0)),%*%(t(/(parsertemp183431,N)),/(colSums(X),N))) +::STMT +MATRIX:s,w +t(+(w,s)) +::STMT +FLOAT:norm_grad +LITERAL_FLOAT:0.1 +*(0.1,norm_grad) +::STMT +MATRIX:I1 +LITERAL_FLOAT:2.0 +*(2.0,cast.FLOAT(I1)) +::STMT +MATRIX:Nc +==(Nc,max(Nc)) +::STMT +MATRIX:parsertemp175077,parsertemp175081,R1 +LITERAL_FLOAT:1.0E-6 +<(abs(-(R1,/(parsertemp175077,parsertemp175081))),1.0E-6) +::STMT +FLOAT:parsertemp386966 +sum(cast.MATRIX(parsertemp386966)) +::STMT +FLOAT:n_components,cov_param,n_features ++(+(cov_param,*(n_features,n_components)),n_components) +::STMT +LITERAL_FLOAT:0.282842712474619 +0.282842712474619 +::STMT +LITERAL_FLOAT:1.0,0.8 +-(1.0,0.8) +::STMT +MATRIX:X,K +LITERAL_FLOAT:-1.0 +*(*(K,-1.0),-(X,X)) +::STMT +MATRIX:parsertemp397837,W4_rand +FLOAT:int375,int658 +LITERAL_FLOAT:0.0873148795050037 +%*%(*(0.0873148795050037,W4_rand),t(/(-(parsertemp397837,int658),+(parsertemp397837,int375)))) +::STMT +MATRIX:parsertemp42200,parsertemp42201,F +FLOAT:int25,int329,meanX +LITERAL_FLOAT:1.0 +*(/(F,-(sum(F),1.0)),-(+(-(parsertemp42200,parsertemp42201),/(int329,int25)),meanX)) +::STMT +MATRIX:parsertemp170136 +FLOAT:trust_delta_sq,p_CG,z +sqrt(-(*(*(z,p_CG),*(z,p_CG)),*(sum(parsertemp170136),-(z,trust_delta_sq)))) +::STMT +MATRIX:parsertemp220853,parsertemp220854,betamax,Hneg,Hpos,beta +LITERAL_FLOAT:0.0,3.4011973816621555,1.0E20 +*(>=(-(+(parsertemp220853,parsertemp220854),3.4011973816621555),0.0),!=(+(*(Hpos,betamax),*(Hneg,beta)),1.0E20)) +::STMT +MATRIX:X,Y,K +-(*(K,-(X,X)),-(Y,Y)) +::STMT +MATRIX:R +FLOAT:minSup +LITERAL_FLOAT:0.0 +&(>(R,0.0),>=(R,minSup)) +::STMT +MATRIX:std,rad,dtd +/(-(rad,std),dtd) +::STMT +MATRIX:lambda,B,S +LITERAL_FLOAT:2.0 +sum(*(lambda,^(+(B,S),2.0))) +::STMT +MATRIX:R,S,parsertemp40218,parsertemp40215 +FLOAT:level +-(+(R,rowSums(==(parsertemp40215,level))),rowSums(==(%*%(S,parsertemp40218),level))) +::STMT +LITERAL_FLOAT:1.0,2.0,2000.0 +*(^(2000.0,2.0),-(2000.0,1.0)) +::STMT +FLOAT:a,x +LITERAL_FLOAT:2.0 +*(a,^(x,2.0)) +::STMT +MATRIX:hubs +LITERAL_FLOAT:2.0 +abs(sum(^(-(hubs,hubs),2.0))) +::STMT +MATRIX:is_unsafe,parsertemp1518 +FLOAT:parsertemp1519,int493 +LITERAL_FLOAT:0.0 +sqrt(+(*(/(parsertemp1518,parsertemp1519),-(int493,is_unsafe)),<=(/(parsertemp1518,parsertemp1519),0.0))) +::STMT +MATRIX:diff,mask +LITERAL_FLOAT:0.0 +*(diff,==(mask,0.0)) +::STMT +MATRIX:parsertemp73634 +LITERAL_FLOAT:16.0,1.0 ++(*(parsertemp73634,16.0),1.0) +::STMT +LITERAL_FLOAT:0.16823164622761327 +0.16823164622761327 +::STMT +MATRIX:Y +LITERAL_FLOAT:0.0 ++(rowSums(Y),==(rowSums(Y),0.0)) +::STMT +MATRIX:simplex +LITERAL_FLOAT:4.0 +/(-(rowSums(simplex),simplex),4.0) +::STMT +MATRIX:s,parsertemp44016,d +FLOAT:parsertemp44015 +cast.FLOAT(%*%(t(-(s,parsertemp44016)),-(s,*(parsertemp44015,d)))) +::STMT +FLOAT:eta,s +LITERAL_FLOAT:-1.0 +^(eta,*(s,-1.0)) +::STMT +MATRIX:2814_t +FLOAT:parsertemp477829,parsertemp477814,2814_K,int626,2814_X,2814_Y,inp_x +*(cast.FLOAT(2814_t),+(*(-(2814_K,2814_Y),-(int626,parsertemp477814)),*(+(parsertemp477829,2814_Y),/(inp_x,2814_X)))) +::STMT +MATRIX:X +LITERAL_FLOAT:0.0 +-(0.0,t(X)) +::STMT +MATRIX:parsertemp409789,parsertemp409798 +FLOAT:int986 +LITERAL_FLOAT:2.0 +sum(^(+(*(parsertemp409789,int986),t(parsertemp409798)),2.0)) +::STMT +MATRIX:scale_X,X,parsertemp115854 +LITERAL_FLOAT:0.0 +*(-(0.0,/(t(parsertemp115854),nrow(X))),scale_X) +::STMT +MATRIX:s,parsertemp44016,d +FLOAT:parsertemp44015,delta2 +-(delta2,%*%(t(-(s,parsertemp44016)),-(s,*(parsertemp44015,d)))) +::STMT +MATRIX:CFreqs1,present_domain_vals_mat,parsertemp27487 +LITERAL_FLOAT:1.0 +sum(*(-(%*%(present_domain_vals_mat,CFreqs1),1.0),%*%(present_domain_vals_mat,parsertemp27487))) +::STMT +MATRIX:W,X,H,parsertemp411101 +FLOAT:eps +/(%*%(t(W),X),+(%*%(%*%(parsertemp411101,W),H),eps)) +::STMT +MATRIX:classFeatureCounts +FLOAT:float640,int90 +LITERAL_FLOAT:1.0 +INT:int694,int227 +/(+(classFeatureCounts,1.0),%*%(+(rowSums(classFeatureCounts),*(int90,float640)),rand(int227,int694,1.0,1.0))) +::STMT +MATRIX:tmp +FLOAT:parsertemp477715,X,x,Y,K +LITERAL_FLOAT:1.0 ++(*(-(*(K,X),-(Y,Y)),-(1.0,/(parsertemp477715,X))),*(cast.FLOAT(tmp),/(-(x,X),-(X,X)))) +::STMT +MATRIX:t,parsertemp171083 +FLOAT:float208,float321 +LITERAL_FLOAT:0.189269,1.432788 ++(1.432788,*(sqrt(*(float321,parsertemp171083)),+(0.189269,*(t,float208)))) +::STMT +LITERAL_FLOAT:0.05469029540078189 +0.05469029540078189 +::STMT +MATRIX:parsertemp22268,parsertemp22266 +FLOAT:q,int631,int578 +LITERAL_FLOAT:1.0,10000.0 +/(sum(/(^(parsertemp22268,int578),/(parsertemp22266,int631))),*(10000.0,-(q,1.0))) +::STMT +MATRIX:b2,176_mask,W2,175_out +FLOAT:p ++(%*%(/(*(175_out,176_mask),p),W2),b2) +::STMT +FLOAT:window_size,q,parsertemp181039,parsertemp181046 +LITERAL_FLOAT:1.0,2.0 +*(*(2.0,window_size),-(1.0,/(-(q,parsertemp181039),*(window_size,parsertemp181046)))) +::STMT +FLOAT:std,arch_coef,noise,a0 +LITERAL_FLOAT:2.0 ++(a0,*(arch_coef,^(*(noise,std),2.0))) \ No newline at end of file