Skip to content

Commit

Permalink
feat: Limit number of sources in merged scan task (#3695)
Browse files Browse the repository at this point in the history
Don't merge more than `N` (default 10) sources in a scan task. This is
so that we don't over merge, given that our estimates with filter
selectivity might be off.

---------

Co-authored-by: Colin Ho <[email protected]>
  • Loading branch information
colin-ho and Colin Ho authored Jan 30, 2025
1 parent 78beff4 commit 29cab30
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 6 deletions.
3 changes: 3 additions & 0 deletions daft/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ def set_execution_config(
config: PyDaftExecutionConfig | None = None,
scan_tasks_min_size_bytes: int | None = None,
scan_tasks_max_size_bytes: int | None = None,
max_sources_per_scan_task: int | None = None,
broadcast_join_size_bytes_threshold: int | None = None,
parquet_split_row_groups_max_files: int | None = None,
sort_merge_join_sort_with_aligned_boundaries: bool | None = None,
Expand Down Expand Up @@ -368,6 +369,7 @@ def set_execution_config(
scan_tasks_max_size_bytes: Maximum size in bytes when merging ScanTasks when reading files from storage.
Increasing this value will increase the upper bound of the size of merged ScanTasks, which leads to bigger but
fewer partitions. (Defaults to 384 MiB)
max_sources_per_scan_task: Maximum number of sources in a single ScanTask. (Defaults to 10)
broadcast_join_size_bytes_threshold: If one side of a join is smaller than this threshold, a broadcast join will be used.
Default is 10 MiB.
parquet_split_row_groups_max_files: Maximum number of files to read in which the row group splitting should happen. (Defaults to 10)
Expand Down Expand Up @@ -406,6 +408,7 @@ def set_execution_config(
new_daft_execution_config = old_daft_execution_config.with_config_values(
scan_tasks_min_size_bytes=scan_tasks_min_size_bytes,
scan_tasks_max_size_bytes=scan_tasks_max_size_bytes,
max_sources_per_scan_task=max_sources_per_scan_task,
broadcast_join_size_bytes_threshold=broadcast_join_size_bytes_threshold,
parquet_split_row_groups_max_files=parquet_split_row_groups_max_files,
sort_merge_join_sort_with_aligned_boundaries=sort_merge_join_sort_with_aligned_boundaries,
Expand Down
3 changes: 3 additions & 0 deletions daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1719,6 +1719,7 @@ class PyDaftExecutionConfig:
self,
scan_tasks_min_size_bytes: int | None = None,
scan_tasks_max_size_bytes: int | None = None,
max_sources_per_scan_task: int | None = None,
broadcast_join_size_bytes_threshold: int | None = None,
parquet_split_row_groups_max_files: int | None = None,
sort_merge_join_sort_with_aligned_boundaries: bool | None = None,
Expand Down Expand Up @@ -1747,6 +1748,8 @@ class PyDaftExecutionConfig:
@property
def scan_tasks_max_size_bytes(self) -> int: ...
@property
def max_sources_per_scan_task(self) -> int: ...
@property
def broadcast_join_size_bytes_threshold(self) -> int: ...
@property
def sort_merge_join_sort_with_aligned_boundaries(self) -> bool: ...
Expand Down
2 changes: 2 additions & 0 deletions src/common/daft-config/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ impl DaftPlanningConfig {
pub struct DaftExecutionConfig {
pub scan_tasks_min_size_bytes: usize,
pub scan_tasks_max_size_bytes: usize,
pub max_sources_per_scan_task: usize,
pub broadcast_join_size_bytes_threshold: usize,
pub sort_merge_join_sort_with_aligned_boundaries: bool,
pub hash_join_partition_size_leniency: f64,
Expand Down Expand Up @@ -69,6 +70,7 @@ impl Default for DaftExecutionConfig {
Self {
scan_tasks_min_size_bytes: 96 * 1024 * 1024, // 96MB
scan_tasks_max_size_bytes: 384 * 1024 * 1024, // 384MB
max_sources_per_scan_task: 10,
broadcast_join_size_bytes_threshold: 10 * 1024 * 1024, // 10 MiB
sort_merge_join_sort_with_aligned_boundaries: false,
hash_join_partition_size_leniency: 0.5,
Expand Down
10 changes: 10 additions & 0 deletions src/common/daft-config/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ impl PyDaftExecutionConfig {
#[pyo3(signature = (
scan_tasks_min_size_bytes=None,
scan_tasks_max_size_bytes=None,
max_sources_per_scan_task=None,
broadcast_join_size_bytes_threshold=None,
parquet_split_row_groups_max_files=None,
sort_merge_join_sort_with_aligned_boundaries=None,
Expand Down Expand Up @@ -105,6 +106,7 @@ impl PyDaftExecutionConfig {
&self,
scan_tasks_min_size_bytes: Option<usize>,
scan_tasks_max_size_bytes: Option<usize>,
max_sources_per_scan_task: Option<usize>,
broadcast_join_size_bytes_threshold: Option<usize>,
parquet_split_row_groups_max_files: Option<usize>,
sort_merge_join_sort_with_aligned_boundaries: Option<bool>,
Expand Down Expand Up @@ -136,6 +138,9 @@ impl PyDaftExecutionConfig {
if let Some(scan_tasks_min_size_bytes) = scan_tasks_min_size_bytes {
config.scan_tasks_min_size_bytes = scan_tasks_min_size_bytes;
}
if let Some(max_sources_per_scan_task) = max_sources_per_scan_task {
config.max_sources_per_scan_task = max_sources_per_scan_task;
}
if let Some(broadcast_join_size_bytes_threshold) = broadcast_join_size_bytes_threshold {
config.broadcast_join_size_bytes_threshold = broadcast_join_size_bytes_threshold;
}
Expand Down Expand Up @@ -236,6 +241,11 @@ impl PyDaftExecutionConfig {
Ok(self.config.scan_tasks_max_size_bytes)
}

#[getter]
fn get_max_sources_per_scan_task(&self) -> PyResult<usize> {
Ok(self.config.max_sources_per_scan_task)
}

#[getter]
fn get_broadcast_join_size_bytes_threshold(&self) -> PyResult<usize> {
Ok(self.config.broadcast_join_size_bytes_threshold)
Expand Down
18 changes: 12 additions & 6 deletions src/daft-scan/src/scan_task_iters/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type BoxScanTaskIter<'a> = Box<dyn Iterator<Item = DaftResult<ScanTaskRef>> + 'a
/// * `scan_tasks`: A Boxed Iterator of ScanTaskRefs to perform merging on
/// * `min_size_bytes`: Minimum size in bytes of a ScanTask, after which no more merging will be performed
/// * `max_size_bytes`: Maximum size in bytes of a ScanTask, capping the maximum size of a merged ScanTask
/// * `max_source_count`: Maximum number of ScanTasks to merge
#[must_use]
fn merge_by_sizes<'a>(
scan_tasks: BoxScanTaskIter<'a>,
Expand Down Expand Up @@ -57,6 +58,7 @@ fn merge_by_sizes<'a>(
target_upper_bound_size_bytes: (limit_bytes * 1.5) as usize,
target_lower_bound_size_bytes: (limit_bytes / 2.) as usize,
accumulator: None,
max_source_count: cfg.max_sources_per_scan_task,
}) as BoxScanTaskIter;
}
}
Expand All @@ -69,6 +71,7 @@ fn merge_by_sizes<'a>(
target_upper_bound_size_bytes: cfg.scan_tasks_max_size_bytes,
target_lower_bound_size_bytes: cfg.scan_tasks_min_size_bytes,
accumulator: None,
max_source_count: cfg.max_sources_per_scan_task,
}) as BoxScanTaskIter
}
}
Expand All @@ -83,6 +86,9 @@ struct MergeByFileSize<'a> {

// Current element being accumulated on
accumulator: Option<ScanTaskRef>,

// Maximum number of files in a merged ScanTask
max_source_count: usize,
}

impl<'a> MergeByFileSize<'a> {
Expand All @@ -92,11 +98,11 @@ impl<'a> MergeByFileSize<'a> {
/// in estimated bytes, as well as other factors including any limit pushdowns.
fn accumulator_ready(&self) -> bool {
// Emit the accumulator as soon as it is bigger than the specified `target_lower_bound_size_bytes`
if let Some(acc) = &self.accumulator
&& let Some(acc_bytes) = acc.estimate_in_memory_size_bytes(Some(self.cfg))
&& acc_bytes >= self.target_lower_bound_size_bytes
{
true
if let Some(acc) = &self.accumulator {
acc.sources.len() >= self.max_source_count
|| acc
.estimate_in_memory_size_bytes(Some(self.cfg))
.map_or(false, |bytes| bytes >= self.target_lower_bound_size_bytes)
} else {
false
}
Expand Down Expand Up @@ -143,7 +149,7 @@ impl<'a> Iterator for MergeByFileSize<'a> {
};
}

// Emit accumulator if ready
// Emit accumulator if ready or if merge count limit is reached
if self.accumulator_ready() {
return self.accumulator.take().map(Ok);
}
Expand Down
12 changes: 12 additions & 0 deletions tests/io/test_merge_scan_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,15 @@ def test_merge_scan_task_limit_override(csv_files):
):
df = daft.read_csv(str(csv_files)).limit(1)
assert df.num_partitions() == 3, "Should have 3 partitions [(CSV1, CSV2, CSV3)] since we have a limit 1"


def test_merge_scan_task_up_to_max_sources(csv_files):
with daft.execution_config_ctx(
scan_tasks_min_size_bytes=30,
scan_tasks_max_size_bytes=30,
max_sources_per_scan_task=2,
):
df = daft.read_csv(str(csv_files))
assert (
df.num_partitions() == 2
), "Should have 2 partitions [(CSV1, CSV2), (CSV3)] since the third CSV is too large to merge with the first two, and max_sources_per_scan_task is set to 2"

0 comments on commit 29cab30

Please sign in to comment.