Skip to content

Commit

Permalink
add support for tpch query partitioning
Browse files Browse the repository at this point in the history
  • Loading branch information
1ntEgr8 committed Dec 4, 2024
1 parent f4bbe6a commit b958a8a
Showing 1 changed file with 75 additions and 28 deletions.
103 changes: 75 additions & 28 deletions data/tpch_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from typing import Any, Dict, List, Optional, Callable, Tuple
from pathlib import Path
from enum import Enum

import absl
import numpy as np
Expand All @@ -28,6 +29,12 @@
from .base_workload_loader import BaseWorkloadLoader


class TpchQueryDifficulty(Enum):
easy = {1, 3, 4, 6, 12, 14, 17, 19, 22}
medium = {10, 11, 13, 15, 16, 18, 20}
hard = {2, 7, 8, 9, 21}


class TpchLoader:
"""Construct TPC-H task graph from a query profile
Expand Down Expand Up @@ -286,64 +293,104 @@ def __init__(self, flags: "absl.flags") -> None:
# Instantiate tpch loader
self._tpch_loader = TpchLoader(path=flags.tpch_query_dag_spec, flags=flags)

# Gather release times
release_policy = self.__make_release_policy()
release_times = release_policy.get_release_times(
completion_time=EventTime(self._flags.loop_timeout, EventTime.Unit.US)
)
# Intialize [(query_num, release_time)]
self._query_nums_and_release_times = []
if len(flags.override_num_invocations) > 0:
# One each for easy, medium, and hard
assert len(flags.override_num_invocations) == len(TpchQueryDifficulty)
assert len(flags.override_poisson_arrival_rates) == len(
flags.override_num_invocations
)

# Sample queries to be released
query_nums = [
self._rng.randint(1, self._tpch_loader.num_queries)
for _ in range(self._flags.override_num_invocation)
]
# only works with poisson distribution
assert flags.override_release_policy == "poisson"

for i, part in enumerate(TpchQueryDifficulty):
print(flags.override_poisson_arrival_rates[i])
release_policy = self.__make_release_policy(
policy_type=flags.override_release_policy,
arrival_rate=float(flags.override_poisson_arrival_rates[i]),
num_invocations=int(flags.override_num_invocations[i]),
)
release_times = release_policy.get_release_times(
completion_time=EventTime(
self._flags.loop_timeout, EventTime.Unit.US
)
)
query_nums = [
self._rng.choice(list(part.value))
for _ in range(int(flags.override_num_invocations[i]))
]
self._query_nums_and_release_times.extend(
list(zip(query_nums, release_times))
)

self._query_nums_and_release_times.sort(key=lambda x: x[1])
else:
release_policy = self.__make_release_policy()
release_times = release_policy.get_release_times(
completion_time=EventTime(self._flags.loop_timeout, EventTime.Unit.US)
)
query_nums = [
self._rng.randint(1, self._tpch_loader.num_queries)
for _ in range(self._flags.override_num_invocation)
]
self._query_nums_and_release_times.extend(
list(zip(query_nums, release_times))
)

self._query_nums_and_release_times = list(zip(query_nums, release_times))
self._current_release_pointer = 0

# Initialize workload
self._workload = Workload.empty(flags)

def __make_release_policy(self):
def __make_release_policy(
self, policy_type=None, arrival_rate=None, num_invocations=None
):
if policy_type is None:
policy_type = self._flags.override_release_policy
if arrival_rate is None:
arrival_rate = self._flags.override_poisson_arrival_rate
if num_invocations is None:
num_invocations = self._flags.override_num_invocation

release_policy_args = {}
if self._flags.override_release_policy == "periodic":
if policy_type == "periodic":
release_policy_args = {
"period": EventTime(
self._flags.override_arrival_period, EventTime.Unit.US
),
}
elif self._flags.override_release_policy == "fixed":
elif policy_type == "fixed":
release_policy_args = {
"period": EventTime(
self._flags.override_arrival_period, EventTime.Unit.US
),
"num_invocations": self._flags.override_num_invocation,
"num_invocations": num_invocations,
}
elif self._flags.override_release_policy == "poisson":
elif policy_type == "poisson":
release_policy_args = {
"rate": self._flags.override_poisson_arrival_rate,
"num_invocations": self._flags.override_num_invocation,
"rate": arrival_rate,
"num_invocations": num_invocations,
}
elif self._flags.override_release_policy == "gamma":
elif policy_type == "gamma":
release_policy_args = {
"rate": self._flags.override_poisson_arrival_rate,
"num_invocations": self._flags.override_num_invocation,
"rate": arrival_rate,
"num_invocations": num_invocations,
"coefficient": self._flags.override_gamma_coefficient,
}
elif self._flags.override_release_policy == "fixed_gamma":
elif policy_type == "fixed_gamma":
release_policy_args = {
"variable_arrival_rate": self._flags.override_poisson_arrival_rate,
"variable_arrival_rate": arrival_rate,
"base_arrival_rate": self._flags.override_base_arrival_rate,
"num_invocations": self._flags.override_num_invocation,
"num_invocations": num_invocations,
"coefficient": self._flags.override_gamma_coefficient,
}
else:
raise NotImplementedError(
f"Release policy {self._flags.override_release_policy} not implemented."
)
raise NotImplementedError(f"Release policy {policy_type} not implemented.")

return make_release_policy(
self._flags.override_release_policy,
policy_type,
release_policy_args,
self._rng,
self._rng_seed,
Expand Down

0 comments on commit b958a8a

Please sign in to comment.