Skip to content

Commit

Permalink
[SYSTEMDS-3815] Fused table sequence
Browse files Browse the repository at this point in the history
This commit contains a new fused operator for:

table(seq(1, nrow(A)), A, w)

That removes the need to generate a vector of incrementing integers in the size of A. Previously, we already had support for this operator and called it rexpand. However, that implementation still allocated the seq vector.

We see a 1.4x improvement in the rexpand operator, and with the addition of removing the seq allocation, it further improves to 1.72x.

The change is not fully integrated into the Federated Instructions and needs additional work. The current workaround tries to compile the previous instruction for federated use cases.

Closes #2181
  • Loading branch information
Baunsgaard committed Jan 21, 2025
1 parent 9484f11 commit eb2ca62
Show file tree
Hide file tree
Showing 23 changed files with 1,199 additions and 100 deletions.
15 changes: 12 additions & 3 deletions src/main/java/org/apache/sysds/hops/TernaryOp.java
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ else if( isCTableReshapeRewriteApplicable(et, ternaryOp) ) {
}

Ctable ternary = new Ctable(inputLops, ternaryOp,
getDataType(), getValueType(), ignoreZeros, outputEmptyBlocks, et);
getDataType(), getValueType(), ignoreZeros, outputEmptyBlocks, et, OptimizerUtils.getConstrainedNumThreads(getMaxNumThreads()));

ternary.getOutputParameters().setDimensions(getDim1(), getDim2(), getBlocksize(), -1);
setLineNumbers(ternary);
Expand Down Expand Up @@ -480,6 +480,10 @@ protected DataCharacteristics inferOutputCharacteristics( MemoTable memo )
}


public ExecType findExecTypeTernaryOp(){
return _etype == null ? optFindExecType(OptimizerUtils.ALLOW_TRANSITIVE_SPARK_EXEC_TYPE) : _etype;
}

@Override
protected ExecType optFindExecType(boolean transitive)
{
Expand Down Expand Up @@ -637,7 +641,7 @@ && getInput().get(1) == that2.getInput().get(1)
return ret;
}

private boolean isSequenceRewriteApplicable( boolean left )
public boolean isSequenceRewriteApplicable( boolean left )
{
boolean ret = false;

Expand All @@ -651,7 +655,9 @@ private boolean isSequenceRewriteApplicable( boolean left )
{
Hop input1 = getInput().get(0);
Hop input2 = getInput().get(1);
if( input1.getDataType() == DataType.MATRIX && input2.getDataType() == DataType.MATRIX )
if( (input1.getDataType() == DataType.MATRIX
|| input1.getDataType() == DataType.SCALAR )
&& input2.getDataType() == DataType.MATRIX )
{
//probe rewrite on left input
if( left && input1 instanceof DataGenOp )
Expand All @@ -663,6 +669,9 @@ private boolean isSequenceRewriteApplicable( boolean left )
|| dgop.getIncrementValue()==1.0; //set by recompiler
}
}
if( left && input1 instanceof LiteralOp && ((LiteralOp)input1).getStringValue().contains("seq(")){
ret = true;
}
//probe rewrite on right input
if( !left && input2 instanceof DataGenOp )
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.apache.sysds.common.Types.AggOp;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.Direction;
import org.apache.sysds.common.Types.ExecType;
import org.apache.sysds.common.Types.OpOp1;
import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.OpOp3;
Expand Down Expand Up @@ -209,6 +210,8 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst)
//process childs recursively after rewrites (to investigate pattern newly created by rewrites)
if( !descendFirst )
rule_AlgebraicSimplification(hi, descendFirst);

hi = fuseSeqAndTableExpand(hi);
}

hop.setVisited();
Expand Down Expand Up @@ -2913,4 +2916,24 @@ private static Hop simplyfyMMCBindZeroVector(Hop parent, Hop hi, int pos) {
}
return hi;
}


private static Hop fuseSeqAndTableExpand(Hop hi) {

if(TernaryOp.ALLOW_CTABLE_SEQUENCE_REWRITES && hi instanceof TernaryOp ) {
TernaryOp thop = (TernaryOp) hi;
thop.getOp();

if(thop.isSequenceRewriteApplicable(true) && thop.findExecTypeTernaryOp() == ExecType.CP &&
thop.getInput(1).getForcedExecType() != Types.ExecType.FED) {
Hop input1 = thop.getInput(0);
if(input1 instanceof DataGenOp){
Hop literal = new LiteralOp("seq(1, "+input1.getDim1() +")");
HopRewriteUtils.replaceChildReference(hi, input1, literal);
}
}

}
return hi;
}
}
12 changes: 9 additions & 3 deletions src/main/java/org/apache/sysds/lops/Ctable.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ public class Ctable extends Lop
{
private final boolean _ignoreZeros;
private final boolean _outputEmptyBlocks;
private final int _numThreads;

public enum OperationTypes {
CTABLE_TRANSFORM,
Expand All @@ -58,15 +59,16 @@ public boolean hasThirdInput() {
OperationTypes operation;


public Ctable(Lop[] inputLops, OperationTypes op, DataType dt, ValueType vt, ExecType et) {
this(inputLops, op, dt, vt, false, true, et);
public Ctable(Lop[] inputLops, OperationTypes op, DataType dt, ValueType vt, ExecType et, int k) {
this(inputLops, op, dt, vt, false, true, et, k);
}

public Ctable(Lop[] inputLops, OperationTypes op, DataType dt, ValueType vt, boolean ignoreZeros, boolean outputEmptyBlocks, ExecType et) {
public Ctable(Lop[] inputLops, OperationTypes op, DataType dt, ValueType vt, boolean ignoreZeros, boolean outputEmptyBlocks, ExecType et, int k) {
super(Lop.Type.Ctable, dt, vt);
init(inputLops, op, et);
_ignoreZeros = ignoreZeros;
_outputEmptyBlocks = outputEmptyBlocks;
_numThreads = k;
}

private void init(Lop[] inputLops, OperationTypes op, ExecType et) {
Expand Down Expand Up @@ -175,6 +177,10 @@ public String getInstructions(String input1, String input2, String input3, Strin
sb.append( OPERAND_DELIMITOR );
sb.append( _outputEmptyBlocks );
}
else {
sb.append( OPERAND_DELIMITOR );
sb.append(_numThreads);
}

return sb.toString();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1987,8 +1987,8 @@ else if(this.getOpCode() == Builtins.MAX_POOL || this.getOpCode() == Builtins.AV
case DECOMPRESS:
if(OptimizerUtils.ALLOW_SCRIPT_LEVEL_COMPRESS_COMMAND){
checkNumParameters(1);
checkMatrixParam(getFirstExpr());
output.setDataType(DataType.MATRIX);
checkMatrixFrameParam(getFirstExpr());
output.setDataType(getFirstExpr().getOutput().getDataType());
output.setDimensions(id.getDim1(), id.getDim2());
output.setBlocksize (id.getBlocksize());
output.setValueType(id.getValueType());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -822,7 +822,7 @@ public double parseSPInst(SPInstruction inst) throws CostEstimationException {
SparkCostUtils.getMatMulChainInstTime(mmchaininst, input1, input2, input3, output, driverMetrics, executorMetrics);
} else if (inst instanceof CtableSPInstruction) {
CtableSPInstruction tableInst = (CtableSPInstruction) inst;
VarStats input1 = getStats(tableInst.input1.getName());
VarStats input1 = getStatsWithDefaultScalar(tableInst.input1.getName());
VarStats input2 = getStatsWithDefaultScalar(tableInst.input2.getName());
VarStats input3 = getStatsWithDefaultScalar(tableInst.input3.getName());
double loadTime = loadRDDStatsAndEstimateTime(input1) +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import org.apache.sysds.runtime.compress.lib.CLALibMMChain;
import org.apache.sysds.runtime.compress.lib.CLALibMatrixMult;
import org.apache.sysds.runtime.compress.lib.CLALibMerge;
import org.apache.sysds.runtime.compress.lib.CLALibReshape;
import org.apache.sysds.runtime.compress.lib.CLALibRexpand;
import org.apache.sysds.runtime.compress.lib.CLALibScalar;
import org.apache.sysds.runtime.compress.lib.CLALibSlice;
Expand Down Expand Up @@ -1281,6 +1282,11 @@ public MatrixBlock transpose(int k) {
return getUncompressed().transpose(k);
}

@Override
public MatrixBlock reshape(int rows,int cols, boolean byRow){
return CLALibReshape.reshape(this, rows, cols, byRow);
}

@Override
public String toString() {
StringBuilder sb = new StringBuilder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -595,10 +595,8 @@ public void preAggregateDense(MatrixBlock m, double[] preAgg, int rl, int ru, in
public void leftMMIdentityPreAggregateDense(MatrixBlock that, MatrixBlock ret, int rl, int ru, int cl, int cu) {
DenseBlock db = that.getDenseBlock();
DenseBlock retDB = ret.getDenseBlock();
if(rl == ru - 1)
leftMMIdentityPreAggregateDenseSingleRow(db.values(rl), db.pos(rl), retDB.values(rl), retDB.pos(rl), cl, cu);
else
throw new NotImplementedException();
for(int i = rl; i < ru; i++)
leftMMIdentityPreAggregateDenseSingleRow(db.values(i), db.pos(i), retDB.values(i), retDB.pos(i), cl, cu);
}

@Override
Expand Down Expand Up @@ -632,7 +630,8 @@ public void rightDecompressingMult(MatrixBlock right, MatrixBlock ret, int rl, i
}
}

final void vectMM(double aa, double[] b, double[] c, int endT, int jd, int crl, int cru, int offOut, int k, int vLen) {
final void vectMM(double aa, double[] b, double[] c, int endT, int jd, int crl, int cru, int offOut, int k,
int vLen) {
// vVec = vVec.broadcast(aa);
final int offj = k * jd;
final int end = endT + offj;
Expand Down Expand Up @@ -919,16 +918,16 @@ public void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock ret,

@Override
protected void denseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) {
// morph(CompressionType.UNCOMPRESSED, _data.size()).sparseSelection(selection, ret, rl, ru);;
final SparseBlock sb = selection.getSparseBlock();
final DenseBlock retB = ret.getDenseBlock();
for(int r = rl; r < ru; r++) {
if(sb.isEmpty(r))
continue;
final int sPos = sb.pos(r);
final int rowCompressed = sb.indexes(r)[sPos]; // column index with 1
decompressToDenseBlock(retB, rowCompressed, rowCompressed + 1, r - rowCompressed, 0);
}
// morph(CompressionType.UNCOMPRESSED, _data.size()).sparseSelection(selection, ret, rl, ru);;
final SparseBlock sb = selection.getSparseBlock();
final DenseBlock retB = ret.getDenseBlock();
for(int r = rl; r < ru; r++) {
if(sb.isEmpty(r))
continue;
final int sPos = sb.pos(r);
final int rowCompressed = sb.indexes(r)[sPos]; // column index with 1
decompressToDenseBlock(retB, rowCompressed, rowCompressed + 1, r - rowCompressed, 0);
}
}


Expand All @@ -946,22 +945,21 @@ private void leftMMIdentityPreAggregateDenseSingleRow(double[] values, int pos,
for(int rc = cl; rc < cu; rc++, pos++) {
final int idx = _data.getIndex(rc);
if(idx != nVal)
values2[_colIndexes.get(idx)] += values[pos];
values2[pos2 + _colIndexes.get(idx)] += values[pos];
}
}
else {
for(int rc = cl; rc < cu; rc++, pos++)
values2[_colIndexes.get(_data.getIndex(rc))] += values[pos];
values2[pos2 + _colIndexes.get(_data.getIndex(rc))] += values[pos];
}
}
}


private void leftMMIdentityPreAggregateDenseSingleRowRangeIndex(double[] values, int pos, double[] values2, int pos2,
int cl, int cu) {
IdentityDictionary a = (IdentityDictionary) _dict;

final int firstCol = _colIndexes.get(0);
final int firstCol = pos2 + _colIndexes.get(0);
pos += cl; // left side matrix position offset.
if(a.withEmpty()) {
final int nVal = _dict.getNumberOfValues(_colIndexes.size()) - 1;
Expand Down
Loading

0 comments on commit eb2ca62

Please sign in to comment.