Skip to content

Commit

Permalink
rewrite ... not allowed if federated in
Browse files Browse the repository at this point in the history
  • Loading branch information
Baunsgaard committed Jan 17, 2025
1 parent 67e43ee commit d9ebcf0
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.apache.sysds.common.Types.AggOp;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.Direction;
import org.apache.sysds.common.Types.ExecType;
import org.apache.sysds.common.Types.OpOp1;
import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.OpOp3;
Expand Down Expand Up @@ -209,6 +210,8 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst)
//process childs recursively after rewrites (to investigate pattern newly created by rewrites)
if( !descendFirst )
rule_AlgebraicSimplification(hi, descendFirst);

hi = fuseSeqAndTableExpand(hi);
}

hop.setVisited();
Expand Down Expand Up @@ -2913,4 +2916,24 @@ private static Hop simplyfyMMCBindZeroVector(Hop parent, Hop hi, int pos) {
}
return hi;
}


private static Hop fuseSeqAndTableExpand(Hop hi) {

if(TernaryOp.ALLOW_CTABLE_SEQUENCE_REWRITES && hi instanceof TernaryOp ) {
TernaryOp thop = (TernaryOp) hi;
thop.getOp();

if(thop.isSequenceRewriteApplicable(true) && thop.findExecTypeTernaryOp() == ExecType.CP &&
thop.getInput(1).getForcedExecType() != Types.ExecType.FED) {
Hop input1 = thop.getInput(0);
if(input1 instanceof DataGenOp){
Hop literal = new LiteralOp("seq(1, "+input1.getDim1() +")");
HopRewriteUtils.replaceChildReference(hi, input1, literal);
}
}

}
return hi;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
import org.apache.sysds.hops.UnaryOp;
import org.apache.sysds.common.Types.AggOp;
import org.apache.sysds.common.Types.Direction;
import org.apache.sysds.common.Types.ExecType;
import org.apache.sysds.common.Types.OpOp1;
import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.OpOp3;
Expand Down Expand Up @@ -199,7 +198,6 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst)
//hi = removeUnecessaryPPred(hop, hi, i); //e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X))

hi = fixNonScalarPrint(hop, hi, i); //e.g., print(m) -> print(toString(m))
hi = fuseSeqAndTableExpand(hi);

//process childs recursively after rewrites (to investigate pattern newly created by rewrites)
if( !descendFirst )
Expand Down Expand Up @@ -2197,22 +2195,4 @@ private static void removeTWriteTReadPairs(ArrayList<Hop> roots) {
}
}
}

private static Hop fuseSeqAndTableExpand(Hop hi) {

if(TernaryOp.ALLOW_CTABLE_SEQUENCE_REWRITES && hi instanceof TernaryOp ) {
TernaryOp thop = (TernaryOp) hi;
thop.getOp();

if(thop.isSequenceRewriteApplicable(true) && thop.findExecTypeTernaryOp() == ExecType.CP) {
Hop input1 = thop.getInput(0);
if(input1 instanceof DataGenOp){
Hop literal = new LiteralOp("seq(1, "+input1.getDim1() +")");
HopRewriteUtils.replaceChildReference(hi, input1, literal);
}
}

}
return hi;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ private void runTestMM(String fileX, String fileY, long driverMemory, int number

// original compilation used for comparison
Program expectedProgram = ResourceCompiler.compile(HOME+"mm_test.dml", nvargs);
Program recompiledProgram = runTest(precompiledProgram, expectedProgram, driverMemory, numberExecutors, executorMemory);
Program recompiledProgram = runTest(precompiledProgram, expectedProgram, driverMemory, numberExecutors, executorMemory, new StringBuilder());

Optional<Instruction> mmInstruction = ((BasicProgramBlock) recompiledProgram.getProgramBlocks().get(0)).getInstructions().stream()
.filter(inst -> (Objects.equals(expectedSparkExecType, inst instanceof SPInstruction) && Objects.equals(inst.getOpcode(), expectedOpcode)))
Expand All @@ -257,7 +257,7 @@ private void runTestTSMM(String fileX, long driverMemory, int numberExecutors, l
}
// original compilation used for comparison
Program expectedProgram = ResourceCompiler.compile(HOME+"mm_transpose_test.dml", nvargs);
Program recompiledProgram = runTest(precompiledProgram, expectedProgram, driverMemory, numberExecutors, executorMemory);
Program recompiledProgram = runTest(precompiledProgram, expectedProgram, driverMemory, numberExecutors, executorMemory, new StringBuilder());
Optional<Instruction> mmInstruction = ((BasicProgramBlock) recompiledProgram.getProgramBlocks().get(0)).getInstructions().stream()
.filter(inst -> (Objects.equals(expectedSparkExecType, inst instanceof SPInstruction) && Objects.equals(inst.getOpcode(), expectedOpcode)))
.findFirst();
Expand All @@ -273,22 +273,23 @@ private void runTestAlgorithm(String dmlScript, long driverMemory, int numberExe
Map<String, String> nvargs) throws IOException {
// pre-compiled program using default values to be used as source for the recompilation
Program precompiledProgram = generateInitialProgram(HOME+dmlScript, nvargs);
System.out.println("precompiled");
System.out.println(Explain.explain(precompiledProgram));
StringBuilder sb = new StringBuilder();
sb.append("\n\nprecompiled\n");
sb.append(Explain.explain(precompiledProgram));
if (numberExecutors > 0) {
ResourceCompiler.setSparkClusterResourceConfigs(driverMemory, driverThreads, numberExecutors, executorMemory, executorThreads);
} else {
ResourceCompiler.setSingleNodeResourceConfigs(driverMemory, driverThreads);
}
// original compilation used for comparison
Program expectedProgram = ResourceCompiler.compile(HOME+dmlScript, nvargs);
System.out.println("expected");
System.out.println(Explain.explain(expectedProgram));
runTest(precompiledProgram, expectedProgram, driverMemory, numberExecutors, executorMemory);
sb.append("\n\nexpected\n");
sb.append(Explain.explain(expectedProgram));
runTest(precompiledProgram, expectedProgram, driverMemory, numberExecutors, executorMemory, sb);
}

private Program runTest(Program precompiledProgram, Program expectedProgram, long driverMemory, int numberExecutors, long executorMemory) {
if (DEBUG_MODE) System.out.println(Explain.explain(expectedProgram));
private Program runTest(Program precompiledProgram, Program expectedProgram, long driverMemory, int numberExecutors, long executorMemory, StringBuilder sb) {
if (DEBUG_MODE) sb.append(Explain.explain(expectedProgram));
Program recompiledProgram;
if (numberExecutors == 0) {
ResourceCompiler.setSingleNodeResourceConfigs(driverMemory, driverThreads);
Expand All @@ -303,19 +304,19 @@ private Program runTest(Program precompiledProgram, Program expectedProgram, lon
);
recompiledProgram = ResourceCompiler.doFullRecompilation(precompiledProgram);
}
System.out.println("recompiled");
System.out.println(Explain.explain(recompiledProgram));
sb.append("\n\nrecompiled\n");
sb.append(Explain.explain(recompiledProgram));

if (DEBUG_MODE) System.out.println(Explain.explain(recompiledProgram));
assertEqualPrograms(expectedProgram, recompiledProgram);
if (DEBUG_MODE) sb.append(Explain.explain(recompiledProgram));
assertEqualPrograms(expectedProgram, recompiledProgram, sb);
return recompiledProgram;
}

private void assertEqualPrograms(Program expected, Program actual) {
private void assertEqualPrograms(Program expected, Program actual, StringBuilder sb) {
// strip empty blocks basic program blocks
String expectedProgramExplained = stripGeneralAndReplaceRandoms(Explain.explain(expected));
String actualProgramExplained = stripGeneralAndReplaceRandoms(Explain.explain(actual));
Assert.assertEquals(expectedProgramExplained, actualProgramExplained);
Assert.assertEquals(sb.toString(), expectedProgramExplained, actualProgramExplained);
}

private String stripGeneralAndReplaceRandoms(String explainedProgram) {
Expand Down

0 comments on commit d9ebcf0

Please sign in to comment.