Skip to content

Commit

Permalink
fixed ctable with seq fuse
Browse files Browse the repository at this point in the history
rewrite fused ctable with given output dim (disaled: performance decrease, need to fix it first)
  • Loading branch information
e-strauss committed Feb 4, 2025
1 parent 749ec56 commit 407090d
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 56 deletions.
7 changes: 3 additions & 4 deletions scripts/builtin/kmeans.dml
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,10 @@ m_kmeans = function(Matrix[Double] X, Integer k = 10, Integer runs = 10, Integer
}

# Find the closest centroid for each record
# P = D <= minD;
P = D <= minD;
# If some records belong to multiple centroids, share them equally
# P = P / rowSums (P);
P = table(seq(1,nrow(D)), rowIndexMin(D))
# P = table(seq(1,nrow(D)),compress(rowIndexMin(D)))
P = P / rowSums (P);
# P = table(seq(1,num_records), rowIndexMin(D), num_records, num_centroids)
# Compute the column normalization factor for P
P_denom = colSums (P);
# Compute new centroids as weighted averages over the records
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/org/apache/sysds/hops/TernaryOp.java
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,8 @@ public boolean isSequenceRewriteApplicable( boolean left )

try
{
// TODO: to rewrite is not currently not triggered if outdim are given --> getInput().size()>=3
// currently disabled due performance decrease
if( getInput().size()==2 || (getInput().size()==3 && getInput().get(2).getDataType()==DataType.SCALAR) )
{
Hop input1 = getInput().get(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.Pair;
import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.runtime.util.UtilFunctions;

Expand Down Expand Up @@ -71,19 +72,23 @@ public static MatrixBlock rexpand(int seqHeight, MatrixBlock A, int nColOut, int

try {
final int[] map = new int[seqHeight];
int maxCol = constructInitialMapping(map, A, k);
Pair<Integer, Integer> meta = constructInitialMapping(map, A, k, nColOut);
int maxCol = meta.getKey();
int nZeros = meta.getValue();
boolean containsNull = maxCol < 0;
maxCol = Math.abs(maxCol);

boolean cutOff = false;
if(nColOut == -1)
nColOut = maxCol;
else if(nColOut < maxCol)
throw new DMLRuntimeException("invalid nColOut, requested: " + nColOut + " but have to be : " + maxCol);
cutOff = true;

final int nNulls = containsNull ? correctNulls(map, nColOut) : 0;
if(containsNull)
correctNulls(map, nColOut);
if(nColOut == 0) // edge case of empty zero dimension block.
return new MatrixBlock(seqHeight, 0, 0.0);
return createCompressedReturn(map, nColOut, seqHeight, nNulls, containsNull, k);
return createCompressedReturn(map, nColOut, seqHeight, nZeros, containsNull || cutOff, k);
}
catch(Exception e) {
throw new RuntimeException("Failed table seq operator", e);
Expand Down Expand Up @@ -139,7 +144,7 @@ private static int correctNulls(int[] map, int nColOut) {
return nNulls;
}

private static int constructInitialMapping(int[] map, MatrixBlock A, int k) {
private static Pair<Integer,Integer> constructInitialMapping(int[] map, MatrixBlock A, int k, int maxOutCol) {
if(A.isEmpty() || A.isInSparseFormat())
throw new DMLRuntimeException("not supported empty or sparse construction of seq table");
final MatrixBlock Ac;
Expand All @@ -155,20 +160,23 @@ private static int constructInitialMapping(int[] map, MatrixBlock A, int k) {
try {

int blkz = Math.max((map.length / k), 1000);
List<Future<Integer>> tasks = new ArrayList<>();
List<Future<Pair<Integer,Integer>>> tasks = new ArrayList<>();
for(int i = 0; i < map.length; i += blkz) {
final int start = i;
final int end = Math.min(i + blkz, map.length);
tasks.add(pool.submit(() -> partialMapping(map, Ac, start, end)));
tasks.add(pool.submit(() -> partialMapping(map, Ac, start, end, maxOutCol)));
}

int maxCol = 0;
for(Future<Integer> f : tasks) {
int tmp = f.get();
if(Math.abs(tmp) > Math.abs(maxCol))
maxCol = tmp;
int zeros = 0;
for(Future<Pair<Integer,Integer>> f : tasks) {
int tmpMaxCol = f.get().getKey();
int tmpZeros = f.get().getValue();
if(Math.abs(tmpMaxCol) > Math.abs(maxCol))
maxCol = tmpMaxCol;
zeros += tmpZeros;
}
return maxCol;
return new Pair<Integer,Integer>(maxCol, zeros);
}
catch(Exception e) {
throw new DMLRuntimeException(e);
Expand All @@ -179,33 +187,32 @@ private static int constructInitialMapping(int[] map, MatrixBlock A, int k) {

}

private static int partialMapping(int[] map, MatrixBlock A, int start, int end) {
private static Pair<Integer, Integer> partialMapping(int[] map, MatrixBlock A, int start, int end, int maxOutCol) {

int maxCol = 0;
boolean containsNull = false;

int zeros = 0;
int notHandledNulls = 0;
final double[] aVals = A.getDenseBlockValues();

for(int i = start; i < end; i++) {
final double v2 = aVals[i];
if(Double.isNaN(v2)) {
map[i] = -1; // assign temporarily to -1
containsNull = true;
}
else {
// safe casts to long for consistent behavior with indexing
int col = UtilFunctions.toInt(v2);
if(col <= 0)
throw new DMLRuntimeException(
int colUnsafe = UtilFunctions.toInt(v2);
if(colUnsafe <= 0)
throw new DMLRuntimeException(
"Erroneous input while computing the contingency table (value <= zero): " + v2);
boolean invalid = Double.isNaN(v2) || (maxOutCol != -1 && colUnsafe > maxOutCol);
final int colSafe = invalid ? maxOutCol : colUnsafe - 1;
zeros += invalid ? 1 : 0;
notHandledNulls += Double.isNaN(v2) ? maxOutCol : 0;
maxCol = Math.max(colUnsafe, maxCol);
map[i] = colSafe;
}

map[i] = col - 1;
// maintain max seen col
maxCol = Math.max(col, maxCol);
}
if (notHandledNulls < 0){
maxCol *= -1;
}

return containsNull ? maxCol * -1 : maxCol;
return new Pair<Integer, Integer>(maxCol, zeros);
}

public static boolean compressedTableSeq() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,17 @@ public void processInstruction(ExecutionContext ec) {

boolean outputDimsKnown = (outputDim1 != -1 && outputDim2 != -1);
if ( outputDimsKnown ) {
int inputRows = matBlock1.getNumRows();
int inputCols = matBlock1.getNumColumns();
boolean sparse = MatrixBlock.evalSparseFormatInMemory(outputDim1, outputDim2, inputRows*inputCols);
//only create result block if dense; it is important not to aggregate on sparse result
//blocks because it would implicitly turn the O(N) algorithm into O(N log N).
if( !sparse )
resultBlock = new MatrixBlock((int)outputDim1, (int)outputDim2, false);
if(_isExpand){
resultBlock = new MatrixBlock((int)outputDim1, (int)outputDim2, true);
} else {
int inputRows = matBlock1.getNumRows();
int inputCols = matBlock1.getNumColumns();
boolean sparse = MatrixBlock.evalSparseFormatInMemory(outputDim1, outputDim2, inputRows*inputCols);
//only create result block if dense; it is important not to aggregate on sparse result
//blocks because it would implicitly turn the O(N) algorithm into O(N log N).
if( !sparse )
resultBlock = new MatrixBlock((int)outputDim1, (int)outputDim2, false);
}
}

switch(ctableOp) {
Expand All @@ -140,7 +144,8 @@ public void processInstruction(ExecutionContext ec) {
}
matBlock2 = ec.getMatrixInput(input2.getName());
cst1 = ec.getScalarInput(input3).getDoubleValue();
resultBlock = LibMatrixReorg.fusedSeqRexpand(matBlock2.getNumRows(), matBlock2, cst1, resultBlock, true, _k);
resultBlock = LibMatrixReorg.fusedSeqRexpand(matBlock2.getNumRows(), matBlock2, cst1, resultBlock,
!outputDimsKnown, _k);
break;
case CTABLE_TRANSFORM_HISTOGRAM: //(VECTOR)
// F=ctable(A,1) or F = ctable(A,1,1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1044,11 +1044,13 @@ public static MatrixBlock fusedSeqRexpand(int seqHeight, MatrixBlock A, double w

}

private static MatrixBlock fusedSeqRexpandSparse(int seqHeight, MatrixBlock A, double w, MatrixBlock ret, boolean updateClen) {
private static MatrixBlock fusedSeqRexpandSparse(int seqHeight, MatrixBlock A, double w, MatrixBlock ret,
boolean updateClen) {
if(ret == null) {
ret = new MatrixBlock();
updateClen = true;
}
int outCols = updateClen ? -1 : ret.getNumColumns();
final int rlen = seqHeight;
// prepare allocation of CSR sparse block
final int[] rowPointers = new int[rlen + 1];
Expand All @@ -1060,14 +1062,14 @@ private static MatrixBlock fusedSeqRexpandSparse(int seqHeight, MatrixBlock A, d
ret.sparse = true;
ret.denseBlock = null;
// construct sparse CSR block from filled arrays
SparseBlockCSR csr = new SparseBlockCSR(rowPointers, indexes, values, rlen);
SparseBlockCSR csr = new SparseBlockCSR(rowPointers, indexes, values, seqHeight);
ret.sparseBlock = csr;
int blkz = Math.min(1024, rlen);
int blkz = Math.min(1024, seqHeight);
int maxcol = 0;
boolean containsNull = false;
for(int i = 0; i < rlen; i += blkz) {
for(int i = 0; i < seqHeight; i += blkz) {
// blocked execution for earlier JIT compilation
int t = fusedSeqRexpandSparseBlock(csr, A, w, i, Math.min(i + blkz, rlen));
int t = fusedSeqRexpandSparseBlock(csr, A, w, i, Math.min(i + blkz, seqHeight), (int) outCols);
if(t < 0) {
t = Math.abs(t);
containsNull = true;
Expand All @@ -1078,14 +1080,15 @@ private static MatrixBlock fusedSeqRexpandSparse(int seqHeight, MatrixBlock A, d
if(containsNull)
csr.compact();

rowPointers[rlen] = rlen;
rowPointers[seqHeight] = seqHeight;
ret.setNonZeros(ret.sparseBlock.size());
if(updateClen)
ret.setNumColumns(maxcol);
ret.setNumColumns(outCols == -1 ? maxcol : (int) outCols);
return ret;
}

private static int fusedSeqRexpandSparseBlock(final SparseBlockCSR csr, final MatrixBlock A, final double w, int rl, int ru) {
private static int fusedSeqRexpandSparseBlock(final SparseBlockCSR csr, final MatrixBlock A, final double w, int rl,
int ru, int maxOutCol) {

// prepare allocation of CSR sparse block
final int[] rowPointers = csr.rowPointers();
Expand All @@ -1096,7 +1099,7 @@ private static int fusedSeqRexpandSparseBlock(final SparseBlockCSR csr, final Ma
int maxCol = 0;

for(int i = rl; i < ru; i++) {
int c = rexpandSingleRow(i, A.get(i, 0), w, indexes, values);
int c = rexpandSingleRow(i, A.get(i, 0), w, indexes, values, maxOutCol);
if(c < 0)
containsNull = true;
else
Expand All @@ -1114,7 +1117,7 @@ private static void updateClenRexpand(MatrixBlock ret, int maxCol, boolean updat
ret.clen = maxCol;
}

public static int rexpandSingleRow(int row, double v2, double w, int[] retIx, double[] retVals) {
public static int rexpandSingleRow(int row, double v2, double w, int[] retIx, double[] retVals, int maxOutCol) {
// If any of the values are NaN (i.e., missing) then
// we skip this tuple, proceed to the next tuple
if(Double.isNaN(v2))
Expand All @@ -1124,10 +1127,12 @@ public static int rexpandSingleRow(int row, double v2, double w, int[] retIx, d
int col = UtilFunctions.toInt(v2);
if(col <= 0)
throw new DMLRuntimeException("Erroneous input while computing the contingency table (value <= zero): " + v2);

// set weight as value (expand is guaranteed to address different cells)
retIx[row] = col - 1;
retVals[row] = w;
// maxOutCol = - 1 if not specified --> TRUE
if(col <= maxOutCol){
// set weight as value (expand is guaranteed to address different cells)
retIx[row] = col - 1;
retVals[row] = w;
}

// maintain max seen col
return col;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ public void federatedKmeans(Types.ExecMode execMode, boolean singleWorker) {

// Run actual dml script with federated matrix
fullDMLScriptName = HOME + TEST_NAME + ".dml";
programArgs = new String[] {"-stats", "-nvargs", "in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
programArgs = new String[] {"-stats","20", "-nvargs", "in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
"in_X2=" + TestUtils.federatedAddress(port2, input("X2")), "rows=" + rows, "cols=" + cols,
"single=" + String.valueOf(singleWorker).toUpperCase(), "runs=" + String.valueOf(runs), "out=" + output("Z")};

Expand Down

0 comments on commit 407090d

Please sign in to comment.