From 749ec564eb4e9d56b0da4227b4d85a1ada0de61a Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Fri, 31 Jan 2025 00:48:47 +0100 Subject: [PATCH] [SYSTEMDS-3824] Decompressing Transpose Sebastian Baunsgaard introduced a new CLALib for Reorg, specifically Transpose e-strauss applied minor changes: - a manual rewrite in bultin kmeans script to use argmin (reduced runtime by 18%) - added new decompressing transpose to DenseBlock from SparseBlock for ColGroupDDC - fixed bug in sparsity evaluation in decompressed transposed (switch nrow w/ ncol) - minor bug fix in regarding the cached decompression count --- scripts/builtin/kmeans.dml | 6 +- .../compress/CompressedMatrixBlock.java | 19 +-- .../compress/colgroup/ColGroupDDC.java | 16 +- .../runtime/compress/lib/CLALibReorg.java | 158 ++++++++++++++++++ .../cp/CompressionCPInstruction.java | 4 +- .../compress/configuration/CompressBase.java | 8 +- 6 files changed, 189 insertions(+), 22 deletions(-) create mode 100644 src/main/java/org/apache/sysds/runtime/compress/lib/CLALibReorg.java diff --git a/scripts/builtin/kmeans.dml b/scripts/builtin/kmeans.dml index 7fdd320a164..6052e5e6e92 100644 --- a/scripts/builtin/kmeans.dml +++ b/scripts/builtin/kmeans.dml @@ -145,9 +145,11 @@ m_kmeans = function(Matrix[Double] X, Integer k = 10, Integer runs = 10, Integer } # Find the closest centroid for each record - P = D <= minD; + # P = D <= minD; # If some records belong to multiple centroids, share them equally - P = P / rowSums (P); + # P = P / rowSums (P); + P = table(seq(1,nrow(D)), rowIndexMin(D)) + # P = table(seq(1,nrow(D)),compress(rowIndexMin(D))) # Compute the column normalization factor for P P_denom = colSums (P); # Compute new centroids as weighted averages over the records diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java index c78d651ff00..48637595741 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java +++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java @@ -59,6 +59,7 @@ import org.apache.sysds.runtime.compress.lib.CLALibMatrixMult; import org.apache.sysds.runtime.compress.lib.CLALibMerge; import org.apache.sysds.runtime.compress.lib.CLALibReplace; +import org.apache.sysds.runtime.compress.lib.CLALibReorg; import org.apache.sysds.runtime.compress.lib.CLALibReshape; import org.apache.sysds.runtime.compress.lib.CLALibRexpand; import org.apache.sysds.runtime.compress.lib.CLALibScalar; @@ -633,21 +634,7 @@ public MatrixBlock replaceOperations(MatrixValue result, double pattern, double @Override public MatrixBlock reorgOperations(ReorgOperator op, MatrixValue ret, int startRow, int startColumn, int length) { - if(op.fn instanceof SwapIndex && this.getNumColumns() == 1) { - MatrixBlock tmp = decompress(op.getNumThreads()); - long nz = tmp.setNonZeros(tmp.getNonZeros()); - tmp = new MatrixBlock(tmp.getNumColumns(), tmp.getNumRows(), tmp.getDenseBlockValues()); - tmp.setNonZeros(nz); - return tmp; - } - else { - // Allow transpose to be compressed output. In general we need to have a transposed flag on - // the compressed matrix. https://issues.apache.org/jira/browse/SYSTEMDS-3025 - String message = op.getClass().getSimpleName() + " -- " + op.fn.getClass().getSimpleName(); - MatrixBlock tmp = getUncompressed(message, op.getNumThreads()); - return tmp.reorgOperations(op, ret, startRow, startColumn, length); - } - + return CLALibReorg.reorg(this, op, (MatrixBlock) ret, startRow, startColumn, length); } public boolean isOverlapping() { @@ -1311,7 +1298,7 @@ public void allocateAndResetSparseBlock(boolean clearNNZ, SparseBlock.Type stype @Override public MatrixBlock transpose(int k) { - return getUncompressed().transpose(k); + return CLALibReorg.reorg(this, new ReorgOperator(SwapIndex.getSwapIndexFnObject(), k), null, 0, 0, 0); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java index c1b9c65f229..e55a24e56f5 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java @@ -251,7 +251,21 @@ protected void decompressToSparseBlockDenseDictionary(SparseBlock ret, int rl, i @Override protected void decompressToDenseBlockTransposedSparseDictionary(DenseBlock db, int rl, int ru, SparseBlock sb) { - throw new NotImplementedException(); + for(int i = rl; i < ru; i++) { + final int vr = _data.getIndex(i); + if(sb.isEmpty(vr)) + continue; + final int apos = sb.pos(vr); + final int alen = sb.size(vr) + apos; + final int[] aix = sb.indexes(vr); + final double[] aval = sb.values(vr); + for(int j = apos; j < alen; j++) { + final int rowOut = _colIndexes.get(aix[j]); + final double[] c = db.values(rowOut); + final int off = db.pos(rowOut); + c[off + i] += aval[j]; + } + } } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibReorg.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibReorg.java new file mode 100644 index 00000000000..d587d26c3cb --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibReorg.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.compress.lib; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.runtime.compress.CompressedMatrixBlock; +import org.apache.sysds.runtime.compress.DMLCompressionException; +import org.apache.sysds.runtime.compress.colgroup.AColGroup; +import org.apache.sysds.runtime.data.DenseBlock; +import org.apache.sysds.runtime.data.SparseBlock; +import org.apache.sysds.runtime.data.SparseBlockMCSR; +import org.apache.sysds.runtime.functionobjects.SwapIndex; +import org.apache.sysds.runtime.matrix.data.LibMatrixReorg; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.operators.ReorgOperator; +import org.apache.sysds.runtime.util.CommonThreadPool; + +public class CLALibReorg { + + protected static final Log LOG = LogFactory.getLog(CLALibReorg.class.getName()); + + public static boolean warned = false; + + public static MatrixBlock reorg(CompressedMatrixBlock cmb, ReorgOperator op, MatrixBlock ret, int startRow, + int startColumn, int length) { + // SwapIndex is transpose + if(op.fn instanceof SwapIndex && cmb.getNumColumns() == 1) { + MatrixBlock tmp = cmb.decompress(op.getNumThreads()); + long nz = tmp.setNonZeros(tmp.getNonZeros()); + if(tmp.isInSparseFormat()) + return LibMatrixReorg.transpose(tmp); // edge case... + else + tmp = new MatrixBlock(tmp.getNumColumns(), tmp.getNumRows(), tmp.getDenseBlockValues()); + tmp.setNonZeros(nz); + return tmp; + } + else if(op.fn instanceof SwapIndex) { + MatrixBlock tmp = cmb.getCachedDecompressed(); + if(tmp != null) + return tmp.reorgOperations(op, ret, startRow, startColumn, length); + // Allow transpose to be compressed output. In general we need to have a transposed flag on + // the compressed matrix. https://issues.apache.org/jira/browse/SYSTEMDS-3025 + return transpose(cmb, ret, op.getNumThreads()); + } + else { + String message = !warned ? op.getClass().getSimpleName() + " -- " + op.fn.getClass().getSimpleName() : null; + MatrixBlock tmp = cmb.getUncompressed(message, op.getNumThreads()); + warned = true; + return tmp.reorgOperations(op, ret, startRow, startColumn, length); + } + } + + private static MatrixBlock transpose(CompressedMatrixBlock cmb, MatrixBlock ret, int k) { + + final long nnz = cmb.getNonZeros(); + final int nRow = cmb.getNumRows(); + final int nCol = cmb.getNumColumns(); + final boolean sparseOut = MatrixBlock.evalSparseFormatInMemory(nCol,nRow, nnz); + if(sparseOut) + return transposeSparse(cmb, ret, k, nRow, nCol, nnz); + else + return transposeDense(cmb, ret, k, nRow, nCol, nnz); + } + + private static MatrixBlock transposeSparse(CompressedMatrixBlock cmb, MatrixBlock ret, int k, int nRow, int nCol, + long nnz) { + if(ret == null) + ret = new MatrixBlock(nCol, nRow, true, nnz); + else + ret.reset(nCol, nRow, true, nnz); + + ret.allocateAndResetSparseBlock(true, SparseBlock.Type.MCSR); + + final int nColOut = ret.getNumColumns(); + + if(k > 1 && cmb.getColGroups().size() > 1) + decompressToTransposedSparseParallel((SparseBlockMCSR) ret.getSparseBlock(), cmb.getColGroups(), nColOut, k); + else + decompressToTransposedSparseSingleThread((SparseBlockMCSR) ret.getSparseBlock(), cmb.getColGroups(), nColOut); + + return ret; + } + + private static MatrixBlock transposeDense(CompressedMatrixBlock cmb, MatrixBlock ret, int k, int nRow, int nCol, + long nnz) { + if(ret == null) + ret = new MatrixBlock(nCol, nRow, false, nnz); + else + ret.reset(nCol, nRow, false, nnz); + + // TODO: parallelize + ret.allocateDenseBlock(); + + decompressToTransposedDense(ret.getDenseBlock(), cmb.getColGroups(), nRow, 0, nRow); + return ret; + } + + private static void decompressToTransposedDense(DenseBlock ret, List groups, int rlen, int rl, int ru) { + for(int i = 0; i < groups.size(); i++) { + AColGroup g = groups.get(i); + g.decompressToDenseBlockTransposed(ret, rl, ru); + } + } + + private static void decompressToTransposedSparseSingleThread(SparseBlockMCSR ret, List groups, + int nColOut) { + for(int i = 0; i < groups.size(); i++) { + AColGroup g = groups.get(i); + g.decompressToSparseBlockTransposed(ret, nColOut); + } + } + + private static void decompressToTransposedSparseParallel(SparseBlockMCSR ret, List groups, int nColOut, + int k) { + final ExecutorService pool = CommonThreadPool.get(k); + try { + final List> tasks = new ArrayList<>(groups.size()); + + for(int i = 0; i < groups.size(); i++) { + final AColGroup g = groups.get(i); + tasks.add(pool.submit(() -> g.decompressToSparseBlockTransposed(ret, nColOut))); + } + + for(Future f : tasks) + f.get(); + + } + catch(Exception e) { + throw new DMLCompressionException("Failed to parallel decompress transpose sparse", e); + } + finally { + pool.shutdown(); + } + } +} diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java index 7d0d9f78704..4216385b722 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java @@ -138,7 +138,9 @@ private void processSimpleCompressInstruction(ExecutionContext ec) { else if(ec.isMatrixObject(input1.getName())) processMatrixBlockCompression(ec, ec.getMatrixInput(input1.getName()), _numThreads, root); else { - throw new NotImplementedException("Not supported other types of input for compression than frame and matrix"); + LOG.warn("Compression on Scalar should not happen"); + ScalarObject Scalar = ec.getScalarInput(input1); + ec.setScalarOutput(output.getName(),Scalar); } } diff --git a/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressBase.java b/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressBase.java index 3d61b4942c7..1ddfc09258b 100644 --- a/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressBase.java +++ b/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressBase.java @@ -32,6 +32,8 @@ import org.apache.sysds.utils.Statistics; import org.junit.Assert; +import java.io.ByteArrayOutputStream; + public abstract class CompressBase extends AutomatedTestBase { // private static final Log LOG = LogFactory.getLog(CompressBase.class.getName()); @@ -66,7 +68,8 @@ public void compressTest(int rows, int cols, double sparsity, ExecType instType, fullDMLScriptName = SCRIPT_DIR + "/functions/compress/compress_" + name + ".dml"; programArgs = new String[] {"-stats", "100", "-nvargs", "A=" + input("A")}; - String out = runTest(null).toString(); + ByteArrayOutputStream tmp = runTest(null); + String out = tmp != null ? runTest(null).toString() : ""; int decompressCount = DMLCompressionStatistics.getDecompressionCount(); long compressionCount = (instType == ExecType.SPARK) ? Statistics @@ -74,7 +77,8 @@ public void compressTest(int rows, int cols, double sparsity, ExecType instType, DMLCompressionStatistics.reset(); Assert.assertEquals(out + "\ncompression count wrong : ", compressionCount, compressionCountsExpected); - Assert.assertTrue(out + "\nDecompression count wrong : ", + Assert.assertTrue(out + "\nDecompression count wrong : " + decompressCount + + (decompressionCountExpected >= 0 ? " [expected: " + decompressionCountExpected+ "]" : ""), decompressionCountExpected >= 0 ? decompressionCountExpected == decompressCount : decompressCount > 1); }