Skip to content

Commit

Permalink
Build a single post-order iterator for faster optimization.
Browse files Browse the repository at this point in the history
  • Loading branch information
sukritkalra committed Dec 23, 2023
1 parent 3e0f9c9 commit f06909c
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 92 deletions.
16 changes: 13 additions & 3 deletions schedulers/tetrisched/include/tetrisched/OptimizationPasses.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define _TETRISCHED_OPTIMIZATION_PASSES_HPP_

#include <cmath>
#include <deque>
#include <string>

#include "tetrisched/Expression.hpp"
Expand Down Expand Up @@ -48,19 +49,28 @@ class OptimizationPass {
using OptimizationPassPtr = std::shared_ptr<OptimizationPass>;

class CriticalPathOptimizationPass : public OptimizationPass {
typedef std::deque<ExpressionPtr> ExpressionPostOrderTraversal;

private:
/// A map from an Expression to the valid time bounds for it.
std::unordered_map<ExpressionPtr, ExpressionTimeBounds>
expressionTimeBoundMap;

/// A helper method to compute the post-order traversal of the Expression
/// graph.
ExpressionPostOrderTraversal computePostOrderTraversal(
ExpressionPtr expression);

/// A helper method to recursively compute the time bounds for an Expression.
void computeTimeBounds(ExpressionPtr expression);
void computeTimeBounds(
const ExpressionPostOrderTraversal& postOrderTraversal);

/// A helper method to push down the time bounds into the Expression tree.
void pushDownTimeBounds(ExpressionPtr expression);
void pushDownTimeBounds(
const ExpressionPostOrderTraversal& postOrderTraversal);

/// A helper method to purge the nodes that do not fit their time bounds.
void purgeNodes(ExpressionPtr expression);
void purgeNodes(const ExpressionPostOrderTraversal& postOrderTraversal);

public:
/// Instantiate the Critical Path optimization pass.
Expand Down
158 changes: 69 additions & 89 deletions schedulers/tetrisched/src/OptimizationPasses.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#include "tetrisched/OptimizationPasses.hpp"

#include <chrono>
#include <queue>
#include <stack>

namespace tetrisched {
Expand All @@ -19,32 +18,42 @@ CriticalPathOptimizationPass::CriticalPathOptimizationPass()
: OptimizationPass("CriticalPathOptimizationPass",
OptimizationPassType::PRE_TRANSLATION_PASS) {}

void CriticalPathOptimizationPass::computeTimeBounds(ExpressionPtr expression) {
CriticalPathOptimizationPass::ExpressionPostOrderTraversal
CriticalPathOptimizationPass::computePostOrderTraversal(
ExpressionPtr expression) {
TETRISCHED_SCOPE_TIMER(
"CriticalPathOptimizationPass::computePostOrderTraversal");
/* Do a Post-Order Traversal of the DAG. */
std::stack<ExpressionPtr> firstStack;
std::stack<ExpressionPtr> secondStack;
std::deque<ExpressionPtr> postOrderTraversal;
firstStack.push(expression);

while (!firstStack.empty()) {
// Move the expression to the second stack.
auto currentExpression = firstStack.top();
firstStack.pop();
secondStack.push(currentExpression);
postOrderTraversal.push_back(currentExpression);

// Add the children to the first stack.
// Add the children of the expression to the first stack.
auto expressionChildren = currentExpression->getChildren();
for (auto child = expressionChildren.rbegin();
child != expressionChildren.rend(); ++child) {
firstStack.push(*child);
}
}

// A Post-Order Traversal will now be the order in which
// the expressions are popped from the second stack.
return postOrderTraversal;
}

void CriticalPathOptimizationPass::computeTimeBounds(
const ExpressionPostOrderTraversal &postOrderTraversal) {
// Iterate over the post-order traversal and compute the time bounds.
std::unordered_set<std::string> visitedExpressions;
while (!secondStack.empty()) {
auto currentExpression = secondStack.top();
secondStack.pop();
visitedExpressions.reserve(postOrderTraversal.size());
for (auto currentExpressionIt = postOrderTraversal.rbegin();
currentExpressionIt != postOrderTraversal.rend();
++currentExpressionIt) {
auto currentExpression = *currentExpressionIt;
if (visitedExpressions.find(currentExpression->getId()) ==
visitedExpressions.end()) {
visitedExpressions.insert(currentExpression->getId());
Expand Down Expand Up @@ -129,32 +138,12 @@ void CriticalPathOptimizationPass::computeTimeBounds(ExpressionPtr expression) {
}

void CriticalPathOptimizationPass::pushDownTimeBounds(
ExpressionPtr expression) {
/* Do a Reverse Post-Order Traversal of the DAG. */
std::stack<ExpressionPtr> traversalStack;
std::queue<ExpressionPtr> traversalQueue;
traversalStack.push(expression);

while (!traversalStack.empty()) {
// Move the expression to the second stack.
auto currentExpression = traversalStack.top();
traversalStack.pop();
traversalQueue.push(currentExpression);

// Add the children to the first stack.
auto expressionChildren = currentExpression->getChildren();
for (auto child = expressionChildren.rbegin();
child != expressionChildren.rend(); ++child) {
traversalStack.push(*child);
}
}

// A Reverse Post-Order Traversal will now be the order in which
// the expressions are popped from the queue.
while (!traversalQueue.empty()) {
auto currentExpression = traversalQueue.front();
traversalQueue.pop();

const ExpressionPostOrderTraversal &postOrderTraversal) {
// Iterate over the post-order traversal in the reverse order and push down
// the time bounds.
for (auto currentExpressionIt = postOrderTraversal.begin();
currentExpressionIt != postOrderTraversal.end(); ++currentExpressionIt) {
auto currentExpression = *currentExpressionIt;
if (expressionTimeBoundMap.find(currentExpression) ==
expressionTimeBoundMap.end()) {
throw exceptions::RuntimeException(
Expand Down Expand Up @@ -310,39 +299,18 @@ void CriticalPathOptimizationPass::pushDownTimeBounds(
}
}

void CriticalPathOptimizationPass::purgeNodes(ExpressionPtr expression) {
/* Do a Post-Order Traversal of the DAG. */
std::stack<ExpressionPtr> firstStack;
std::stack<ExpressionPtr> secondStack;
firstStack.push(expression);

while (!firstStack.empty()) {
// Move the expression to the second stack.
auto currentExpression = firstStack.top();
firstStack.pop();
secondStack.push(currentExpression);

// Add the children to the first stack.
auto expressionChildren = currentExpression->getChildren();
for (auto child = expressionChildren.rbegin();
child != expressionChildren.rend(); ++child) {
firstStack.push(*child);
}
}

// A Post-Order Traversal will now be the order in which
// the expressions are popped from the second stack.
void CriticalPathOptimizationPass::purgeNodes(
const ExpressionPostOrderTraversal &postOrderTraversal) {
// Iterate over the post-order traversal and purge the nodes that cannot be
// satisfied.
std::unordered_set<ExpressionPtr> visitedExpressions;
visitedExpressions.reserve(secondStack.size());
visitedExpressions.reserve(postOrderTraversal.size());
std::unordered_set<ExpressionPtr> purgedExpressions;
purgedExpressions.reserve(secondStack.size());
while (!secondStack.empty()) {
auto currentExpression = secondStack.top();

TETRISCHED_SCOPE_TIMER("CriticalPathOptimizationPass::purgeNodes::purge" +
currentExpression->getTypeString())

secondStack.pop();
purgedExpressions.reserve(postOrderTraversal.size());
for (auto currentExpressionIt = postOrderTraversal.rbegin();
currentExpressionIt != postOrderTraversal.rend();
++currentExpressionIt) {
auto currentExpression = *currentExpressionIt;
if (visitedExpressions.find(currentExpression) ==
visitedExpressions.end()) {
visitedExpressions.insert(currentExpression);
Expand All @@ -361,8 +329,6 @@ void CriticalPathOptimizationPass::purgeNodes(ExpressionPtr expression) {
}

if (currentExpression->getNumChildren() == 0) {
TETRISCHED_SCOPE_TIMER(
"CriticalPathOptimizationPass::purgeNodes::boundsCheckPurgeLeaf");
// If this is a leaf node, we check if it can be purged.
auto newTimeBounds = newTimeBoundsLocation->second;
auto originalTimeBounds = currentExpression->getTimeBounds();
Expand Down Expand Up @@ -396,9 +362,9 @@ void CriticalPathOptimizationPass::purgeNodes(ExpressionPtr expression) {
ExpressionType::EXPR_WINDOWED_CHOOSE ||
currentExpression->getType() ==
ExpressionType::EXPR_MALLEABLE_CHOOSE) {
// Both WindowedChoose and MalleableChoose can generate various options.
// We only purge them if all the options are invalid. Otherwise, we just
// tighten the bounds.
// Both WindowedChoose and MalleableChoose can generate various
// options. We only purge them if all the options are invalid.
// Otherwise, we just tighten the bounds.
if (newTimeBounds.startTimeRange.first >
newTimeBounds.endTimeRange.second) {
// The expression is being asked to start after it can finish at the
Expand Down Expand Up @@ -447,33 +413,44 @@ void CriticalPathOptimizationPass::purgeNodes(ExpressionPtr expression) {
}

void CriticalPathOptimizationPass::runPass(
ExpressionPtr strlExpression, CapacityConstraintMapPtr capacityConstraints,
std::optional<std::string> debugFile) {
ExpressionPtr strlExpression,
CapacityConstraintMapPtr /* capacityConstraints */,
std::optional<std::string> /* debugFile */) {
/* Preprocessing: We first compute the post-order traversal of the
Expression graph since all subsequent steps use it. */
auto postOrderTraversal = computePostOrderTraversal(strlExpression);

// We reserve enough space in the map to avoid rehashing.
expressionTimeBoundMap.reserve(postOrderTraversal.size());

/* Phase 1: We first do a bottom-up traversal of the tree to compute
a tight bound for each node in the STRL tree. */
{
TETRISCHED_SCOPE_TIMER(
"CriticalPathOptimizationPass::runPass::computeTimeBounds");
computeTimeBounds(strlExpression);
computeTimeBounds(postOrderTraversal);
}

/* Phase 2: The previous phase computes the tight bounds but does not
push them down necessarily. In this phase, we push the bounds down. */
{
TETRISCHED_SCOPE_TIMER(
"CriticalPathOptimizationPass::runPass::pushDownTimeBounds");
pushDownTimeBounds(strlExpression);
pushDownTimeBounds(postOrderTraversal);
}

/* Phase 3: The bounds have been pushed down now, we can do a bottom-up
traversal and start purging nodes that cannot be satisfied. */
{
TETRISCHED_SCOPE_TIMER("CriticalPathOptimizationPass::runPass::purgeNodes");
purgeNodes(strlExpression);
purgeNodes(postOrderTraversal);
}
}

void CriticalPathOptimizationPass::clean() { expressionTimeBoundMap.clear(); }
void CriticalPathOptimizationPass::clean() {
TETRISCHED_SCOPE_TIMER("CriticalPathOptimizationPass::clean");
expressionTimeBoundMap.clear();
}

/* Methods for DiscretizationSelectorOptimizationPass */
DiscretizationSelectorOptimizationPass::DiscretizationSelectorOptimizationPass()
Expand Down Expand Up @@ -604,15 +581,16 @@ void DiscretizationSelectorOptimizationPass::runPass(
}
// std::cout << "\t" << "[" << minOccupancyTime << "]" <<
// "[DiscretizationSelectorOptimizationPass] "
// << i + minOccupancyTime << ": " << occupancyRequests[i]
// << i + minOccupancyTime << ": " <<
// occupancyRequests[i]
// << std::endl;
}

double autoMaxOccupancyThreshold = maxOccupancyThreshold * maxOccupancyVal;
// std::cout << "** [DiscretizationSelectorOptimizationPass] Max
// Discretization Value" << maxOccupancyVal << " Threshold Decided is: " <<
// autoMaxOccupancyThreshold << " threshold val: " << maxOccupancyThreshold <<
// std::endl;
// autoMaxOccupancyThreshold << " threshold val: " << maxOccupancyThreshold
// << std::endl;

// finding the right discretization

Expand Down Expand Up @@ -645,8 +623,8 @@ void DiscretizationSelectorOptimizationPass::runPass(
auto nextPlanAhead =
std::min(i + predictedDiscretization, occupancyRequests.size());

// find the right discretization such that average occupancy for that period
// predicts same discretization
// find the right discretization such that average occupancy for that
// period predicts same discretization
while (nextPlanAhead >= (i + 1)) {
double averageOccupancy = 0;
int count = 0;
Expand Down Expand Up @@ -678,7 +656,8 @@ void DiscretizationSelectorOptimizationPass::runPass(
// std::cout <<
// "[DiscretizationSelectorOptimizationPass] Dynamic Discretization
// between "
// << minOccupancyTime << " and " << maxOccupancyTime << ": "<< std::endl;
// << minOccupancyTime << " and " << maxOccupancyTime << ": "<<
// std::endl;

// for (auto &[discretizationTimeRange, granularity] :
// timeRangeToGranularities) {
Expand All @@ -687,7 +666,8 @@ void DiscretizationSelectorOptimizationPass::runPass(
// << "[DiscretizationSelectorOptimizationPassDiscreteTime] "
// << discretizationTimeRange.first << " - " <<
// discretizationTimeRange.second << " : " << granularity << "
// Occuapncy: " << occupancyRequests[discretizationTimeRange.first
// Occuapncy: " <<
// occupancyRequests[discretizationTimeRange.first
// - minOccupancyTime] << std::endl;
// }

Expand Down Expand Up @@ -720,15 +700,15 @@ void DiscretizationSelectorOptimizationPass::runPass(
}
if (ncksWithinTimeRange.size() > 1) {
// if more than one nck found within the time range, remove it as only
// one nck within granularity is sufficient. Nck with minimum start time
// is kept within this time range
// one nck within granularity is sufficient. Nck with minimum start
// time is kept within this time range
for (auto redundantNckExpr : ncksWithinTimeRange) {
if (redundantNckExpr->getId() != minStartTimeNckExpr->getId()) {
maxNckExpr->removeChild(redundantNckExpr);
// std::cout << "[DiscretizationSelectorOptimizationPassRemoveNck]
// Removing NCK: " + redundantNckExpr->getName() + " From Max: " +
// maxNckExpr->getName() << "Time Range: [" << startTime << ", " <<
// endTime << "]";
// maxNckExpr->getName() << "Time Range: [" << startTime << ", "
// << endTime << "]";
}
}
}
Expand Down

0 comments on commit f06909c

Please sign in to comment.