Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite Discovery and Generation Framework #2207

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/main/java/org/apache/sysds/api/DMLOptions.java
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ public class DMLOptions {
public int[] statsNGramSizes = { 3 }; // Default n-gram tuple sizes
public int statsTopKNGrams = 10; // How many of the most heavy hitting n-grams are displayed
public boolean statsNGramsUseLineage = true; // If N-Grams use lineage for data-dependent tracking
public boolean applyGeneratedRewrites = false; // If generated rewrites should be applied
public boolean fedStats = false; // Whether to record and print the federated statistics
public int fedStatsCount = 10; // Default federated statistics count
public boolean memStats = false; // max memory statistics
Expand Down Expand Up @@ -246,6 +247,8 @@ else if (lineageType.equalsIgnoreCase("debugger"))
}
}

dmlOptions.applyGeneratedRewrites = line.hasOption("applyGeneratedRewrites");

dmlOptions.fedStats = line.hasOption("fedStats");
if (dmlOptions.fedStats) {
String fedStatsCount = line.getOptionValue("fedStats");
Expand Down Expand Up @@ -372,6 +375,7 @@ private static Options createCLIOptions() {
Option ngramsOpt = OptionBuilder//.withArgName("ngrams")
.withDescription("monitors and reports the most occurring n-grams; -ngrams <comma separated n's> <topK>")
.hasOptionalArgs(2).create("ngrams");
Option applyGeneratedRewritesOpt = OptionBuilder.withArgName("applyGeneratedRewrites").withDescription("if automatically generated rewrites should be applied").create("applyGeneratedRewrites");
Option fedStatsOpt = OptionBuilder.withArgName("count")
.withDescription("monitors and reports summary execution statistics of federated workers; heavy hitter <count> is 10 unless overridden; default off")
.hasOptionalArg().create("fedStats");
Expand Down Expand Up @@ -434,6 +438,7 @@ private static Options createCLIOptions() {
options.addOption(cleanOpt);
options.addOption(statsOpt);
options.addOption(ngramsOpt);
options.addOption(applyGeneratedRewritesOpt);
options.addOption(fedStatsOpt);
options.addOption(memOpt);
options.addOption(explainOpt);
Expand Down
15 changes: 14 additions & 1 deletion src/main/java/org/apache/sysds/api/DMLScript.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
import java.util.Date;
import java.util.Map;
import java.util.Scanner;
import java.util.function.BiConsumer;
import java.util.function.Function;

import org.apache.commons.cli.AlreadySelectedException;
import org.apache.commons.cli.HelpFormatter;
Expand Down Expand Up @@ -106,6 +108,7 @@ public class DMLScript
public static int STATISTICS_TOP_K_NGRAMS = DMLOptions.defaultOptions.statsTopKNGrams;
// Set if N-Grams use lineage for data-dependent tracking
public static boolean STATISTICS_NGRAMS_USE_LINEAGE = DMLOptions.defaultOptions.statsNGramsUseLineage;
public static boolean APPLY_GENERATED_REWRITES = DMLOptions.defaultOptions.applyGeneratedRewrites;
// Set statistics maximum wrap length
public static int STATISTICS_MAX_WRAP_LEN = 30;
// Enable/disable to print federated statistics
Expand Down Expand Up @@ -168,6 +171,9 @@ public class DMLScript
public static String _uuid = IDHandler.createDistributedUniqueID();
private static final Log LOG = LogFactory.getLog(DMLScript.class.getName());

public static Function<DMLProgram, Boolean> preHopInterceptor = null; // Intercepts HOPs before they are rewritten
public static Function<DMLProgram, Boolean> hopInterceptor = null; // Intercepts HOPs after they are rewritten

///////////////////////////////
// public external interface
////////
Expand Down Expand Up @@ -261,6 +267,7 @@ public static boolean executeScript( String[] args )
STATISTICS_NGRAMS = dmlOptions.statsNGrams;
STATISTICS_NGRAM_SIZES = dmlOptions.statsNGramSizes;
STATISTICS_TOP_K_NGRAMS = dmlOptions.statsTopKNGrams;
APPLY_GENERATED_REWRITES = dmlOptions.applyGeneratedRewrites;
FED_STATISTICS = dmlOptions.fedStats;
FED_STATISTICS_COUNT = dmlOptions.fedStatsCount;
JMLC_MEM_STATISTICS = dmlOptions.memStats;
Expand Down Expand Up @@ -456,9 +463,15 @@ private static void execute(String dmlScriptStr, String fnameOptConfig, Map<Stri

//init working directories (before usage by following compilation steps)
initHadoopExecution( ConfigurationManager.getDMLConfig() );


if (preHopInterceptor != null && !preHopInterceptor.apply(prog))
return;

//Step 5: rewrite HOP DAGs (incl IPA and memory estimates)
dmlt.rewriteHopsDAG(prog);

if (hopInterceptor != null && !hopInterceptor.apply(prog))
return;

//Step 6: construct lops (incl exec type and op selection)
dmlt.constructLops(prog);
Expand Down
102 changes: 102 additions & 0 deletions src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,30 @@ public static DataGenOp copyDataGenOp( DataGenOp inputGen, double scale, double

return datagen;
}

public static Hop createDataGenOpFromDims( Hop rows, Hop cols, double value ) {
Hop val = new LiteralOp(value);

HashMap<String, Hop> params = new HashMap<>();
params.put(DataExpression.RAND_ROWS, rows);
params.put(DataExpression.RAND_COLS, cols);
params.put(DataExpression.RAND_MIN, val);
params.put(DataExpression.RAND_MAX, val);
params.put(DataExpression.RAND_PDF, new LiteralOp(DataExpression.RAND_PDF_UNIFORM));
params.put(DataExpression.RAND_LAMBDA, new LiteralOp(-1.0));
params.put(DataExpression.RAND_SPARSITY, new LiteralOp(1.0));
params.put(DataExpression.RAND_SEED, new LiteralOp(DataGenOp.UNSPECIFIED_SEED) );

//note internal refresh size information
Hop datagen = new DataGenOp(OpOpDG.RAND, new DataIdentifier("tmp"), params);
datagen.setBlocksize(1000);
//copyLineNumbers(rowInput, datagen);

if( value==0 )
datagen.setNnz(0);

return datagen;
}

public static Hop createDataGenOp( Hop rowInput, Hop colInput, double value )
{
Expand Down Expand Up @@ -661,6 +685,84 @@ public static BinaryOp createBinary(Hop input1, Hop input2, OpOp2 op, boolean ou
bop.refreshSizeInformation();
return bop;
}

// To fix issues with createBinary, which does not always correctly set value types (e.g. INT-MATRIX+FLOAT-SCALAR -> bop(+)::INT)
public static BinaryOp createAutoGeneratedBinary(Hop input1, Hop input2, OpOp2 op) {
Hop mainInput = input1.getDataType().isMatrix() ? input1 :
input2.getDataType().isMatrix() ? input2 : input1;
BinaryOp bop = new BinaryOp(mainInput.getName(), getImplicitDataType(input1, input2),
getImplicitValueType(input1, input2), op, input1, input2);
//cleanup value type for relational operations
if( bop.isPPredOperation() && bop.getDataType().isScalar() )
bop.setValueType(ValueType.BOOLEAN);
bop.setOuterVectorOperation(false);
bop.setBlocksize(mainInput.getBlocksize());
copyLineNumbers(mainInput, bop);
bop.refreshSizeInformation();
return bop;
}

public static DataType getImplicitDataType(Hop... inputs) {
for (int i = 0; i < inputs.length; i++)
if (inputs[i].getDataType().isMatrix())
return inputs[i].getDataType();

return inputs[0].getDataType();
}

public static ValueType getImplicitValueType(Hop... inputs) {
ValueType out = null;
for (int i = 0; i < inputs.length; i++) {
switch (inputs[i].getValueType()) {
case FP64:
return inputs[i].getValueType();
case FP32:
out = inputs[i].getValueType();
break;
case INT64:
out = implicitValueType(out, ValueType.INT64);
break;
case INT32:
out = implicitValueType(out, ValueType.INT32);
break;
case BOOLEAN:
out = implicitValueType(out, ValueType.BOOLEAN);
break;
}
}

return out == null ? inputs[0].getValueType() : out;
}

private static ValueType implicitValueType(ValueType type1, ValueType type2) {
int rank1 = getTypeRank(type1);
int rank2 = getTypeRank(type2);

if (rank1 == Integer.MIN_VALUE && rank2 == Integer.MIN_VALUE)
return null;

return rank1 > rank2 ? type1 : type2;
}

private static int getTypeRank(ValueType vt) {
if (vt == null)
return Integer.MIN_VALUE;

switch (vt) {
case FP64:
return 5;
case FP32:
return 4;
case INT64:
return 3;
case INT32:
return 2;
case BOOLEAN:
return 1;
}

return Integer.MIN_VALUE;
}

public static AggUnaryOp createSum( Hop input ) {
return createAggUnaryOp(input, AggOp.SUM, Direction.RowCol);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
import org.apache.sysds.conf.CompilerConfig.ConfigType;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.rewriter.generated.GeneratedRewriteClass;
import org.apache.sysds.hops.rewriter.generated.RewriteAutomaticallyGenerated;
import org.apache.sysds.hops.rewriter.dml.DMLExecutor;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
Expand Down Expand Up @@ -83,6 +86,9 @@ public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites)
_dagRuleSet.add( new RewriteCommonSubexpressionElimination() );
if( OptimizerUtils.ALLOW_CONSTANT_FOLDING )
_dagRuleSet.add( new RewriteConstantFolding() ); //dependency: cse
if ( DMLScript.APPLY_GENERATED_REWRITES ) {
_dagRuleSet.add(new RewriteAutomaticallyGenerated(new GeneratedRewriteClass()));
}
if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION )
_dagRuleSet.add( new RewriteAlgebraicSimplificationStatic() ); //dependencies: cse
if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION ) //dependency: simplifications (no need to merge leafs again)
Expand Down Expand Up @@ -124,6 +130,9 @@ public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites)
if ( DMLScript.USE_ACCELERATOR ){
_dagRuleSet.add( new RewriteGPUSpecificOps() ); // gpu-specific rewrites
}
if ( DMLScript.APPLY_GENERATED_REWRITES ) {
_dagRuleSet.add(new RewriteAutomaticallyGenerated(new GeneratedRewriteClass()));
}
if ( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES) {
_dagRuleSet.add( new RewriteMatrixMultChainOptimization() ); //dependency: cse
if( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES2 )
Expand Down
Loading
Loading