Skip to content

Commit

Permalink
[SYSTEMDS-3827] CLA MultiCBind
Browse files Browse the repository at this point in the history
This commit adds specialized support for n way cbind in compressed space.

Closes #2208
  • Loading branch information
Baunsgaard committed Feb 3, 2025
1 parent cdff385 commit b751389
Show file tree
Hide file tree
Showing 8 changed files with 263 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@
import org.apache.sysds.runtime.compress.colgroup.ColGroupIO;
import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed;
import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary;
import org.apache.sysds.runtime.compress.lib.CLALibAppend;
import org.apache.sysds.runtime.compress.lib.CLALibBinaryCellOp;
import org.apache.sysds.runtime.compress.lib.CLALibCBind;
import org.apache.sysds.runtime.compress.lib.CLALibCMOps;
import org.apache.sysds.runtime.compress.lib.CLALibCompAgg;
import org.apache.sysds.runtime.compress.lib.CLALibDecompress;
Expand Down Expand Up @@ -556,8 +556,8 @@ public MatrixBlock binaryOperationsLeft(BinaryOperator op, MatrixValue thatValue

@Override
public MatrixBlock append(MatrixBlock[] that, MatrixBlock ret, boolean cbind) {
if(cbind && that.length == 1)
return CLALibAppend.append(this, that[0], InfrastructureAnalyzer.getLocalParallelism());
if(cbind)
return CLALibCBind.cbind(this, that, InfrastructureAnalyzer.getLocalParallelism());
else {
MatrixBlock left = getUncompressed("append list or r-bind not supported in compressed");
MatrixBlock[] thatUC = new MatrixBlock[that.length];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,7 @@ public AMapToData getMapToData() {
public double getSparsity() {
return 1.0;
}

@Override
protected void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) {
throw new NotImplementedException();
Expand Down Expand Up @@ -710,12 +710,10 @@ protected void decompressToSparseBlockTransposedDenseDictionary(SparseBlockMCSR
public AColGroup combineWithSameIndex(int nRow, int nCol, AColGroup right) {
if(!(right instanceof ColGroupConst))
return super.combineWithSameIndex(nRow, nCol, right);

final IColIndex combIndex = _colIndexes.combine(right.getColIndices().shift(nCol));
final IDictionary b = ((ColGroupConst) right).getDictionary();
final IDictionary combined = DictionaryFactory.cBindDictionaries(_dict, b, this.getNumCols(), right.getNumCols());
return create(combIndex, combined);

}

@Override
Expand All @@ -737,10 +735,11 @@ public AColGroup combineWithSameIndex(int nRow, int nCol, List<AColGroup> right)
for(int i = 0; i < right.size(); i++) {
AColGroup g = right.get(i);

if(!(g instanceof ColGroupConst) || !(g instanceof ColGroupEmpty)) {
if(!(g instanceof ColGroupConst) && !(g instanceof ColGroupEmpty)) {
return super.combineWithSameIndex(nRow, nCol, right);
}
}

IColIndex combinedIndex = _colIndexes;
int i = 0;
for(AColGroup g : right) {
Expand All @@ -751,7 +750,7 @@ public AColGroup combineWithSameIndex(int nRow, int nCol, List<AColGroup> right)

return create(combinedIndex, combined);
}

@Override
protected boolean allowShallowIdentityRightMult() {
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -549,25 +549,12 @@ public AColGroupCompressed combineWithSameIndex(int nRow, int nCol, List<AColGro
final IColIndex combinedColIndex = combineColIndexes(nCol, right);
final double[] combinedDefaultTuple = IContainDefaultTuple.combineDefaultTuples(_reference, right);

// return new ColGroupDDC(combinedColIndex, combined, _data, getCachedCounts());
return new ColGroupSDC(combinedColIndex, this.getNumRows(), combined, combinedDefaultTuple, _indexes, _data,
getCachedCounts());
return new ColGroupSDCFOR(combinedColIndex, this.getNumRows(), combined, _indexes, _data, getCachedCounts(),
combinedDefaultTuple);
}

@Override
public AColGroupCompressed combineWithSameIndex(int nRow, int nCol, AColGroup right) {
// if(right instanceof ColGroupSDCZeros){
// ColGroupSDCZeros rightSDC = ((ColGroupSDCZeros) right);
// IDictionary b = rightSDC.getDictionary();
// IDictionary combined = DictionaryFactory.cBindDictionaries(_dict, b, this.getNumCols(), right.getNumCols());
// IColIndex combinedColIndex = _colIndexes.combine(right.getColIndices().shift(nCol));
// double[] combinedDefaultTuple = new double[_reference.length + right.getNumCols()];
// System.arraycopy(_reference, 0, combinedDefaultTuple, 0, _reference.length);

// return new ColGroupSDC(combinedColIndex, this.getNumRows(), combined, combinedDefaultTuple, _indexes, _data,
// getCachedCounts());
// }
// else{
ColGroupSDCFOR rightSDC = ((ColGroupSDCFOR) right);
IDictionary b = rightSDC.getDictionary();
IDictionary combined = DictionaryFactory.cBindDictionaries(_dict, b, this.getNumCols(), right.getNumCols());
Expand All @@ -576,9 +563,8 @@ public AColGroupCompressed combineWithSameIndex(int nRow, int nCol, AColGroup ri
System.arraycopy(_reference, 0, combinedDefaultTuple, 0, _reference.length);
System.arraycopy(rightSDC._reference, 0, combinedDefaultTuple, _reference.length, rightSDC._reference.length);

return new ColGroupSDC(combinedColIndex, this.getNumRows(), combined, combinedDefaultTuple, _indexes, _data,
getCachedCounts());
// }
return new ColGroupSDCFOR(combinedColIndex, this.getNumRows(), combined, _indexes, _data, getCachedCounts(),
combinedDefaultTuple);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,60 @@

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.CompressedMatrixBlockFactory;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty;
import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed;
import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.CommonThreadPool;

public final class CLALibAppend {
public final class CLALibCBind {

private CLALibAppend(){
private CLALibCBind() {
// private constructor.
}

private static final Log LOG = LogFactory.getLog(CLALibAppend.class.getName());
private static final Log LOG = LogFactory.getLog(CLALibCBind.class.getName());

public static MatrixBlock append(MatrixBlock left, MatrixBlock right, int k) {
public static MatrixBlock cbind(MatrixBlock left, MatrixBlock[] right, int k) {
try {

if(right.length == 1) {
return cbind(left, right[0], k);
}
else {
boolean allCompressed = true;
for(int i = 0; i < right.length && allCompressed; i++)
allCompressed = right[i] instanceof CompressedMatrixBlock;
if(allCompressed)
return cbindAllCompressed((CompressedMatrixBlock) left, right, k);
else
return cbindAllNormalCompressed(left, right, k);
}
}
catch(Exception e) {
throw new DMLCompressionException("Failed to Cbind with compressed input", e);
}
}

private static MatrixBlock cbindAllNormalCompressed(MatrixBlock left, MatrixBlock[] right, int k) {
for(int i = 0; i < right.length; i++) {
left = cbind(left, right[i], k);
}
return left;
}

public static MatrixBlock cbind(MatrixBlock left, MatrixBlock right, int k) {

final int m = left.getNumRows();
final int n = left.getNumColumns() + right.getNumColumns();
Expand All @@ -66,15 +99,96 @@ else if(right.isEmpty() && left instanceof CompressedMatrixBlock)
final double spar = (left.getNonZeros() + right.getNonZeros()) / ((double) m * n);
final double estSizeUncompressed = MatrixBlock.estimateSizeInMemory(m, n, spar);
final double estSizeCompressed = left.getInMemorySize() + right.getInMemorySize();
// if(isAligned((CompressedMatrixBlock) left, (CompressedMatrixBlock) right))
// return combineCompressed((CompressedMatrixBlock) left, (CompressedMatrixBlock) right);
// else
if(estSizeUncompressed < estSizeCompressed)
return uc(left).append(uc(right), null);
else if(left instanceof CompressedMatrixBlock)
return appendRightUncompressed((CompressedMatrixBlock) left, right, m, n);
else
return appendLeftUncompressed(left, (CompressedMatrixBlock) right, m, n);
}
if(isAligned((CompressedMatrixBlock) left, (CompressedMatrixBlock) right))
return combineCompressed((CompressedMatrixBlock) left, (CompressedMatrixBlock) right);
else
return append((CompressedMatrixBlock) left, (CompressedMatrixBlock) right, m, n);
}

private static MatrixBlock cbindAllCompressed(CompressedMatrixBlock left, MatrixBlock[] right, int k)
throws InterruptedException, ExecutionException {

final int nCol = left.getNumColumns();
for(int i = 0; i < right.length; i++) {
CompressedMatrixBlock rightCM = ((CompressedMatrixBlock) right[i]);
if(nCol != right[i].getNumColumns() || !isAligned(left, rightCM))
return cbindAllNormalCompressed(left, right, k);
}
return cbindAllCompressedAligned(left, right, k);

}

private static boolean isAligned(CompressedMatrixBlock left, CompressedMatrixBlock right) {
final List<AColGroup> gl = left.getColGroups();
for(int j = 0; j < gl.size(); j++) {
final AColGroup glj = gl.get(j);
final int aColumnInGroup = glj.getColIndices().get(0);
final AColGroup grj = right.getColGroupForColumn(aColumnInGroup);

if(!glj.sameIndexStructure(grj) || glj.getNumCols() != grj.getNumCols())
return false;

}
return true;
}

private static CompressedMatrixBlock combineCompressed(CompressedMatrixBlock left, CompressedMatrixBlock right) {
final List<AColGroup> gl = left.getColGroups();
final List<AColGroup> retCG = new ArrayList<>(gl.size());
for(int j = 0; j < gl.size(); j++) {
AColGroup glj = gl.get(j);
int aColumnInGroup = glj.getColIndices().get(0);
AColGroup grj = right.getColGroupForColumn(aColumnInGroup);
// parallel combine...
retCG.add(glj.combineWithSameIndex(left.getNumRows(), left.getNumColumns(), grj));
}
return new CompressedMatrixBlock(left.getNumRows(), left.getNumColumns() + right.getNumColumns(),
left.getNonZeros() + right.getNonZeros(), false, retCG);
}

private static CompressedMatrixBlock cbindAllCompressedAligned(CompressedMatrixBlock left, MatrixBlock[] right,
final int k) throws InterruptedException, ExecutionException {

final ExecutorService pool = CommonThreadPool.get(k);
try {
final List<AColGroup> gl = left.getColGroups();
final List<Future<AColGroup>> tasks = new ArrayList<>();
final int nCol = left.getNumColumns();
final int nRow = left.getNumRows();
for(int i = 0; i < gl.size(); i++) {
final AColGroup gli = gl.get(i);
tasks.add(pool.submit(() -> {
List<AColGroup> combines = new ArrayList<>();
final int cId = gli.getColIndices().get(0);
for(int j = 0; j < right.length; j++) {
combines.add(((CompressedMatrixBlock) right[j]).getColGroupForColumn(cId));
}
return gli.combineWithSameIndex(nRow, nCol, combines);
}));
}

final List<AColGroup> retCG = new ArrayList<>(gl.size());
for(Future<AColGroup> t : tasks)
retCG.add(t.get());

int totalCol = nCol + right.length * nCol;

return new CompressedMatrixBlock(left.getNumRows(), totalCol, -1, false, retCG);
}
finally {
pool.shutdown();
}

return append((CompressedMatrixBlock) left, (CompressedMatrixBlock) right, m, n);
}

private static MatrixBlock appendLeftUncompressed(MatrixBlock left, CompressedMatrixBlock right, final int m,
Expand Down Expand Up @@ -123,17 +237,17 @@ private static MatrixBlock append(CompressedMatrixBlock left, CompressedMatrixBl
ret.setNonZeros(left.getNonZeros() + right.getNonZeros());
ret.setOverlapping(left.isOverlapping() || right.isOverlapping());

final double compressedSize = ret.getInMemorySize();
final double uncompressedSize = MatrixBlock.estimateSizeInMemory(m, n, ret.getSparsity());
// final double compressedSize = ret.getInMemorySize();
// final double uncompressedSize = MatrixBlock.estimateSizeInMemory(m, n, ret.getSparsity());

if(compressedSize < uncompressedSize)
return ret;
else {
final double ratio = uncompressedSize / compressedSize;
String message = String.format("Decompressing c bind matrix because it had to small compression ratio: %2.3f",
ratio);
return ret.getUncompressed(message);
}
// if(compressedSize < uncompressedSize)
return ret;
// else {
// final double ratio = uncompressedSize / compressedSize;
// String message = String.format("Decompressing c bind matrix because it had to small compression ratio: %2.3f",
// ratio);
// return ret.getUncompressed(message);
// }
}

private static MatrixBlock appendRightEmpty(CompressedMatrixBlock left, MatrixBlock right, int m, int n) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.lib.CLALibAppend;
import org.apache.sysds.runtime.compress.lib.CLALibCBind;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.lineage.LineageItemUtils;
Expand All @@ -46,8 +46,9 @@ public void processInstruction(ExecutionContext ec) {
validateInput(matBlock1, matBlock2);

MatrixBlock ret;
if(matBlock1 instanceof CompressedMatrixBlock || matBlock2 instanceof CompressedMatrixBlock)
ret = CLALibAppend.append(matBlock1, matBlock2, InfrastructureAnalyzer.getLocalParallelism());
if(_type == AppendType.CBIND &&
(matBlock1 instanceof CompressedMatrixBlock || matBlock2 instanceof CompressedMatrixBlock))
ret = CLALibCBind.cbind(matBlock1, matBlock2, InfrastructureAnalyzer.getLocalParallelism());
else
ret = matBlock1.append(matBlock2, new MatrixBlock(), _type == AppendType.CBIND);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.lib.CLALibAggTernaryOp;
import org.apache.sysds.runtime.compress.lib.CLALibCBind;
import org.apache.sysds.runtime.compress.lib.CLALibMerge;
import org.apache.sysds.runtime.compress.lib.CLALibTernaryOp;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
Expand Down Expand Up @@ -3654,10 +3655,19 @@ public final MatrixBlock append(MatrixBlock that, MatrixBlock ret ) {
return append(that, ret, true); //default cbind
}

public static MatrixBlock append(List<MatrixBlock> that,MatrixBlock ret, boolean cbind, int k ){
MatrixBlock[] th = new MatrixBlock[that.size() -1];
for(int i = 0; i < that.size() -1; i++)
th[i] = that.get(i+1);
/**
* Append that list of matrixblocks to this.
*
* @param that That list.
* @param ret The output block
* @param cbind If the blocks a appended cbind
* @param k the parallelization degree
* @return the appended matrix.
*/
public static MatrixBlock append(List<MatrixBlock> that, MatrixBlock ret, boolean cbind, int k) {
MatrixBlock[] th = new MatrixBlock[that.size() - 1];
for(int i = 0; i < that.size() - 1; i++)
th[i] = that.get(i + 1);
return that.get(0).append(th, ret, cbind);
}

Expand Down Expand Up @@ -3716,6 +3726,13 @@ private final int computeNNzRow(MatrixBlock[] that, int row) {
public MatrixBlock append(MatrixBlock[] that, MatrixBlock result, boolean cbind) {
checkDimensionsForAppend(that, cbind);

for(int k = 0; k < that.length; k++)
if( that[k] instanceof CompressedMatrixBlock){
if(that.length == 1 && cbind)
return CLALibCBind.cbind(this, that[0], 1);
that[k] = CompressedMatrixBlock.getUncompressed(that[k], "Append N");
}

final int m = cbind ? rlen : combinedRows(that);
final int n = cbind ? combinedCols(that) : clen;
final long nnz = calculateCombinedNNz(that);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.apache.sysds.runtime.compress.cost.CostEstimatorBuilder;
import org.apache.sysds.runtime.compress.cost.CostEstimatorFactory;
import org.apache.sysds.runtime.compress.cost.InstructionTypeCounter;
import org.apache.sysds.runtime.compress.lib.CLALibCBind;
import org.apache.sysds.runtime.compress.workload.WTreeRoot;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.test.TestUtils;
Expand Down Expand Up @@ -395,4 +396,10 @@ public void manyRowsButNotQuite() {
MatrixBlock m2 = CompressedMatrixBlockFactory.compress(m1).getLeft();
TestUtils.compareMatricesBitAvgDistance(m1, m2, 0, 0, "no");
}


@Test(expected = Exception.class)
public void cbindWithError(){
CLALibCBind.cbind(null, new MatrixBlock[]{null}, 0);
}
}
Loading

0 comments on commit b751389

Please sign in to comment.