Skip to content

Commit

Permalink
[PERF] Iceberg Partition Pruning (#1688)
Browse files Browse the repository at this point in the history
* Implements Partition Transforms which map source fields to partition
fields
* Implements Partition Filtering when creating scan tasks
* Implements Predicate to Partition Filter rewriting
* Allow Iceberg Scan to leverage partition filtering
* Implements EmptyScan which kicks in whenever we have no files to scan
* Fixes bug with incorrect length when we have a predicate and a limit
in a ScanTask
* Fixes bug with `df.num_partitions()` where we didn't optimize the
logical plan before computing the number of partitions
  • Loading branch information
samster25 authored Dec 22, 2023
1 parent b4f3ae1 commit 94bb370
Show file tree
Hide file tree
Showing 29 changed files with 821 additions and 188 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

26 changes: 23 additions & 3 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,8 @@ class ScanTask:
storage_config: StorageConfig,
size_bytes: int | None,
pushdowns: Pushdowns | None,
) -> ScanTask:
partition_values: PyTable | None,
) -> ScanTask | None:
"""
Create a Catalog Scan Task
"""
Expand Down Expand Up @@ -585,17 +586,36 @@ class PartitionField:
Partitioning Field of a Scan Source such as Hive or Iceberg
"""

field: PyField

def __init__(
self, field: PyField, source_field: PyField | None = None, transform: PyExpr | None = None
self, field: PyField, source_field: PyField | None = None, transform: PartitionTransform | None = None
) -> None: ...

class PartitionTransform:
"""
Partitioning Transform from a Data Catalog source field to a Partitioning Columns
"""

@staticmethod
def identity() -> PartitionTransform: ...
@staticmethod
def year() -> PartitionTransform: ...
@staticmethod
def month() -> PartitionTransform: ...
@staticmethod
def day() -> PartitionTransform: ...
@staticmethod
def hour() -> PartitionTransform: ...

class Pushdowns:
"""
Pushdowns from the query optimizer that can optimize scanning data sources.
"""

columns: list[str] | None
filters: list[str] | None
filters: PyExpr | None
partition_filters: PyExpr | None
limit: int | None

def read_parquet(
Expand Down
3 changes: 2 additions & 1 deletion daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@ def explain(self, show_optimized: bool = False, simple=False) -> None:

def num_partitions(self) -> int:
daft_execution_config = get_context().daft_execution_config
return self.__builder.to_physical_plan_scheduler(daft_execution_config).num_partitions()
# We need to run the optimizer since that could change the number of partitions
return self.__builder.optimize().to_physical_plan_scheduler(daft_execution_config).num_partitions()

@DataframePublicAPI
def schema(self) -> Schema:
Expand Down
1 change: 0 additions & 1 deletion daft/execution/physical_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,6 @@ def global_limit(
# since we will never take more than the remaining limit anyway.
child_plan = local_limit(child_plan=child_plan, limit=remaining_rows)
started = False

while True:
# Check if any inputs finished executing.
# Apply and deduct the rolling global limit.
Expand Down
32 changes: 31 additions & 1 deletion daft/execution/rust_physical_plan_shim.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,17 @@ def scan_with_tasks(
yield scan_step


def empty_scan(
schema: Schema,
) -> physical_plan.InProgressPhysicalPlan[PartitionT]:
"""yield a plan to create an empty Partition"""
scan_step = execution_step.PartitionTaskBuilder[PartitionT](inputs=[], partial_metadatas=None,).add_instruction(
instruction=EmptyScan(schema=schema),
resource_request=ResourceRequest(memory_bytes=0),
)
yield scan_step


@dataclass(frozen=True)
class ScanWithTask(execution_step.SingleOutputInstruction):
scan_task: ScanTask
Expand All @@ -51,7 +62,8 @@ def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]:

def _scan(self, inputs: list[MicroPartition]) -> list[MicroPartition]:
assert len(inputs) == 0
return [MicroPartition._from_scan_task(self.scan_task)]
table = MicroPartition._from_scan_task(self.scan_task)
return [table]

def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) -> list[PartialPartitionMetadata]:
assert len(input_metadatas) == 0
Expand All @@ -64,6 +76,24 @@ def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata])
]


@dataclass(frozen=True)
class EmptyScan(execution_step.SingleOutputInstruction):
schema: Schema

def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]:
return [MicroPartition.empty(self.schema)]

def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) -> list[PartialPartitionMetadata]:
assert len(input_metadatas) == 0

return [
PartialPartitionMetadata(
num_rows=0,
size_bytes=0,
)
]


def tabular_scan(
schema: PySchema,
columns_to_read: list[str] | None,
Expand Down
58 changes: 41 additions & 17 deletions daft/iceberg/iceberg_scan.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,32 @@
from __future__ import annotations

import logging
import warnings
from collections.abc import Iterator

import pyarrow as pa
from pyiceberg.io.pyarrow import schema_to_pyarrow
from pyiceberg.partitioning import PartitionField as IcebergPartitionField
from pyiceberg.partitioning import PartitionSpec as IcebergPartitionSpec
from pyiceberg.schema import Schema as IcebergSchema
from pyiceberg.table import Table
from pyiceberg.typedef import Record

import daft
from daft.daft import (
FileFormatConfig,
ParquetSourceConfig,
PartitionTransform,
Pushdowns,
ScanTask,
StorageConfig,
)
from daft.datatype import DataType
from daft.expressions.expressions import col
from daft.io.scan import PartitionField, ScanOperator, make_partition_field
from daft.logical.schema import Field, Schema

logger = logging.getLogger(__name__)


def _iceberg_partition_field_to_daft_partition_field(
iceberg_schema: IcebergSchema, pfield: IcebergPartitionField
Expand All @@ -37,7 +43,6 @@ def _iceberg_partition_field_to_daft_partition_field(
arrow_result_type = schema_to_pyarrow(iceberg_result_type)
daft_result_type = DataType.from_arrow_type(arrow_result_type)
result_field = Field.create(name, daft_result_type)

from pyiceberg.transforms import (
DayTransform,
HourTransform,
Expand All @@ -46,25 +51,20 @@ def _iceberg_partition_field_to_daft_partition_field(
YearTransform,
)

expr = None
tfm = None
if isinstance(transform, IdentityTransform):
expr = col(source_name)
if source_name != name:
expr = expr.alias(name)
tfm = PartitionTransform.identity()
elif isinstance(transform, YearTransform):
expr = col(source_name).dt.year().alias(name)
tfm = PartitionTransform.year()
elif isinstance(transform, MonthTransform):
expr = col(source_name).dt.month().alias(name)
tfm = PartitionTransform.month()
elif isinstance(transform, DayTransform):
expr = col(source_name).dt.day().alias(name)
tfm = PartitionTransform.day()
elif isinstance(transform, HourTransform):
warnings.warn(
"HourTransform not implemented, Please make a comment: https://github.com/Eventual-Inc/Daft/issues/1606"
)
tfm = PartitionTransform.hour()
else:
warnings.warn(f"{transform} not implemented, Please make an issue!")

return make_partition_field(result_field, daft_field, transform=expr)
return make_partition_field(result_field, daft_field, transform=tfm)


def iceberg_partition_spec_to_fields(iceberg_schema: IcebergSchema, spec: IcebergPartitionSpec) -> list[PartitionField]:
Expand All @@ -83,15 +83,37 @@ def __init__(self, iceberg_table: Table, storage_config: StorageConfig) -> None:
def schema(self) -> Schema:
return self._schema

def display_name(self) -> str:
return f"IcebergScanOperator({'.'.join(self._table.name())})"

def partitioning_keys(self) -> list[PartitionField]:
return self._partition_keys

def _iceberg_record_to_partition_spec(self, record: Record) -> daft.table.Table | None:
arrays = dict()
assert len(record._position_to_field_name) == len(self._partition_keys)
for name, value, pfield in zip(record._position_to_field_name, record.record_fields(), self._partition_keys):
field = Field._from_pyfield(pfield.field)
field_name = field.name
field_dtype = field.dtype
arrow_type = field_dtype.to_arrow_dtype()
assert name == field_name
arrays[name] = daft.Series.from_arrow(pa.array([value], type=arrow_type), name=name).cast(field_dtype)
if len(arrays) > 0:
return daft.table.Table.from_pydict(arrays)
else:
return None

def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]:
limit = pushdowns.limit
iceberg_tasks = self._table.scan(limit=limit).plan_files()

limit_files = limit is not None and pushdowns.filters is None
limit_files = limit is not None and pushdowns.filters is None and pushdowns.partition_filters is None

if len(self.partitioning_keys()) > 0 and pushdowns.partition_filters is None:
logging.warn(
f"{self.display_name()} has Partitioning Keys: {self.partitioning_keys()} but no partition filter was specified. This will result in a full table scan."
)
scan_tasks = []

if limit is not None:
Expand All @@ -114,9 +136,8 @@ def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]:
if len(task.delete_files) > 0:
raise NotImplementedError(f"Iceberg Merge-on-Read currently not supported, please make an issue!")

# TODO: Thread in PartitionSpec to each ScanTask: P1
# TODO: Thread in Statistics to each ScanTask: P2

pspec = self._iceberg_record_to_partition_spec(file.partition)
st = ScanTask.catalog_scan_task(
file=path,
file_format=file_format_config,
Expand All @@ -125,7 +146,10 @@ def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]:
storage_config=self._storage_config,
size_bytes=file.file_size_in_bytes,
pushdowns=pushdowns,
partition_values=pspec._table if pspec is not None else None,
)
if st is None:
continue
rows_left -= record_count
scan_tasks.append(st)
return iter(scan_tasks)
Expand Down
11 changes: 7 additions & 4 deletions daft/io/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,17 @@
import abc
from collections.abc import Iterator

from daft.daft import PartitionField, Pushdowns, ScanTask
from daft.expressions.expressions import Expression
from daft.daft import PartitionField, PartitionTransform, Pushdowns, ScanTask
from daft.logical.schema import Field, Schema


def make_partition_field(
field: Field, source_field: Field | None = None, transform: Expression | None = None
field: Field, source_field: Field | None = None, transform: PartitionTransform | None = None
) -> PartitionField:
return PartitionField(
field._field,
source_field._field if source_field is not None else None,
transform._expr if transform is not None else None,
transform,
)


Expand All @@ -23,6 +22,10 @@ class ScanOperator(abc.ABC):
def schema(self) -> Schema:
raise NotImplementedError()

@abc.abstractmethod
def display_name(self) -> str:
return self.__class__.__name__

@abc.abstractmethod
def partitioning_keys(self) -> list[PartitionField]:
raise NotImplementedError()
Expand Down
29 changes: 13 additions & 16 deletions src/daft-core/src/series/ops/partitioning.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
use crate::datatypes::logical::TimestampArray;
use crate::datatypes::{Int32Array, Int64Array, TimeUnit};
use crate::series::array_impl::IntoSeries;
use crate::{
datatypes::{logical::DateArray, DataType},
series::Series,
};
use crate::{datatypes::DataType, series::Series};
use common_error::{DaftError, DaftResult};

impl Series {
pub fn partitioning_years(&self) -> DaftResult<Self> {
let epoch_year = Int32Array::from(("1970", vec![1970])).into_series();

match self.data_type() {
let value = match self.data_type() {
DataType::Date | DataType::Timestamp(_, None) => {
let years_since_ce = self.dt_year()?;
&years_since_ce - &epoch_year
Expand All @@ -25,13 +22,14 @@ impl Series {
"Can only run partitioning_years() operation on temporal types, got {}",
self.data_type()
))),
}
}?;
value.cast(&DataType::Int32)
}

pub fn partitioning_months(&self) -> DaftResult<Self> {
let months_in_year = Int32Array::from(("months", vec![12])).into_series();
let month_of_epoch = Int32Array::from(("months", vec![1])).into_series();
match self.data_type() {
let value = match self.data_type() {
DataType::Date | DataType::Timestamp(_, None) => {
let years_since_1970 = self.partitioning_years()?;
let months_of_this_year = self.dt_month()?;
Expand All @@ -47,24 +45,22 @@ impl Series {
"Can only run partitioning_years() operation on temporal types, got {}",
self.data_type()
))),
}
}?;
value.cast(&DataType::Int32)
}

pub fn partitioning_days(&self) -> DaftResult<Self> {
match self.data_type() {
DataType::Date => {
let downcasted = self.downcast::<DateArray>()?;
downcasted.cast(&DataType::Int32)
}
DataType::Date => Ok(self.clone()),
DataType::Timestamp(_, None) => {
let ts_array = self.downcast::<TimestampArray>()?;
ts_array.date()?.cast(&DataType::Int32)
Ok(ts_array.date()?.into_series())
}

DataType::Timestamp(tu, Some(_)) => {
let array = self.cast(&DataType::Timestamp(*tu, None))?;
let ts_array = array.downcast::<TimestampArray>()?;
ts_array.date()?.cast(&DataType::Int32)
Ok(ts_array.date()?.into_series())
}

_ => Err(DaftError::ComputeError(format!(
Expand All @@ -75,7 +71,7 @@ impl Series {
}

pub fn partitioning_hours(&self) -> DaftResult<Self> {
match self.data_type() {
let value = match self.data_type() {
DataType::Timestamp(unit, _) => {
let ts_array = self.downcast::<TimestampArray>()?;
let physical = &ts_array.physical;
Expand All @@ -93,6 +89,7 @@ impl Series {
"Can only run partitioning_hours() operation on timestamp types, got {}",
self.data_type()
))),
}
}?;
value.cast(&DataType::Int32)
}
}
Loading

0 comments on commit 94bb370

Please sign in to comment.