Skip to content

Commit

Permalink
feat(shuffles): Locality aware pre shuffle merge (#3505)
Browse files Browse the repository at this point in the history
Only merge maps in pre shuffle merge if they are on the same node. Uses
the `ray.experimental.get_object_locations` API to retrieve the node id
for an object, and
`ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy` to
schedule the merge job on a node.

The benefit here is that the maps + merges finish slightly quicker.

On a 3000 x 3000 300mb partition shuffle:
- Locality aware: 12m 32s
- Unaware: 13m 9s. 

On a 1000 x 1000 100mb partition shuffle:
- Locality aware: 29s
- Unaware: 45s

When we increase the partitions by 3x, the performance increase with
locality awareness doesn't really change. Looking at the trace for the
3000 x 3000 shuffle.

Locality aware:
<img width="1257" alt="Screenshot 2024-12-10 at 12 15 14 PM"
src="https://github.com/user-attachments/assets/bb50f804-6d29-46ff-a344-a94f8d10835a">

Unaware:
<img width="1262" alt="Screenshot 2024-12-10 at 12 12 23 PM"
src="https://github.com/user-attachments/assets/9d191ab1-1521-491a-9047-f678adf7d711">

The reduces take up pretty much most of the time anyway, but you can see
that the maps + merges take less time in locality aware mode. 2.6 mins
vs 3.1 mins.

---------

Co-authored-by: EC2 Default User <[email protected]>
Co-authored-by: Colin Ho <[email protected]>
  • Loading branch information
3 people authored Dec 12, 2024
1 parent 2557dba commit 6ae4e77
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 35 deletions.
7 changes: 7 additions & 0 deletions daft/execution/execution_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ class PartitionTask(Generic[PartitionT]):
# Indicates if the PartitionTask is "done" or not
is_done: bool = False

# Desired node_id to schedule this task on
node_id: str | None = None

_id: int = field(default_factory=lambda: next(ID_GEN))

def id(self) -> str:
Expand Down Expand Up @@ -108,6 +111,7 @@ def __init__(
partial_metadatas: list[PartialPartitionMetadata] | None,
resource_request: ResourceRequest = ResourceRequest(),
actor_pool_id: str | None = None,
node_id: str | None = None,
) -> None:
self.inputs = inputs
if partial_metadatas is not None:
Expand All @@ -118,6 +122,7 @@ def __init__(
self.instructions: list[Instruction] = list()
self.num_results = len(inputs)
self.actor_pool_id = actor_pool_id
self.node_id = node_id

def add_instruction(
self,
Expand Down Expand Up @@ -156,6 +161,7 @@ def finalize_partition_task_single_output(self, stage_id: int) -> SingleOutputPa
resource_request=resource_request_final_cpu,
partial_metadatas=self.partial_metadatas,
actor_pool_id=self.actor_pool_id,
node_id=self.node_id,
)

def finalize_partition_task_multi_output(self, stage_id: int) -> MultiOutputPartitionTask[PartitionT]:
Expand All @@ -177,6 +183,7 @@ def finalize_partition_task_multi_output(self, stage_id: int) -> MultiOutputPart
resource_request=resource_request_final_cpu,
partial_metadatas=self.partial_metadatas,
actor_pool_id=self.actor_pool_id,
node_id=self.node_id,
)

def __str__(self) -> str:
Expand Down
106 changes: 73 additions & 33 deletions daft/execution/shuffles/pre_shuffle_merge.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import logging
from collections import defaultdict
from typing import Dict

import ray.experimental # noqa: TID253

from daft.daft import ResourceRequest
from daft.execution import execution_step
from daft.execution.execution_step import (
Expand Down Expand Up @@ -38,7 +41,7 @@ def pre_shuffle_merge(
no_more_input = False

while True:
# Get and sort materialized maps by size.
# Get materialized maps by size
materialized_maps = sorted(
[
(
Expand All @@ -56,42 +59,79 @@ def pre_shuffle_merge(
done_with_input = no_more_input and len(materialized_maps) == len(in_flight_maps)

if enough_maps or done_with_input:
# Initialize the first merge group
merge_groups = []
current_group = [materialized_maps[0][0]]
current_size = materialized_maps[0][1]

# Group remaining maps based on memory threshold
for partition, size in materialized_maps[1:]:
if current_size + size > pre_shuffle_merge_threshold:
merge_groups.append(current_group)
current_group = [partition]
current_size = size
else:
current_group.append(partition)
current_size += size
# Get location information for all materialized partitions
partitions = [m[0].result().partition() for m in materialized_maps]
location_map = ray.experimental.get_object_locations(partitions)

# Group partitions by node
node_groups = defaultdict(list)
unknown_location_group = [] # Special group for partitions without known location

# Add the last group if it exists and is either:
# 1. Contains more than 1 partition
# 2. Is the last group and we're done with input
# 3. The partition exceeds the memory threshold
if current_group:
if len(current_group) > 1 or done_with_input or current_size > pre_shuffle_merge_threshold:
merge_groups.append(current_group)
for partition, size in materialized_maps:
partition_ref = partition.partition()
location_info = location_map.get(partition_ref, {})

if not location_info or "node_ids" not in location_info or not location_info["node_ids"]:
unknown_location_group.append((partition, size))
else:
# TODO: Handle multiple locations, with a strategy to select more optimal nodes, e.g. based on memory
node_id = location_info["node_ids"][0] # Use first node if multiple locations exist
node_groups[node_id].append((partition, size))

# Function to create merge groups for a list of partitions
def create_merge_groups(partitions_list):
if not partitions_list:
return []

groups = []
current_group = [partitions_list[0][0]]
current_size = partitions_list[0][1]

for partition, size in partitions_list[1:]:
if current_size + size > pre_shuffle_merge_threshold:
groups.append(current_group)
current_group = [partition]
current_size = size
else:
current_group.append(partition)
current_size += size

# Add the last group if it exists and is either:
# 1. Contains more than 1 partition
# 2. Is the last group and we're done with input
# 3. The partition exceeds the memory threshold
should_add_last_group = (
len(current_group) > 1 or done_with_input or current_size > pre_shuffle_merge_threshold
)
if current_group and should_add_last_group:
groups.append(current_group)

return groups

# Process each node's partitions and unknown location partitions
merge_groups = {}

# Process node-specific groups
for node_id, node_partitions in node_groups.items():
merge_groups[node_id] = create_merge_groups(node_partitions)

# Process unknown location group
merge_groups[None] = create_merge_groups(unknown_location_group)

# Create merge steps and remove processed maps
for group in merge_groups:
for node_id, groups in merge_groups.items():
# Remove processed maps from in_flight_maps
for partition in group:
del in_flight_maps[partition.id()]

total_size = sum(m.partition_metadata().size_bytes or 0 for m in group)
merge_step = PartitionTaskBuilder[PartitionT](
inputs=[p.partition() for p in group],
partial_metadatas=[m.partition_metadata() for m in group],
resource_request=ResourceRequest(memory_bytes=total_size),
).add_instruction(instruction=execution_step.ReduceMerge())
yield merge_step
for group in groups:
for partition in group:
del in_flight_maps[partition.id()]
total_size = sum(m.partition_metadata().size_bytes or 0 for m in group)
merge_step = PartitionTaskBuilder[PartitionT](
inputs=[p.partition() for p in group],
partial_metadatas=[m.partition_metadata() for m in group],
resource_request=ResourceRequest(memory_bytes=total_size),
node_id=node_id,
).add_instruction(instruction=execution_step.ReduceMerge())
yield merge_step

# Process next map task if available
try:
Expand Down
11 changes: 9 additions & 2 deletions daft/runners/ray_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# import times. If this changes, we first need to make the daft.lazy_import.LazyImport class
# serializable before importing pa from daft.dependencies.
import pyarrow as pa # noqa: TID253
import ray.experimental # noqa: TID253

from daft.arrow_utils import ensure_array
from daft.context import execution_config_ctx, get_context
Expand Down Expand Up @@ -516,7 +517,7 @@ def fanout_pipeline(


@ray_tracing.ray_remote_traced
@ray.remote(scheduling_strategy="SPREAD")
@ray.remote
def reduce_pipeline(
task_context: PartitionTaskContext,
daft_execution_config: PyDaftExecutionConfig,
Expand All @@ -533,7 +534,7 @@ def reduce_pipeline(


@ray_tracing.ray_remote_traced
@ray.remote(scheduling_strategy="SPREAD")
@ray.remote
def reduce_and_fanout(
task_context: PartitionTaskContext,
daft_execution_config: PyDaftExecutionConfig,
Expand Down Expand Up @@ -1012,6 +1013,12 @@ def _build_partitions(
if task.instructions and isinstance(task.instructions[-1], FanoutInstruction)
else reduce_pipeline
)
if task.node_id is not None:
ray_options["scheduling_strategy"] = ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
task.node_id, soft=True
)
else:
ray_options["scheduling_strategy"] = "SPREAD"
build_remote = build_remote.options(**ray_options).with_tracing(runner_tracer, task)
[metadatas_ref, *partitions] = build_remote.remote(
PartitionTaskContext(job_id=job_id, task_id=task.id(), stage_id=task.stage_id),
Expand Down
4 changes: 4 additions & 0 deletions src/common/daft-config/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ impl DaftExecutionConfig {
{
cfg.enable_ray_tracing = true;
}
let shuffle_algorithm_env_var_name = "DAFT_SHUFFLE_ALGORITHM";
if let Ok(val) = std::env::var(shuffle_algorithm_env_var_name) {
cfg.shuffle_algorithm = val;
}
cfg
}
}
Expand Down

0 comments on commit 6ae4e77

Please sign in to comment.