diff --git a/data/tpch_loader.py b/data/tpch_loader.py index 97c2522a..8d04198a 100644 --- a/data/tpch_loader.py +++ b/data/tpch_loader.py @@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional, Callable, Tuple from pathlib import Path +from enum import Enum import absl import numpy as np @@ -28,6 +29,12 @@ from .base_workload_loader import BaseWorkloadLoader +class TpchQueryDifficulty(Enum): + easy = {1, 3, 4, 6, 12, 14, 17, 19, 22} + medium = {10, 11, 13, 15, 16, 18, 20} + hard = {2, 7, 8, 9, 21} + + class TpchLoader: """Construct TPC-H task graph from a query profile @@ -286,64 +293,104 @@ def __init__(self, flags: "absl.flags") -> None: # Instantiate tpch loader self._tpch_loader = TpchLoader(path=flags.tpch_query_dag_spec, flags=flags) - # Gather release times - release_policy = self.__make_release_policy() - release_times = release_policy.get_release_times( - completion_time=EventTime(self._flags.loop_timeout, EventTime.Unit.US) - ) + # Intialize [(query_num, release_time)] + self._query_nums_and_release_times = [] + if len(flags.override_num_invocations) > 0: + # One each for easy, medium, and hard + assert len(flags.override_num_invocations) == len(TpchQueryDifficulty) + assert len(flags.override_poisson_arrival_rates) == len( + flags.override_num_invocations + ) - # Sample queries to be released - query_nums = [ - self._rng.randint(1, self._tpch_loader.num_queries) - for _ in range(self._flags.override_num_invocation) - ] + # only works with poisson distribution + assert flags.override_release_policy == "poisson" + + for i, part in enumerate(TpchQueryDifficulty): + print(flags.override_poisson_arrival_rates[i]) + release_policy = self.__make_release_policy( + policy_type=flags.override_release_policy, + arrival_rate=float(flags.override_poisson_arrival_rates[i]), + num_invocations=int(flags.override_num_invocations[i]), + ) + release_times = release_policy.get_release_times( + completion_time=EventTime( + self._flags.loop_timeout, EventTime.Unit.US + ) + ) + query_nums = [ + self._rng.choice(list(part.value)) + for _ in range(int(flags.override_num_invocations[i])) + ] + self._query_nums_and_release_times.extend( + list(zip(query_nums, release_times)) + ) + + self._query_nums_and_release_times.sort(key=lambda x: x[1]) + else: + release_policy = self.__make_release_policy() + release_times = release_policy.get_release_times( + completion_time=EventTime(self._flags.loop_timeout, EventTime.Unit.US) + ) + query_nums = [ + self._rng.randint(1, self._tpch_loader.num_queries) + for _ in range(self._flags.override_num_invocation) + ] + self._query_nums_and_release_times.extend( + list(zip(query_nums, release_times)) + ) - self._query_nums_and_release_times = list(zip(query_nums, release_times)) self._current_release_pointer = 0 # Initialize workload self._workload = Workload.empty(flags) - def __make_release_policy(self): + def __make_release_policy( + self, policy_type=None, arrival_rate=None, num_invocations=None + ): + if policy_type is None: + policy_type = self._flags.override_release_policy + if arrival_rate is None: + arrival_rate = self._flags.override_poisson_arrival_rate + if num_invocations is None: + num_invocations = self._flags.override_num_invocation + release_policy_args = {} - if self._flags.override_release_policy == "periodic": + if policy_type == "periodic": release_policy_args = { "period": EventTime( self._flags.override_arrival_period, EventTime.Unit.US ), } - elif self._flags.override_release_policy == "fixed": + elif policy_type == "fixed": release_policy_args = { "period": EventTime( self._flags.override_arrival_period, EventTime.Unit.US ), - "num_invocations": self._flags.override_num_invocation, + "num_invocations": num_invocations, } - elif self._flags.override_release_policy == "poisson": + elif policy_type == "poisson": release_policy_args = { - "rate": self._flags.override_poisson_arrival_rate, - "num_invocations": self._flags.override_num_invocation, + "rate": arrival_rate, + "num_invocations": num_invocations, } - elif self._flags.override_release_policy == "gamma": + elif policy_type == "gamma": release_policy_args = { - "rate": self._flags.override_poisson_arrival_rate, - "num_invocations": self._flags.override_num_invocation, + "rate": arrival_rate, + "num_invocations": num_invocations, "coefficient": self._flags.override_gamma_coefficient, } - elif self._flags.override_release_policy == "fixed_gamma": + elif policy_type == "fixed_gamma": release_policy_args = { - "variable_arrival_rate": self._flags.override_poisson_arrival_rate, + "variable_arrival_rate": arrival_rate, "base_arrival_rate": self._flags.override_base_arrival_rate, - "num_invocations": self._flags.override_num_invocation, + "num_invocations": num_invocations, "coefficient": self._flags.override_gamma_coefficient, } else: - raise NotImplementedError( - f"Release policy {self._flags.override_release_policy} not implemented." - ) + raise NotImplementedError(f"Release policy {policy_type} not implemented.") return make_release_policy( - self._flags.override_release_policy, + policy_type, release_policy_args, self._rng, self._rng_seed,