Skip to content

Commit

Permalink
[SYSTEMDS-3824] Decompressing Transpose
Browse files Browse the repository at this point in the history
Sebastian Baunsgaard <[email protected]> introduced a new CLALib for Reorg, specifically Transpose

e-strauss <[email protected]> applied minor changes:
- a manual rewrite in bultin kmeans script to use argmin (reduced runtime by 18%)
- added new decompressing transpose to DenseBlock from SparseBlock for ColGroupDDC
- fixed bug in sparsity evaluation in decompressed transposed (switch nrow w/ ncol)
- minor bug fix in regarding the cached decompression count
  • Loading branch information
Baunsgaard authored and e-strauss committed Feb 4, 2025
1 parent e022eaf commit 749ec56
Show file tree
Hide file tree
Showing 6 changed files with 189 additions and 22 deletions.
6 changes: 4 additions & 2 deletions scripts/builtin/kmeans.dml
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,11 @@ m_kmeans = function(Matrix[Double] X, Integer k = 10, Integer runs = 10, Integer
}

# Find the closest centroid for each record
P = D <= minD;
# P = D <= minD;
# If some records belong to multiple centroids, share them equally
P = P / rowSums (P);
# P = P / rowSums (P);
P = table(seq(1,nrow(D)), rowIndexMin(D))
# P = table(seq(1,nrow(D)),compress(rowIndexMin(D)))
# Compute the column normalization factor for P
P_denom = colSums (P);
# Compute new centroids as weighted averages over the records
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
import org.apache.sysds.runtime.compress.lib.CLALibMatrixMult;
import org.apache.sysds.runtime.compress.lib.CLALibMerge;
import org.apache.sysds.runtime.compress.lib.CLALibReplace;
import org.apache.sysds.runtime.compress.lib.CLALibReorg;
import org.apache.sysds.runtime.compress.lib.CLALibReshape;
import org.apache.sysds.runtime.compress.lib.CLALibRexpand;
import org.apache.sysds.runtime.compress.lib.CLALibScalar;
Expand Down Expand Up @@ -633,21 +634,7 @@ public MatrixBlock replaceOperations(MatrixValue result, double pattern, double

@Override
public MatrixBlock reorgOperations(ReorgOperator op, MatrixValue ret, int startRow, int startColumn, int length) {
if(op.fn instanceof SwapIndex && this.getNumColumns() == 1) {
MatrixBlock tmp = decompress(op.getNumThreads());
long nz = tmp.setNonZeros(tmp.getNonZeros());
tmp = new MatrixBlock(tmp.getNumColumns(), tmp.getNumRows(), tmp.getDenseBlockValues());
tmp.setNonZeros(nz);
return tmp;
}
else {
// Allow transpose to be compressed output. In general we need to have a transposed flag on
// the compressed matrix. https://issues.apache.org/jira/browse/SYSTEMDS-3025
String message = op.getClass().getSimpleName() + " -- " + op.fn.getClass().getSimpleName();
MatrixBlock tmp = getUncompressed(message, op.getNumThreads());
return tmp.reorgOperations(op, ret, startRow, startColumn, length);
}

return CLALibReorg.reorg(this, op, (MatrixBlock) ret, startRow, startColumn, length);
}

public boolean isOverlapping() {
Expand Down Expand Up @@ -1311,7 +1298,7 @@ public void allocateAndResetSparseBlock(boolean clearNNZ, SparseBlock.Type stype

@Override
public MatrixBlock transpose(int k) {
return getUncompressed().transpose(k);
return CLALibReorg.reorg(this, new ReorgOperator(SwapIndex.getSwapIndexFnObject(), k), null, 0, 0, 0);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,21 @@ protected void decompressToSparseBlockDenseDictionary(SparseBlock ret, int rl, i

@Override
protected void decompressToDenseBlockTransposedSparseDictionary(DenseBlock db, int rl, int ru, SparseBlock sb) {
throw new NotImplementedException();
for(int i = rl; i < ru; i++) {
final int vr = _data.getIndex(i);
if(sb.isEmpty(vr))
continue;
final int apos = sb.pos(vr);
final int alen = sb.size(vr) + apos;
final int[] aix = sb.indexes(vr);
final double[] aval = sb.values(vr);
for(int j = apos; j < alen; j++) {
final int rowOut = _colIndexes.get(aix[j]);
final double[] c = db.values(rowOut);
final int off = db.pos(rowOut);
c[off + i] += aval[j];
}
}
}

@Override
Expand Down
158 changes: 158 additions & 0 deletions src/main/java/org/apache/sysds/runtime/compress/lib/CLALibReorg.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysds.runtime.compress.lib;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.data.SparseBlockMCSR;
import org.apache.sysds.runtime.functionobjects.SwapIndex;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
import org.apache.sysds.runtime.util.CommonThreadPool;

public class CLALibReorg {

protected static final Log LOG = LogFactory.getLog(CLALibReorg.class.getName());

public static boolean warned = false;

public static MatrixBlock reorg(CompressedMatrixBlock cmb, ReorgOperator op, MatrixBlock ret, int startRow,
int startColumn, int length) {
// SwapIndex is transpose
if(op.fn instanceof SwapIndex && cmb.getNumColumns() == 1) {
MatrixBlock tmp = cmb.decompress(op.getNumThreads());
long nz = tmp.setNonZeros(tmp.getNonZeros());
if(tmp.isInSparseFormat())
return LibMatrixReorg.transpose(tmp); // edge case...
else
tmp = new MatrixBlock(tmp.getNumColumns(), tmp.getNumRows(), tmp.getDenseBlockValues());
tmp.setNonZeros(nz);
return tmp;
}
else if(op.fn instanceof SwapIndex) {
MatrixBlock tmp = cmb.getCachedDecompressed();
if(tmp != null)
return tmp.reorgOperations(op, ret, startRow, startColumn, length);
// Allow transpose to be compressed output. In general we need to have a transposed flag on
// the compressed matrix. https://issues.apache.org/jira/browse/SYSTEMDS-3025
return transpose(cmb, ret, op.getNumThreads());
}
else {
String message = !warned ? op.getClass().getSimpleName() + " -- " + op.fn.getClass().getSimpleName() : null;
MatrixBlock tmp = cmb.getUncompressed(message, op.getNumThreads());
warned = true;
return tmp.reorgOperations(op, ret, startRow, startColumn, length);
}
}

private static MatrixBlock transpose(CompressedMatrixBlock cmb, MatrixBlock ret, int k) {

final long nnz = cmb.getNonZeros();
final int nRow = cmb.getNumRows();
final int nCol = cmb.getNumColumns();
final boolean sparseOut = MatrixBlock.evalSparseFormatInMemory(nCol,nRow, nnz);
if(sparseOut)
return transposeSparse(cmb, ret, k, nRow, nCol, nnz);
else
return transposeDense(cmb, ret, k, nRow, nCol, nnz);
}

private static MatrixBlock transposeSparse(CompressedMatrixBlock cmb, MatrixBlock ret, int k, int nRow, int nCol,
long nnz) {
if(ret == null)
ret = new MatrixBlock(nCol, nRow, true, nnz);
else
ret.reset(nCol, nRow, true, nnz);

ret.allocateAndResetSparseBlock(true, SparseBlock.Type.MCSR);

final int nColOut = ret.getNumColumns();

if(k > 1 && cmb.getColGroups().size() > 1)
decompressToTransposedSparseParallel((SparseBlockMCSR) ret.getSparseBlock(), cmb.getColGroups(), nColOut, k);
else
decompressToTransposedSparseSingleThread((SparseBlockMCSR) ret.getSparseBlock(), cmb.getColGroups(), nColOut);

return ret;
}

private static MatrixBlock transposeDense(CompressedMatrixBlock cmb, MatrixBlock ret, int k, int nRow, int nCol,
long nnz) {
if(ret == null)
ret = new MatrixBlock(nCol, nRow, false, nnz);
else
ret.reset(nCol, nRow, false, nnz);

// TODO: parallelize
ret.allocateDenseBlock();

decompressToTransposedDense(ret.getDenseBlock(), cmb.getColGroups(), nRow, 0, nRow);
return ret;
}

private static void decompressToTransposedDense(DenseBlock ret, List<AColGroup> groups, int rlen, int rl, int ru) {
for(int i = 0; i < groups.size(); i++) {
AColGroup g = groups.get(i);
g.decompressToDenseBlockTransposed(ret, rl, ru);
}
}

private static void decompressToTransposedSparseSingleThread(SparseBlockMCSR ret, List<AColGroup> groups,
int nColOut) {
for(int i = 0; i < groups.size(); i++) {
AColGroup g = groups.get(i);
g.decompressToSparseBlockTransposed(ret, nColOut);
}
}

private static void decompressToTransposedSparseParallel(SparseBlockMCSR ret, List<AColGroup> groups, int nColOut,
int k) {
final ExecutorService pool = CommonThreadPool.get(k);
try {
final List<Future<?>> tasks = new ArrayList<>(groups.size());

for(int i = 0; i < groups.size(); i++) {
final AColGroup g = groups.get(i);
tasks.add(pool.submit(() -> g.decompressToSparseBlockTransposed(ret, nColOut)));
}

for(Future<?> f : tasks)
f.get();

}
catch(Exception e) {
throw new DMLCompressionException("Failed to parallel decompress transpose sparse", e);
}
finally {
pool.shutdown();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,9 @@ private void processSimpleCompressInstruction(ExecutionContext ec) {
else if(ec.isMatrixObject(input1.getName()))
processMatrixBlockCompression(ec, ec.getMatrixInput(input1.getName()), _numThreads, root);
else {
throw new NotImplementedException("Not supported other types of input for compression than frame and matrix");
LOG.warn("Compression on Scalar should not happen");
ScalarObject Scalar = ec.getScalarInput(input1);
ec.setScalarOutput(output.getName(),Scalar);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
import org.apache.sysds.utils.Statistics;
import org.junit.Assert;

import java.io.ByteArrayOutputStream;

public abstract class CompressBase extends AutomatedTestBase {
// private static final Log LOG = LogFactory.getLog(CompressBase.class.getName());

Expand Down Expand Up @@ -66,15 +68,17 @@ public void compressTest(int rows, int cols, double sparsity, ExecType instType,
fullDMLScriptName = SCRIPT_DIR + "/functions/compress/compress_" + name + ".dml";
programArgs = new String[] {"-stats", "100", "-nvargs", "A=" + input("A")};

String out = runTest(null).toString();
ByteArrayOutputStream tmp = runTest(null);
String out = tmp != null ? runTest(null).toString() : "";

int decompressCount = DMLCompressionStatistics.getDecompressionCount();
long compressionCount = (instType == ExecType.SPARK) ? Statistics
.getCPHeavyHitterCount("sp_compress") : Statistics.getCPHeavyHitterCount("compress");
DMLCompressionStatistics.reset();

Assert.assertEquals(out + "\ncompression count wrong : ", compressionCount, compressionCountsExpected);
Assert.assertTrue(out + "\nDecompression count wrong : ",
Assert.assertTrue(out + "\nDecompression count wrong : " + decompressCount +
(decompressionCountExpected >= 0 ? " [expected: " + decompressionCountExpected+ "]" : ""),
decompressionCountExpected >= 0 ? decompressionCountExpected == decompressCount : decompressCount > 1);

}
Expand Down

0 comments on commit 749ec56

Please sign in to comment.