Skip to content

Commit

Permalink
[SYSTEMDS-3660] GPU cache eviction operator and related rewrite
Browse files Browse the repository at this point in the history
This patch introduces a new operator, _evict, to clean up the free
pointer cached in the lineage cache. A shift in the allocation pattern
leads to large eviction overhead and memory fragmentation. To address
that, we speculatively clear a fraction of the free pointers. Currently,
we place a _evict before every mini-batch processing.
  • Loading branch information
phaniarnab committed Dec 28, 2023
1 parent 4b0133f commit 33149f8
Show file tree
Hide file tree
Showing 12 changed files with 953 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/main/java/org/apache/sysds/common/Types.java
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ public enum OpOp1 {
CEIL, CHOLESKY, COS, COSH, CUMMAX, CUMMIN, CUMPROD, CUMSUM,
CUMSUMPROD, DETECTSCHEMA, COLNAMES, EIGEN, EXISTS, EXP, FLOOR, INVERSE,
IQM, ISNA, ISNAN, ISINF, LENGTH, LINEAGE, LOG, NCOL, NOT, NROW,
MEDIAN, PREFETCH, PRINT, ROUND, SIN, SINH, SIGN, SOFTMAX, SQRT, STOP,
MEDIAN, PREFETCH, PRINT, ROUND, SIN, SINH, SIGN, SOFTMAX, SQRT, STOP, _EVICT,
SVD, TAN, TANH, TYPEOF, TRIGREMOTE,
//fused ML-specific operators for performance
SPROP, //sample proportion: P * (1 - P)
Expand Down
4 changes: 4 additions & 0 deletions src/main/java/org/apache/sysds/conf/ConfigurationManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,10 @@ public static boolean isRuleBasedGPUPlacement() {
|| OptimizerUtils.RULE_BASED_GPU_EXEC));
}

public static boolean isAutoEvictionEnabled() {
return OptimizerUtils.AUTO_GPU_CACHE_EVICTION;
}

public static ILinearize.DagLinearization getLinearizationOrder() {
if (OptimizerUtils.COST_BASED_ORDERING)
return ILinearize.DagLinearization.AUTO;
Expand Down
6 changes: 6 additions & 0 deletions src/main/java/org/apache/sysds/hops/OptimizerUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,12 @@ public enum MemoryManager {
*/
public static boolean RULE_BASED_GPU_EXEC = false;

/**
* Automatic placement of GPU lineage cache eviction
*/

public static boolean AUTO_GPU_CACHE_EVICTION = true;

//////////////////////
// Optimizer levels //
//////////////////////
Expand Down
3 changes: 3 additions & 0 deletions src/main/java/org/apache/sysds/hops/UnaryOp.java
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ public Lop constructLops()
case LOCAL:
ret = new Local(input.constructLops(), getDataType(), getValueType());
break;
case _EVICT:
ret = new UnaryCP(input.constructLops(), _op, getDataType(), getValueType());
break;
default:
final boolean isScalarIn = getInput().get(0).getDataType() == DataType.SCALAR;
if(getDataType() == DataType.SCALAR // value type casts or matrix to scalar
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ public LopRewriter() {
_lopSBRuleSet.add(new RewriteAddBroadcastLop());
_lopSBRuleSet.add(new RewriteAddChkpointLop());
_lopSBRuleSet.add(new RewriteAddChkpointInLoop());
_lopSBRuleSet.add(new RewriteAddGPUEvictLop());
// TODO: A rewrite pass to remove less effective chkpoints
// Last rewrite to reset Lop IDs in a depth-first manner
_lopSBRuleSet.add(new RewriteFixIDs());
Expand Down
115 changes: 115 additions & 0 deletions src/main/java/org/apache/sysds/lops/rewrite/RewriteAddGPUEvictLop.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysds.lops.rewrite;

import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.UnaryOp;
import org.apache.sysds.lops.BinaryScalar;
import org.apache.sysds.lops.Data;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.OperatorOrderingUtils;
import org.apache.sysds.lops.RightIndex;
import org.apache.sysds.lops.UnaryCP;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.VariableSet;
import org.apache.sysds.runtime.lineage.LineageCacheConfig;

import java.util.ArrayList;
import java.util.List;

public class RewriteAddGPUEvictLop extends LopRewriteRule
{
@Override
public List<StatementBlock> rewriteLOPinStatementBlock(StatementBlock sb) {
// TODO: Move this as a Statement block rewrite
if (!ConfigurationManager.isAutoEvictionEnabled())
return List.of(sb);

if (sb == null || !(sb instanceof ForStatementBlock)
|| !DMLScript.USE_ACCELERATOR || LineageCacheConfig.ReuseCacheType.isNone())
return List.of(sb);

// Collect the LOPs
StatementBlock csb = ((ForStatement) sb.getStatement(0)).getBody().get(0);
ArrayList<Lop> lops = OperatorOrderingUtils.getLopList(csb);

// Check if this loop is for mini-batch processing
boolean isMiniBatch = findMiniBatchSlicing(lops);

// Insert statement block with _evict instruction before the loop
ArrayList<StatementBlock> ret = new ArrayList<>();
if (isMiniBatch) {
int evictFrac = 100;
StatementBlock sb0 = new StatementBlock();
sb0.setDMLProg(sb.getDMLProg());
sb0.setParseInfo(sb);
sb0.setLiveIn(new VariableSet());
sb0.setLiveOut(new VariableSet());
// Create both lops and hops (hops for recompilation)
// TODO: Add another input for the backend (GPU/CPU/Spark)
ArrayList<Lop> newlops = new ArrayList<>();
ArrayList<Hop> newhops = new ArrayList<>();
Lop fr = Data.createLiteralLop(Types.ValueType.INT64, Integer.toString(evictFrac));
fr.getOutputParameters().setDimensions(0, 0, 0, -1);
UnaryCP evict = new UnaryCP(fr, Types.OpOp1._EVICT, fr.getDataType(), fr.getValueType(), Types.ExecType.CP);
Hop in = new LiteralOp(evictFrac);
Hop evictHop = new UnaryOp("tmp", Types.DataType.SCALAR, Types.ValueType.INT64, Types.OpOp1._EVICT, in);
newlops.add(evict);
newhops.add(evictHop);
sb0.setLops(newlops);
sb0.setHops(newhops);
ret.add(sb0);
}
ret.add(sb);

return ret;
}

@Override
public List<StatementBlock> rewriteLOPinStatementBlocks(List<StatementBlock> sbs) {
return sbs;
}

// To verify mini-batch processing, match the below pattern
// beg = ((i-1) * batch_size) %% N + 1;
// end = min(N, beg+batch_size-1);
// X_batch = X[beg:end];
private boolean findMiniBatchSlicing(ArrayList<Lop> lops) {
for (Lop l : lops) {
if (l instanceof RightIndex) {
ArrayList<Lop> inputs = l.getInputs();
if (inputs.get(0) instanceof Data && ((Data) inputs.get(0)).isTransientRead()
&& inputs.get(0).getInputs().size() == 0 //input1 is the dataset
&& inputs.get(1) instanceof BinaryScalar //input2 is beg
&& ((BinaryScalar) inputs.get(1)).getOperationType() == Types.OpOp2.PLUS
&& inputs.get(2) instanceof BinaryScalar //input3 is end
&& ((BinaryScalar) inputs.get(2)).getOperationType() == Types.OpOp2.MIN)
return true;
}
}
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import org.apache.sysds.runtime.instructions.cp.DataGenCPInstruction;
import org.apache.sysds.runtime.instructions.cp.DeCompressionCPInstruction;
import org.apache.sysds.runtime.instructions.cp.DnnCPInstruction;
import org.apache.sysds.runtime.instructions.cp.EvictCPInstruction;
import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;
import org.apache.sysds.runtime.instructions.cp.IndexingCPInstruction;
import org.apache.sysds.runtime.instructions.cp.LocalCPInstruction;
Expand Down Expand Up @@ -337,6 +338,7 @@ public class CPInstructionParser extends InstructionParser {
String2CPInstructionType.put( DeCompression.OPCODE, CPType.DeCompression);
String2CPInstructionType.put( "spoof", CPType.SpoofFused);
String2CPInstructionType.put( "prefetch", CPType.Prefetch);
String2CPInstructionType.put( "_evict", CPType.EvictLineageCache);
String2CPInstructionType.put( "broadcast", CPType.Broadcast);
String2CPInstructionType.put( "trigremote", CPType.TrigRemote);
String2CPInstructionType.put( Local.OPCODE, CPType.Local);
Expand Down Expand Up @@ -483,6 +485,9 @@ public static CPInstruction parseSingleInstruction ( CPType cptype, String str )

case Broadcast:
return BroadcastCPInstruction.parseInstruction(str);

case EvictLineageCache:
return EvictCPInstruction.parseInstruction(str);

default:
throw new DMLRuntimeException("Invalid CP Instruction Type: " + cptype );
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ public enum CPType {
Builtin, Reorg, Variable, FCall, Append, Rand, QSort, QPick, Local,
MatrixIndexing, MMTSJ, PMMJ, MMChain, Reshape, Partition, Compression, DeCompression, SpoofFused,
StringInit, CentralMoment, Covariance, UaggOuterChain, Dnn, Sql, Prefetch, Broadcast, TrigRemote,
EvictLineageCache,
NoOp,
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysds.runtime.instructions.cp;

import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.lineage.LineageGPUCacheEviction;
import org.apache.sysds.runtime.matrix.operators.Operator;

public class EvictCPInstruction extends UnaryCPInstruction
{
private EvictCPInstruction(Operator op, CPOperand in, CPOperand out, String opcode, String istr) {
super(CPType.EvictLineageCache, op, in, out, opcode, istr);
}

public static EvictCPInstruction parseInstruction(String str) {
InstructionUtils.checkNumFields(str, 3);
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
String opcode = parts[0];
CPOperand in = new CPOperand(parts[1]);
CPOperand out = new CPOperand(parts[2]);
return new EvictCPInstruction(null, in, out, opcode, str);
}

@Override
public void processInstruction(ExecutionContext ec) {
// Evict fraction of cached objects
ScalarObject fr = ec.getScalarInput(input1);
double evictFrac = ((double) fr.getLongValue()) / 100;
LineageGPUCacheEviction.removeAllEntries(evictFrac);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -131,17 +131,23 @@ private static void removeEntry(LineageCacheEntry e) {
}
}

public static void removeAllEntries() {
// Speculative eviction
public static void removeAllEntries(double evictFrac) {
List<Long> sizes = new ArrayList<>(freeQueues.keySet());
for (Long size : sizes) {
TreeSet<LineageCacheEntry> freeList = freeQueues.get(size);
int evictLim = (int) (freeList.size() * evictFrac);
int evictCount = 1;
LineageCacheEntry le = pollFirstFreeEntry(size);
while (le != null) {
// Free the pointer
_gpuContext.getMemoryManager().guardedCudaFree(le.getGPUPointer());
if (DMLScript.STATISTICS)
LineageCacheStatistics.incrementGpuDel();
le = pollFirstFreeEntry(size);
if (evictCount > evictLim)
break;
evictCount++;
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
import org.apache.sysds.utils.Statistics;
import org.junit.Assert;
import org.junit.Assume;
import org.junit.BeforeClass;
import org.junit.Test;
Expand All @@ -39,7 +41,7 @@ public class GPULineageCacheEvictionTest extends AutomatedTestBase{

protected static final String TEST_DIR = "functions/lineage/";
protected static final String TEST_NAME = "GPUCacheEviction";
protected static final int TEST_VARIANTS = 5;
protected static final int TEST_VARIANTS = 6;
protected String TEST_CLASS_DIR = TEST_DIR + GPULineageCacheEvictionTest.class.getSimpleName() + "/";

@BeforeClass
Expand Down Expand Up @@ -80,6 +82,11 @@ public void TransferLearningVGG() { //transfer learning and reuse
testLineageTraceExec(TEST_NAME+"5");
}

@Test
public void TransferLearning3Models() { //transfer learning and reuse (AlexNet,VGG,ResNet)
testLineageTraceExec(TEST_NAME+"6");
}


private void testLineageTraceExec(String testname) {
System.out.println("------------ BEGIN " + testname + "------------");
Expand Down Expand Up @@ -117,6 +124,13 @@ private void testLineageTraceExec(String testname) {

//compare results
TestUtils.compareMatrices(R_orig, R_reused, 1e-6, "Origin", "Reused");

//Match _evict count
if (testname.equalsIgnoreCase(TEST_NAME+"6")) {
long exp_numev = 3;
long numev = Statistics.getCPHeavyHitterCount("_evict");
Assert.assertTrue("Violated Prefetch instruction count: "+numev, numev == exp_numev);
}
}
}

Loading

0 comments on commit 33149f8

Please sign in to comment.