Skip to content

Commit

Permalink
Update the OptimizationPass API for Post-Translation passes.
Browse files Browse the repository at this point in the history
  • Loading branch information
sukritkalra committed Nov 6, 2023
1 parent 29145e7 commit 80d60b6
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 16 deletions.
47 changes: 42 additions & 5 deletions schedulers/tetrisched/include/tetrisched/OptimizationPasses.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,33 @@
#include "tetrisched/Expression.hpp"

namespace tetrisched {
enum OptimizationPassType {
/// A `PRE_TRANSLATION_PASS` is a pass that is run before the translation
/// of the STRL expression into a solver model.
PRE_TRANSLATION_PASS = 0,
/// A `POST_TRANSLATION_PASS` is a pass that is run after the translation
/// of the STRL expression into a solver model.
POST_TRANSLATION_PASS = 1,
};

/// An `OptimizationPass` is a base class for all Optimization passes that
/// run on the STRL tree.
class OptimizationPass {
/// A representative name of the optimization pass.
std::string name;
/// The type of the optimization pass.
OptimizationPassType type;

public:
/// Construct the base OptimizationPass class.
OptimizationPass(std::string name);
OptimizationPass(std::string name, OptimizationPassType type);

/// Get the type of the optimization pass.
OptimizationPassType getType() const;

/// Run the pass on the given STRL expression.
virtual void runPass(ExpressionPtr strlExpression) = 0;
virtual void runPass(ExpressionPtr strlExpression,
CapacityConstraintMap& capacityConstraints) = 0;
};
using OptimizationPassPtr = std::shared_ptr<OptimizationPass>;

Expand All @@ -38,7 +55,22 @@ class CriticalPathOptimizationPass : public OptimizationPass {
CriticalPathOptimizationPass();

/// Run the Critical Path optimization pass on the given STRL expression.
void runPass(ExpressionPtr strlExpression) override;
void runPass(ExpressionPtr strlExpression,
CapacityConstraintMap& capacityConstraints) override;
};

/// A `CapacityConstraintMapPurgingOptimizationPass` is an optimization pass
/// that aims to remove the capacity constraints that are not needed because
/// they are trivially satisfied by the Expression tree.
class CapacityConstraintMapPurgingOptimizationPass : public OptimizationPass {
public:
/// Instantiate the CapacityConstraintMapPurgingOptimizationPass.
CapacityConstraintMapPurgingOptimizationPass();

/// Run the CapacityConstraintMapPurgingOptimizationPass on the given STRL
/// expression.
void runPass(ExpressionPtr strlExpression,
CapacityConstraintMap& capacityConstraints) override;
};

class OptimizationPassRunner {
Expand All @@ -50,8 +82,13 @@ class OptimizationPassRunner {
/// Initialize the OptimizationPassRunner.
OptimizationPassRunner();

/// Run the registered optimization passes on the given STRL expression.
void runPasses(ExpressionPtr strlExpression);
/// Run the pre-translation optimization passes on the given STRL expression.
void runPreTranslationPasses(ExpressionPtr strlExpression,
CapacityConstraintMap& capacityConstraints);

/// Run the post-translation optimization passes on the given STRL expression.
void runPostTranslationPasses(ExpressionPtr strlExpression,
CapacityConstraintMap& capacityConstraints);
};
} // namespace tetrisched
#endif // _TETRISCHED_OPTIMIZATION_PASSES_HPP_
42 changes: 37 additions & 5 deletions schedulers/tetrisched/src/OptimizationPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@
namespace tetrisched {

/* Methods for OptimizationPass */
OptimizationPass::OptimizationPass(std::string name) : name(name) {}
OptimizationPass::OptimizationPass(std::string name, OptimizationPassType type)
: name(name), type(type) {}

OptimizationPassType OptimizationPass::getType() const { return type; }

/* Methods for CriticalPathOptimizationPass */
CriticalPathOptimizationPass::CriticalPathOptimizationPass()
: OptimizationPass("CriticalPathOptimizationPass") {}
: OptimizationPass("CriticalPathOptimizationPass",
OptimizationPassType::PRE_TRANSLATION_PASS) {}

void CriticalPathOptimizationPass::computeTimeBounds(ExpressionPtr expression) {
/* Do a Post-Order Traversal of the DAG. */
Expand Down Expand Up @@ -344,7 +348,8 @@ void CriticalPathOptimizationPass::purgeNodes(ExpressionPtr expression) {
}
}

void CriticalPathOptimizationPass::runPass(ExpressionPtr strlExpression) {
void CriticalPathOptimizationPass::runPass(
ExpressionPtr strlExpression, CapacityConstraintMap& capacityConstraints) {
/* Phase 1: We first do a bottom-up traversal of the tree to compute
a tight bound for each node in the STRL tree. */
computeTimeBounds(strlExpression);
Expand All @@ -358,16 +363,43 @@ void CriticalPathOptimizationPass::runPass(ExpressionPtr strlExpression) {
purgeNodes(strlExpression);
}

/* Methods for CapacityConstraintMapPurgingOptimizationPass */
CapacityConstraintMapPurgingOptimizationPass::
CapacityConstraintMapPurgingOptimizationPass()
: OptimizationPass("CapacityConstraintMapPurgingOptimizationPass",
OptimizationPassType::POST_TRANSLATION_PASS) {}

void CapacityConstraintMapPurgingOptimizationPass::runPass(
ExpressionPtr strlExpression, CapacityConstraintMap& capacityConstraints) {
throw tetrisched::exceptions::RuntimeException("Not implemented yet!");
}

/* Methods for OptimizationPassRunner */
OptimizationPassRunner::OptimizationPassRunner() {
// Register the Critical Path optimization pass.
registeredPasses.push_back(std::make_shared<CriticalPathOptimizationPass>());
// Register the CapacityConstraintMapPurging optimization pass.
registeredPasses.push_back(
std::make_shared<CapacityConstraintMapPurgingOptimizationPass>());
}

void OptimizationPassRunner::runPasses(ExpressionPtr strlExpression) {
void OptimizationPassRunner::runPreTranslationPasses(
ExpressionPtr strlExpression, CapacityConstraintMap& capacityConstraints) {
// Run the registered optimization passes on the given STRL expression.
for (auto& pass : registeredPasses) {
pass->runPass(strlExpression);
if (pass->getType() == OptimizationPassType::PRE_TRANSLATION_PASS) {
pass->runPass(strlExpression, capacityConstraints);
}
}
}

void OptimizationPassRunner::runPostTranslationPasses(
ExpressionPtr strlExpression, CapacityConstraintMap& capacityConstraints) {
// Run the registered optimization passes on the given STRL expression.
for (auto& pass : registeredPasses) {
if (pass->getType() == OptimizationPassType::POST_TRANSLATION_PASS) {
pass->runPass(strlExpression, capacityConstraints);
}
}
}
} // namespace tetrisched
17 changes: 12 additions & 5 deletions schedulers/tetrisched/src/Scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,24 @@ void Scheduler::registerSTRL(ExpressionPtr expression,
// Save the expression.
this->expression = expression;

// Run the OptimizationPasses on this expression.
if (optimize) {
optimizationPasses.runPasses(expression);
}

// Create the CapacityConstraintMap for the STRL tree to add constraints to.
CapacityConstraintMap capacityConstraintMap(discretization);

// Run the Pre-Translation OptimizationPasses on this expression.
if (optimize) {
optimizationPasses.runPreTranslationPasses(expression,
capacityConstraintMap);
}

// Parse the ExpressionTree to populate the solver model.
auto _ = expression->parse(solverModel, availablePartitions,
capacityConstraintMap, currentTime);

// Run the Post-Translation OptimizationPasses on this expression.
if (optimize) {
optimizationPasses.runPostTranslationPasses(expression,
capacityConstraintMap);
}
}

void Scheduler::schedule(Time currentTime) {
Expand Down
3 changes: 2 additions & 1 deletion schedulers/tetrisched/test/test_optimization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ TEST(OptimizationTest, TestBasicCriticalPathOptimizationPass) {
lessThanExpression->addChild(maxExpression_2);
lessThanExpression->exportToDot("PreOptimizationPass.dot");

optimizationPass.runPass(lessThanExpression);
tetrisched::CapacityConstraintMap capacityConstraintMap(1);
optimizationPass.runPass(lessThanExpression, capacityConstraintMap);
lessThanExpression->exportToDot("PostOptimizationPass.dot");
}

0 comments on commit 80d60b6

Please sign in to comment.