Skip to content

Commit

Permalink
Updates to CompressedMatrixBlock
Browse files Browse the repository at this point in the history
  • Loading branch information
Baunsgaard committed Jan 21, 2025
1 parent 399a459 commit fb0d81a
Showing 1 changed file with 64 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@
import java.io.ObjectOutput;
import java.lang.ref.SoftReference;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;

Expand All @@ -42,9 +44,11 @@
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.AColGroup.CompressionType;
import org.apache.sysds.runtime.compress.colgroup.ADictBasedColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty;
import org.apache.sysds.runtime.compress.colgroup.ColGroupIO;
import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed;
import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary;
import org.apache.sysds.runtime.compress.lib.CLALibAppend;
import org.apache.sysds.runtime.compress.lib.CLALibBinaryCellOp;
import org.apache.sysds.runtime.compress.lib.CLALibCMOps;
Expand Down Expand Up @@ -99,14 +103,13 @@ public class CompressedMatrixBlock extends MatrixBlock {
private static final Log LOG = LogFactory.getLog(CompressedMatrixBlock.class.getName());
private static final long serialVersionUID = 73193720143154058L;

/**
* Debugging flag for Compressed Matrices
*/
/** Debugging flag for Compressed Matrices */
public static boolean debug = false;

/**
* Column groups
*/
/** Disallow caching of uncompressed Block */
public static boolean allowCachingUncompressed = true;

/** Column groups */
protected transient List<AColGroup> _colGroups;

/**
Expand All @@ -119,6 +122,9 @@ public class CompressedMatrixBlock extends MatrixBlock {
*/
protected transient SoftReference<MatrixBlock> decompressedVersion;

/** Cached Memory size */
protected transient long cachedMemorySize = -1;

public CompressedMatrixBlock() {
super(true);
sparse = false;
Expand Down Expand Up @@ -169,7 +175,9 @@ protected CompressedMatrixBlock(MatrixBlock uncompressedMatrixBlock) {
clen = uncompressedMatrixBlock.getNumColumns();
sparse = false;
nonZeros = uncompressedMatrixBlock.getNonZeros();
decompressedVersion = new SoftReference<>(uncompressedMatrixBlock);
if(!(uncompressedMatrixBlock instanceof CompressedMatrixBlock)) {
decompressedVersion = new SoftReference<>(uncompressedMatrixBlock);
}
}

/**
Expand All @@ -189,6 +197,7 @@ public CompressedMatrixBlock(int rl, int cl, long nnz, boolean overlapping, List
this.nonZeros = nnz;
this.overlappingColGroups = overlapping;
this._colGroups = groups;
getInMemorySize(); // cache memory size
}

@Override
Expand All @@ -204,6 +213,7 @@ public void reset(int rl, int cl, boolean sp, long estnnz, double val) {
* @param cg The column group to use after.
*/
public void allocateColGroup(AColGroup cg) {
cachedMemorySize = -1;
_colGroups = new ArrayList<>(1);
_colGroups.add(cg);
}
Expand Down Expand Up @@ -270,6 +280,12 @@ public synchronized MatrixBlock decompress(int k) {

ret = CLALibDecompress.decompress(this, k);

if(ret.getNonZeros() <= 0) {
LOG.warn("Decompress incorrectly set nnz to 0 or -1");
ret.recomputeNonZeros(k);
}
ret.examSparsity(k);

// Set soft reference to the decompressed version
decompressedVersion = new SoftReference<>(ret);

Expand All @@ -290,7 +306,7 @@ public void putInto(MatrixBlock target, int rowOffset, int colOffset, boolean sp
* @return The cached decompressed matrix, if it does not exist return null
*/
public MatrixBlock getCachedDecompressed() {
if(decompressedVersion != null) {
if( allowCachingUncompressed && decompressedVersion != null) {
final MatrixBlock mb = decompressedVersion.get();
if(mb != null) {
DMLCompressionStatistics.addDecompressCacheCount();
Expand All @@ -302,6 +318,7 @@ public MatrixBlock getCachedDecompressed() {
}

public CompressedMatrixBlock squash(int k) {
cachedMemorySize = -1;
return CLALibSquash.squash(this, k);
}

Expand Down Expand Up @@ -377,12 +394,27 @@ public long estimateSizeInMemory() {
* @return an upper bound on the memory used to store this compressed block considering class overhead.
*/
public long estimateCompressedSizeInMemory() {
long total = baseSizeInMemory();

for(AColGroup grp : _colGroups)
total += grp.estimateInMemorySize();
if(cachedMemorySize <= -1L) {

long total = baseSizeInMemory();
// take into consideration duplicate dictionaries
Set<IDictionary> dicts = new HashSet<>();
for(AColGroup grp : _colGroups){
if(grp instanceof ADictBasedColGroup){
IDictionary dg = ((ADictBasedColGroup) grp).getDictionary();
if(dicts.contains(dg))
total -= dg.getInMemorySize();
dicts.add(dg);
}
total += grp.estimateInMemorySize();
}
cachedMemorySize = total;
return total;

return total;
}
else
return cachedMemorySize;
}

public static long baseSizeInMemory() {
Expand All @@ -392,6 +424,7 @@ public static long baseSizeInMemory() {
total += 8; // Col Group Ref
total += 8; // v reference
total += 8; // soft reference to decompressed version
total += 8; // long cached memory size
total += 1 + 7; // Booleans plus padding

total += 40; // Col Group Array List
Expand Down Expand Up @@ -431,6 +464,7 @@ public long estimateSizeOnDisk() {

@Override
public void readFields(DataInput in) throws IOException {
cachedMemorySize = -1;
// deserialize compressed block
rlen = in.readInt();
clen = in.readInt();
Expand Down Expand Up @@ -736,8 +770,22 @@ public MatrixBlock rexpandOperations(MatrixBlock ret, double max, boolean rows,

@Override
public boolean isEmptyBlock(boolean safe) {
final long nonZeros = getNonZeros();
return _colGroups == null || nonZeros == 0 || (nonZeros == -1 && recomputeNonZeros() == 0);
if(nonZeros > 1)
return false;
else if(_colGroups == null || nonZeros == 0)
return true;
else{
if(nonZeros == -1){
// try to use column groups
for(AColGroup g : _colGroups)
if(!g.isEmpty())
return false;
// Otherwise recompute non zeros.
recomputeNonZeros();
}

return getNonZeros() == 0;
}
}

@Override
Expand Down Expand Up @@ -1045,6 +1093,7 @@ public void copy(int rl, int ru, int cl, int cu, MatrixBlock src, boolean awareD
}

private void copyCompressedMatrix(CompressedMatrixBlock that) {
cachedMemorySize = -1;
this.rlen = that.getNumRows();
this.clen = that.getNumColumns();
this.sparseBlock = null;
Expand All @@ -1059,7 +1108,7 @@ private void copyCompressedMatrix(CompressedMatrixBlock that) {
}

public SoftReference<MatrixBlock> getSoftReferenceToDecompressed() {
return decompressedVersion;
return allowCachingUncompressed ? decompressedVersion : null;
}

public void clearSoftReferenceToDecompressed() {
Expand Down

0 comments on commit fb0d81a

Please sign in to comment.