diff --git a/.github/ci-scripts/format_env_vars.py b/.github/ci-scripts/format_env_vars.py deleted file mode 100644 index 870c007dc8..0000000000 --- a/.github/ci-scripts/format_env_vars.py +++ /dev/null @@ -1,37 +0,0 @@ -"""Given a comma-separated string of environment variables, parse them into a dictionary. - -Example: - env_str = "a=1,b=2" - result = parse_env_var_str(env_str) - # returns {"a":1,"b":2} -""" - -import argparse -import json - - -def parse_env_var_str(env_var_str: str) -> dict: - iter = map( - lambda s: s.strip().split("="), - filter(lambda s: s, env_var_str.split(",")), - ) - return {k: v for k, v in iter} - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--enable-ray-tracing", action="store_true") - parser.add_argument("--env-vars", required=True) - args = parser.parse_args() - - env_vars = parse_env_var_str(args.env_vars) - if args.enable_ray_tracing: - env_vars["DAFT_ENABLE_RAY_TRACING"] = "1" - ray_env_vars = { - "env_vars": env_vars, - } - print(json.dumps(ray_env_vars)) - - -if __name__ == "__main__": - main() diff --git a/.github/ci-scripts/job_runner.py b/.github/ci-scripts/job_runner.py index c36226c1ab..2fe2cf448f 100644 --- a/.github/ci-scripts/job_runner.py +++ b/.github/ci-scripts/job_runner.py @@ -15,6 +15,10 @@ from ray.job_submission import JobStatus, JobSubmissionClient +# We impose a 5min timeout here +# If any job does *not* finish in 5min, then we cancel it and mark the question as a "DNF" (did-not-finish). +TIMEOUT_S = 60 * 5 + def parse_env_var_str(env_var_str: str) -> dict: iter = map( @@ -29,13 +33,17 @@ async def print_logs(logs): print(lines, end="") -async def wait_on_job(logs, timeout_s): - await asyncio.wait_for(print_logs(logs), timeout=timeout_s) +async def wait_on_job(logs, timeout_s) -> bool: + try: + await asyncio.wait_for(print_logs(logs), timeout=timeout_s) + return False + except asyncio.exceptions.TimeoutError: + return True @dataclass class Result: - query: int + arguments: str duration: timedelta error_msg: Optional[str] @@ -45,7 +53,6 @@ def submit_job( entrypoint_script: str, entrypoint_args: str, env_vars: str, - enable_ray_tracing: bool, ): if "GHA_OUTPUT_DIR" not in os.environ: raise RuntimeError("Output directory environment variable not found; don't know where to store outputs") @@ -53,8 +60,6 @@ def submit_job( output_dir.mkdir(exist_ok=True, parents=True) env_vars_dict = parse_env_var_str(env_vars) - if enable_ray_tracing: - env_vars_dict["DAFT_ENABLE_RAY_TRACING"] = "1" client = JobSubmissionClient(address="http://localhost:8265") @@ -66,7 +71,7 @@ def submit_job( results = [] - for index, args in enumerate(list_of_entrypoint_args): + for args in list_of_entrypoint_args: entrypoint = f"DAFT_RUNNER=ray python {entrypoint_script} {args}" print(f"{entrypoint=}") start = datetime.now() @@ -78,18 +83,20 @@ def submit_job( }, ) - asyncio.run(wait_on_job(client.tail_job_logs(job_id), timeout_s=60 * 30)) + timed_out = asyncio.run(wait_on_job(client.tail_job_logs(job_id), timeout_s=TIMEOUT_S)) status = client.get_job_status(job_id) - assert status.is_terminal(), "Job should have terminated" end = datetime.now() duration = end - start error_msg = None if status != JobStatus.SUCCEEDED: - job_info = client.get_job_info(job_id) - error_msg = job_info.message + if timed_out: + error_msg = f"Job exceeded {TIMEOUT_S} second(s)" + else: + job_info = client.get_job_info(job_id) + error_msg = job_info.message - result = Result(query=index, duration=duration, error_msg=error_msg) + result = Result(arguments=args, duration=duration, error_msg=error_msg) results.append(result) output_file = output_dir / "out.csv" @@ -106,7 +113,6 @@ def submit_job( parser.add_argument("--entrypoint-script", type=str, required=True) parser.add_argument("--entrypoint-args", type=str, required=True) parser.add_argument("--env-vars", type=str, required=True) - parser.add_argument("--enable-ray-tracing", action="store_true") args = parser.parse_args() @@ -122,5 +128,4 @@ def submit_job( entrypoint_script=args.entrypoint_script, entrypoint_args=args.entrypoint_args, env_vars=args.env_vars, - enable_ray_tracing=args.enable_ray_tracing, ) diff --git a/.github/ci-scripts/templatize_ray_config.py b/.github/ci-scripts/templatize_ray_config.py index 1608cf8dee..5828b99be3 100644 --- a/.github/ci-scripts/templatize_ray_config.py +++ b/.github/ci-scripts/templatize_ray_config.py @@ -60,6 +60,20 @@ class Metadata(BaseModel, extra="allow"): sudo chmod 777 /tmp fi""", ), + "benchmarking-arm": Profile( + instance_type="i8g.4xlarge", + image_id="ami-0d4eea77bb23270f4", + node_count=8, + ssh_user="ubuntu", + volume_mount=""" | + findmnt /tmp 1> /dev/null + code=$? + if [ $code -ne 0 ]; then + sudo mkfs.ext4 /dev/nvme0n1 + sudo mount -t ext4 /dev/nvme0n1 /tmp + sudo chmod 777 /tmp + fi""", + ), } @@ -71,7 +85,7 @@ class Metadata(BaseModel, extra="allow"): parser.add_argument("--daft-wheel-url") parser.add_argument("--daft-version") parser.add_argument("--python-version", required=True) - parser.add_argument("--cluster-profile", required=True, choices=["debug_xs-x86", "medium-x86"]) + parser.add_argument("--cluster-profile", required=True, choices=["debug_xs-x86", "medium-x86", "benchmarking-arm"]) parser.add_argument("--working-dir", required=True) parser.add_argument("--entrypoint-script", required=True) args = parser.parse_args() diff --git a/.github/workflows/run-cluster.yaml b/.github/workflows/run-cluster.yaml index 7bb35ac765..1a726547af 100644 --- a/.github/workflows/run-cluster.yaml +++ b/.github/workflows/run-cluster.yaml @@ -20,6 +20,7 @@ on: description: Cluster profile type: choice options: + - benchmarking-arm - medium-x86 - debug_xs-x86 required: false @@ -49,7 +50,7 @@ jobs: uses: ./.github/workflows/build-commit.yaml if: ${{ inputs.daft_version == '' && inputs.daft_wheel_url == '' }} with: - arch: x86 + arch: ${{ (inputs.cluster_profile == 'debug_xs-x86' || inputs.cluster_profile == 'medium-x86') && 'x86' || 'arm' }} python_version: ${{ inputs.python_version }} secrets: ACTIONS_AWS_ROLE_ARN: ${{ secrets.ACTIONS_AWS_ROLE_ARN }} @@ -131,13 +132,12 @@ jobs: --entrypoint-script='${{ inputs.entrypoint_script }}' \ --entrypoint-args='${{ inputs.entrypoint_args }}' \ --env-vars='${{ inputs.env_vars }}' \ - --enable-ray-tracing - name: Download log files from ray cluster if: always() run: | source .venv/bin/activate - ray rsync-down .github/assets/ray.yaml /tmp/ray/session_*/logs ray-daft-logs - find ray-daft-logs -depth -name '*:*' -exec bash -c ' + ray rsync-down .github/assets/ray.yaml /tmp/ray/session_*/logs ray-logs + find ray-logs -depth -name '*:*' -exec bash -c ' for filepath; do dir=$(dirname "$filepath") base=$(basename "$filepath") @@ -172,5 +172,5 @@ jobs: if: always() uses: actions/upload-artifact@v4 with: - name: ray-daft-logs - path: ray-daft-logs + name: ray-logs + path: ray-logs diff --git a/.gitignore b/.gitignore index bb5369738d..e21b99245a 100644 --- a/.gitignore +++ b/.gitignore @@ -43,3 +43,6 @@ log/ # helix editor .helix + +# uv +uv.lock diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 16021912bd..e9f85b7763 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -30,7 +30,8 @@ repos: (?x)^( tutorials/.*\.ipynb| docs/.*\.ipynb| - docs/source/user_guide/fotw/data/ + docs/source/user_guide/fotw/data/| + .*\.jsonl )$ args: - --autofix diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 0113c8cedb..6f1d09e063 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -36,6 +36,7 @@ To set up your development environment: 1. `make build`: recompile your code after modifying any Rust code in `src/` 2. `make test`: run tests 3. `DAFT_RUNNER=ray make test`: set the runner to the Ray runner and run tests (DAFT_RUNNER defaults to `py`) +4. `make docs`: build and serve docs ### Developing with Ray diff --git a/Cargo.lock b/Cargo.lock index f62fd7e570..db6b9e5e3b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1518,7 +1518,9 @@ version = "0.3.0-dev0" dependencies = [ "aws-credential-types", "chrono", + "common-error", "common-py-serde", + "derivative", "derive_more", "pyo3", "secrecy", @@ -1941,6 +1943,7 @@ dependencies = [ "daft-minhash", "daft-parquet", "daft-physical-plan", + "daft-ray-execution", "daft-scan", "daft-scheduler", "daft-sql", @@ -2005,13 +2008,16 @@ version = "0.3.0-dev0" dependencies = [ "arrow2", "async-stream", - "common-daft-config", + "common-error", "common-file-formats", + "common-runtime", + "daft-catalog", "daft-core", "daft-dsl", "daft-local-execution", "daft-logical-plan", "daft-micropartition", + "daft-ray-execution", "daft-scan", "daft-schema", "daft-sql", @@ -2020,8 +2026,10 @@ dependencies = [ "eyre", "futures", "itertools 0.11.0", + "once_cell", "pyo3", "spark-connect", + "textwrap", "tokio", "tonic", "tracing", @@ -2126,10 +2134,8 @@ dependencies = [ "derive_more", "indexmap 2.7.0", "itertools 0.11.0", - "log", "pyo3", "serde", - "typed-builder 0.20.0", "typetag", ] @@ -2142,7 +2148,6 @@ dependencies = [ "bytes", "common-error", "common-hashable-float-wrapper", - "common-io-config", "common-runtime", "daft-core", "daft-dsl", @@ -2373,6 +2378,8 @@ dependencies = [ "serde", "snafu", "test-log", + "tokio", + "typed-builder 0.20.0", "uuid 1.11.0", ] @@ -2472,6 +2479,16 @@ dependencies = [ "serde", ] +[[package]] +name = "daft-ray-execution" +version = "0.3.0-dev0" +dependencies = [ + "common-error", + "daft-logical-plan", + "daft-micropartition", + "pyo3", +] + [[package]] name = "daft-scan" version = "0.3.0-dev0" @@ -2571,6 +2588,7 @@ dependencies = [ "common-io-config", "common-runtime", "daft-algebra", + "daft-catalog", "daft-core", "daft-dsl", "daft-functions", @@ -5268,6 +5286,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e484fd2c8b4cb67ab05a318f1fd6fa8f199fcc30819f08f07d200809dba26c15" dependencies = [ "cfg-if", + "chrono", "indexmap 2.7.0", "indoc", "inventory", diff --git a/Cargo.toml b/Cargo.toml index eae95f1264..5f8ad3ba37 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,6 +31,7 @@ daft-micropartition = {path = "src/daft-micropartition", default-features = fals daft-minhash = {path = "src/daft-minhash", default-features = false} daft-parquet = {path = "src/daft-parquet", default-features = false} daft-physical-plan = {path = "src/daft-physical-plan", default-features = false} +daft-ray-execution = {path = "src/daft-ray-execution", default-features = false} daft-scan = {path = "src/daft-scan", default-features = false} daft-scheduler = {path = "src/daft-scheduler", default-features = false} daft-sql = {path = "src/daft-sql", default-features = false} @@ -56,7 +57,6 @@ python = [ "common-system-info/python", "daft-catalog-python-catalog/python", "daft-catalog/python", - "daft-connect/python", "daft-core/python", "daft-csv/python", "daft-dsl/python", @@ -172,7 +172,8 @@ members = [ "src/parquet2", # "src/spark-connect-script", "src/generated/spark-connect", - "src/common/partitioning" + "src/common/partitioning", + "src/daft-ray-execution" ] [workspace.dependencies] @@ -200,6 +201,7 @@ daft-hash = {path = "src/daft-hash"} daft-local-execution = {path = "src/daft-local-execution"} daft-logical-plan = {path = "src/daft-logical-plan"} daft-micropartition = {path = "src/daft-micropartition"} +daft-ray-execution = {path = "src/daft-ray-execution"} daft-scan = {path = "src/daft-scan"} daft-schema = {path = "src/daft-schema"} daft-sql = {path = "src/daft-sql"} @@ -277,7 +279,7 @@ features = ['async'] path = "src/parquet2" [workspace.dependencies.pyo3] -features = ["extension-module", "multiple-pymethods", "abi3-py39", "indexmap"] +features = ["extension-module", "multiple-pymethods", "abi3-py39", "indexmap", "chrono"] version = "0.23.3" [workspace.dependencies.pyo3-log] diff --git a/Makefile b/Makefile index e6def0bf93..c21528a7eb 100644 --- a/Makefile +++ b/Makefile @@ -70,6 +70,10 @@ test: .venv build ## Run tests dsdgen: .venv ## Generate TPC-DS data $(VENV_BIN)/python benchmarking/tpcds/datagen.py --scale-factor=$(SCALE_FACTOR) --tpcds-gen-folder=$(OUTPUT_DIR) +.PHONY: docs +docs: .venv ## Serve docs + uv run --with-requirements requirements-docs.txt mkdocs serve + .PHONY: clean clean: rm -rf $(VENV) diff --git a/README.rst b/README.rst index c6e3152259..aad9d86b77 100644 --- a/README.rst +++ b/README.rst @@ -142,7 +142,7 @@ Daft has an Apache 2.0 license - please see the LICENSE file. .. |Benchmark Image| image:: https://github-production-user-asset-6210df.s3.amazonaws.com/2550285/243524430-338e427d-f049-40b3-b555-4059d6be7bfd.png :alt: Benchmarks for SF100 TPCH -.. |Banner| image:: https://github.com/user-attachments/assets/ac676800-b799-454e-a6e0-9a58974a4154 +.. |Banner| image:: https://github.com/user-attachments/assets/da7a2a93-9464-4c8d-b5bd-759731610356 :target: https://www.getdaft.io :alt: Daft dataframes can load any data such as PDF documents, images, protobufs, csv, parquet and audio files into a table dataframe structure for easy querying diff --git a/benchmarking/tpcds/ray_entrypoint.py b/benchmarking/tpcds/ray_entrypoint.py index 0d1ced05bc..26668ca089 100644 --- a/benchmarking/tpcds/ray_entrypoint.py +++ b/benchmarking/tpcds/ray_entrypoint.py @@ -78,8 +78,8 @@ def run( { "question": question, "scale-factor": scale_factor, - "planning-time": explain_delta, - "execution-time": execute_delta, + "planning-time": str(explain_delta), + "execution-time": str(execute_delta), } ) f.write(stats) diff --git a/daft/catalog/__init__.py b/daft/catalog/__init__.py index 5e74fd4c08..b6ac2d26fa 100644 --- a/daft/catalog/__init__.py +++ b/daft/catalog/__init__.py @@ -27,7 +27,7 @@ ```python df = daft.from_pydict({"foo": [1, 2, 3]}) -daft.catalog.register_named_table( +daft.catalog.register_table( "my_table", df, ) diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index 0f70c0d606..86e3c3b2ca 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -461,6 +461,10 @@ class S3Config: """Creates an S3Config, retrieving credentials and configurations from the current environment.""" ... + def provide_cached_credentials(self) -> S3Credentials | None: + """Wrapper around call to `S3Config.credentials_provider` to cache credentials until expiry.""" + ... + class S3Credentials: key_id: str access_key: str @@ -965,6 +969,7 @@ class PyExpr: def is_null(self) -> PyExpr: ... def not_null(self) -> PyExpr: ... def fill_null(self, fill_value: PyExpr) -> PyExpr: ... + def eq_null_safe(self, other: PyExpr) -> PyExpr: ... def is_in(self, other: list[PyExpr]) -> PyExpr: ... def between(self, lower: PyExpr, upper: PyExpr) -> PyExpr: ... def name(self) -> str: ... @@ -1200,6 +1205,13 @@ def utf8_normalize( expr: PyExpr, remove_punct: bool, lowercase: bool, nfd_unicode: bool, white_space: bool ) -> PyExpr: ... +# --- +# expr.binary namespace +# --- +def binary_length(expr: PyExpr) -> PyExpr: ... +def binary_concat(left: PyExpr, right: PyExpr) -> PyExpr: ... +def binary_slice(expr: PyExpr, start: PyExpr, length: PyExpr | None = None) -> PyExpr: ... + class PyCatalog: @staticmethod def new() -> PyCatalog: ... @@ -1683,10 +1695,7 @@ class LogicalPlanBuilder: def repr_mermaid(self, options: MermaidOptions) -> str: ... class NativeExecutor: - @staticmethod - def from_logical_plan_builder( - logical_plan_builder: LogicalPlanBuilder, - ) -> NativeExecutor: ... + def __init__(self) -> None: ... def run( self, psets: dict[str, list[PartitionT]], cfg: PyDaftExecutionConfig, results_buffer_size: int | None ) -> Iterator[PyMicroPartition]: ... diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index 2735549f2b..d128fb2262 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -552,7 +552,7 @@ def write_parquet( self, root_dir: Union[str, pathlib.Path], compression: str = "snappy", - write_mode: Literal["append", "overwrite"] = "append", + write_mode: Literal["append", "overwrite", "overwrite-partitions"] = "append", partition_cols: Optional[List[ColumnInputType]] = None, io_config: Optional[IOConfig] = None, ) -> "DataFrame": @@ -566,7 +566,7 @@ def write_parquet( Args: root_dir (str): root file path to write parquet files to. compression (str, optional): compression algorithm. Defaults to "snappy". - write_mode (str, optional): Operation mode of the write. `append` will add new data, `overwrite` will replace table with new data. Defaults to "append". + write_mode (str, optional): Operation mode of the write. `append` will add new data, `overwrite` will replace the contents of the root directory with new data. `overwrite-partitions` will replace only the contents in the partitions that are being written to. Defaults to "append". partition_cols (Optional[List[ColumnInputType]], optional): How to subpartition each partition further. Defaults to None. io_config (Optional[IOConfig], optional): configurations to use when interacting with remote storage. @@ -576,8 +576,12 @@ def write_parquet( .. NOTE:: This call is **blocking** and will execute the DataFrame when called """ - if write_mode not in ["append", "overwrite"]: - raise ValueError(f"Only support `append` or `overwrite` mode. {write_mode} is unsupported") + if write_mode not in ["append", "overwrite", "overwrite-partitions"]: + raise ValueError( + f"Only support `append`, `overwrite`, or `overwrite-partitions` mode. {write_mode} is unsupported" + ) + if write_mode == "overwrite-partitions" and partition_cols is None: + raise ValueError("Partition columns must be specified to use `overwrite-partitions` mode.") io_config = get_context().daft_planning_config.default_io_config if io_config is None else io_config @@ -598,7 +602,9 @@ def write_parquet( assert write_df._result is not None if write_mode == "overwrite": - overwrite_files(write_df, root_dir, io_config) + overwrite_files(write_df, root_dir, io_config, False) + elif write_mode == "overwrite-partitions": + overwrite_files(write_df, root_dir, io_config, True) if len(write_df) > 0: # Populate and return a new disconnected DataFrame @@ -624,7 +630,7 @@ def write_parquet( def write_csv( self, root_dir: Union[str, pathlib.Path], - write_mode: Literal["append", "overwrite"] = "append", + write_mode: Literal["append", "overwrite", "overwrite-partitions"] = "append", partition_cols: Optional[List[ColumnInputType]] = None, io_config: Optional[IOConfig] = None, ) -> "DataFrame": @@ -637,15 +643,19 @@ def write_csv( Args: root_dir (str): root file path to write parquet files to. - write_mode (str, optional): Operation mode of the write. `append` will add new data, `overwrite` will replace table with new data. Defaults to "append". + write_mode (str, optional): Operation mode of the write. `append` will add new data, `overwrite` will replace the contents of the root directory with new data. `overwrite-partitions` will replace only the contents in the partitions that are being written to. Defaults to "append". partition_cols (Optional[List[ColumnInputType]], optional): How to subpartition each partition further. Defaults to None. io_config (Optional[IOConfig], optional): configurations to use when interacting with remote storage. Returns: DataFrame: The filenames that were written out as strings. """ - if write_mode not in ["append", "overwrite"]: - raise ValueError(f"Only support `append` or `overwrite` mode. {write_mode} is unsupported") + if write_mode not in ["append", "overwrite", "overwrite-partitions"]: + raise ValueError( + f"Only support `append`, `overwrite`, or `overwrite-partitions` mode. {write_mode} is unsupported" + ) + if write_mode == "overwrite-partitions" and partition_cols is None: + raise ValueError("Partition columns must be specified to use `overwrite-partitions` mode.") io_config = get_context().daft_planning_config.default_io_config if io_config is None else io_config @@ -665,7 +675,9 @@ def write_csv( assert write_df._result is not None if write_mode == "overwrite": - overwrite_files(write_df, root_dir, io_config) + overwrite_files(write_df, root_dir, io_config, False) + elif write_mode == "overwrite-partitions": + overwrite_files(write_df, root_dir, io_config, True) if len(write_df) > 0: # Populate and return a new disconnected DataFrame diff --git a/daft/execution/execution_step.py b/daft/execution/execution_step.py index 41f2d4c775..9670ff42b7 100644 --- a/daft/execution/execution_step.py +++ b/daft/execution/execution_step.py @@ -51,6 +51,9 @@ class PartitionTask(Generic[PartitionT]): # This is used when a specific executor (e.g. an Actor pool) must be provisioned and used for the task actor_pool_id: str | None + # Indicates that the metadata of the result partition should be cached when the task is done + cache_metadata_on_done: bool = True + # Indicates if the PartitionTask is "done" or not is_done: bool = False @@ -70,11 +73,17 @@ def set_done(self): """Sets the PartitionTask as done.""" assert not self.is_done, "Cannot set PartitionTask as done more than once" self.is_done = True + if self.cache_metadata_on_done: + self.cache_metadata() def cancel(self) -> None: """If possible, cancel the execution of this PartitionTask.""" raise NotImplementedError() + def cache_metadata(self) -> None: + """Cache the metadata of the result partition.""" + raise NotImplementedError() + def set_result(self, result: list[MaterializedResult[PartitionT]]) -> None: """Set the result of this Task. For use by the Task executor. @@ -140,7 +149,9 @@ def is_empty(self) -> bool: """Whether this partition task is guaranteed to result in an empty partition.""" return len(self.partial_metadatas) > 0 and all(meta.num_rows == 0 for meta in self.partial_metadatas) - def finalize_partition_task_single_output(self, stage_id: int) -> SingleOutputPartitionTask[PartitionT]: + def finalize_partition_task_single_output( + self, stage_id: int, cache_metadata_on_done: bool = True + ) -> SingleOutputPartitionTask[PartitionT]: """Create a SingleOutputPartitionTask from this PartitionTaskBuilder. Returns a "frozen" version of this PartitionTask that cannot have instructions added. @@ -162,9 +173,12 @@ def finalize_partition_task_single_output(self, stage_id: int) -> SingleOutputPa partial_metadatas=self.partial_metadatas, actor_pool_id=self.actor_pool_id, node_id=self.node_id, + cache_metadata_on_done=cache_metadata_on_done, ) - def finalize_partition_task_multi_output(self, stage_id: int) -> MultiOutputPartitionTask[PartitionT]: + def finalize_partition_task_multi_output( + self, stage_id: int, cache_metadata_on_done: bool = True + ) -> MultiOutputPartitionTask[PartitionT]: """Create a MultiOutputPartitionTask from this PartitionTaskBuilder. Same as finalize_partition_task_single_output, except the output of this PartitionTask is a list of partitions. @@ -184,6 +198,7 @@ def finalize_partition_task_multi_output(self, stage_id: int) -> MultiOutputPart partial_metadatas=self.partial_metadatas, actor_pool_id=self.actor_pool_id, node_id=self.node_id, + cache_metadata_on_done=cache_metadata_on_done, ) def __str__(self) -> str: @@ -201,6 +216,7 @@ class SingleOutputPartitionTask(PartitionTask[PartitionT]): # When available, the partition created from running the PartitionTask. _result: None | MaterializedResult[PartitionT] = None + _partition_metadata: None | PartitionMetadata = None def set_result(self, result: list[MaterializedResult[PartitionT]]) -> None: assert self._result is None, f"Cannot set result twice. Result is already {self._result}" @@ -220,13 +236,22 @@ def partition(self) -> PartitionT: """Get the PartitionT resulting from running this PartitionTask.""" return self.result().partition() + def cache_metadata(self) -> None: + assert self._result is not None, "Cannot cache metadata without a result" + if self._partition_metadata is not None: + return + + [partial_metadata] = self.partial_metadatas + self._partition_metadata = self.result().metadata().merge_with_partial(partial_metadata) + def partition_metadata(self) -> PartitionMetadata: """Get the metadata of the result partition. (Avoids retrieving the actual partition itself if possible.) """ - [partial_metadata] = self.partial_metadatas - return self.result().metadata().merge_with_partial(partial_metadata) + self.cache_metadata() + assert self._partition_metadata is not None + return self._partition_metadata def micropartition(self) -> MicroPartition: """Get the raw vPartition of the result.""" @@ -249,6 +274,7 @@ class MultiOutputPartitionTask(PartitionTask[PartitionT]): # When available, the partitions created from running the PartitionTask. _results: None | list[MaterializedResult[PartitionT]] = None + _partition_metadatas: None | list[PartitionMetadata] = None def set_result(self, result: list[MaterializedResult[PartitionT]]) -> None: assert self._results is None, f"Cannot set result twice. Result is already {self._results}" @@ -264,16 +290,24 @@ def partitions(self) -> list[PartitionT]: assert self._results is not None return [result.partition() for result in self._results] + def cache_metadata(self) -> None: + assert self._results is not None, "Cannot cache metadata without a result" + if self._partition_metadatas is not None: + return + + self._partition_metadatas = [ + result.metadata().merge_with_partial(partial_metadata) + for result, partial_metadata in zip(self._results, self.partial_metadatas) + ] + def partition_metadatas(self) -> list[PartitionMetadata]: """Get the metadata of the result partitions. (Avoids retrieving the actual partition itself if possible.) """ - assert self._results is not None - return [ - result.metadata().merge_with_partial(partial_metadata) - for result, partial_metadata in zip(self._results, self.partial_metadatas) - ] + self.cache_metadata() + assert self._partition_metadatas is not None + return self._partition_metadatas def micropartition(self, index: int) -> MicroPartition: """Get the raw vPartition of the result.""" diff --git a/daft/execution/native_executor.py b/daft/execution/native_executor.py index 1958c6b90f..333db5fc4c 100644 --- a/daft/execution/native_executor.py +++ b/daft/execution/native_executor.py @@ -5,10 +5,10 @@ from daft.daft import ( NativeExecutor as _NativeExecutor, ) -from daft.daft import PyDaftExecutionConfig from daft.table import MicroPartition if TYPE_CHECKING: + from daft.daft import PyDaftExecutionConfig from daft.logical.builder import LogicalPlanBuilder from daft.runners.partitioning import ( LocalMaterializedResult, @@ -18,16 +18,12 @@ class NativeExecutor: - def __init__(self, executor: _NativeExecutor): - self._executor = executor - - @classmethod - def from_logical_plan_builder(cls, builder: LogicalPlanBuilder) -> NativeExecutor: - executor = _NativeExecutor.from_logical_plan_builder(builder._builder) - return cls(executor) + def __init__(self): + self._executor = _NativeExecutor() def run( self, + builder: LogicalPlanBuilder, psets: dict[str, list[MaterializedResult[PartitionT]]], daft_execution_config: PyDaftExecutionConfig, results_buffer_size: int | None, @@ -39,5 +35,5 @@ def run( } return ( LocalMaterializedResult(MicroPartition._from_pymicropartition(part)) - for part in self._executor.run(psets_mp, daft_execution_config, results_buffer_size) + for part in self._executor.run(builder._builder, psets_mp, daft_execution_config, results_buffer_size) ) diff --git a/daft/execution/physical_plan.py b/daft/execution/physical_plan.py index be3f6739c3..7d1ae3cbe2 100644 --- a/daft/execution/physical_plan.py +++ b/daft/execution/physical_plan.py @@ -1681,7 +1681,7 @@ def _best_effort_next_step( return (None, False) else: if isinstance(step, PartitionTaskBuilder): - step = step.finalize_partition_task_single_output(stage_id=stage_id) + step = step.finalize_partition_task_single_output(stage_id=stage_id, cache_metadata_on_done=False) return (step, True) elif isinstance(step, PartitionTask): return (step, False) @@ -1771,7 +1771,7 @@ def __iter__(self) -> MaterializedPhysicalPlan: try: step = next(self.child_plan) if isinstance(step, PartitionTaskBuilder): - step = step.finalize_partition_task_single_output(stage_id=stage_id) + step = step.finalize_partition_task_single_output(stage_id=stage_id, cache_metadata_on_done=False) self.materializations.append(step) num_final_yielded += 1 logger.debug("[plan-%s] YIELDING final task (%s so far)", stage_id, num_final_yielded) diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 0e68c2cbdf..82f8b8304e 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -273,6 +273,15 @@ def json(self) -> ExpressionJsonNamespace: """Access methods that work on columns of json.""" return ExpressionJsonNamespace.from_expression(self) + @property + def binary(self) -> ExpressionBinaryNamespace: + """Access binary string operations for this expression. + + Returns: + ExpressionBinaryNamespace: A namespace containing binary string operations + """ + return ExpressionBinaryNamespace.from_expression(self) + @staticmethod def _from_pyexpr(pyexpr: _PyExpr) -> Expression: expr = Expression.__new__(Expression) @@ -457,6 +466,23 @@ def __eq__(self, other: Expression) -> Expression: # type: ignore expr = Expression._to_expression(other) return Expression._from_pyexpr(self._expr == expr._expr) + def eq_null_safe(self, other: Expression) -> Expression: + """Performs a null-safe equality comparison between two expressions. + + Unlike regular equality (==), null-safe equality (<=> or IS NOT DISTINCT FROM): + - Returns True when comparing NULL <=> NULL + - Returns False when comparing NULL <=> any_value + - Behaves like regular equality for non-NULL values + + Args: + other: The expression to compare with + + Returns: + Expression: A boolean expression indicating if the values are equal + """ + expr = Expression._to_expression(other) + return Expression._from_pyexpr(self._expr.eq_null_safe(expr._expr)) + def __ne__(self, other: Expression) -> Expression: # type: ignore """Compares if an expression is not equal to another (``e1 != e2``).""" expr = Expression._to_expression(other) @@ -1414,6 +1440,7 @@ def upload( multi_thread = ExpressionUrlNamespace._should_use_multithreading_tokio_runtime() # If the user specifies a single location via a string, we should upload to a single folder. Otherwise, # if the user gave an expression, we assume that each row has a specific url to upload to. + # Consider moving the check for is_single_folder to a lower IR. is_single_folder = isinstance(location, str) io_config = ExpressionUrlNamespace._override_io_config_max_connections(max_connections, io_config) return Expression._from_pyexpr( @@ -3554,3 +3581,100 @@ class ExpressionEmbeddingNamespace(ExpressionNamespace): def cosine_distance(self, other: Expression) -> Expression: """Compute the cosine distance between two embeddings.""" return Expression._from_pyexpr(native.cosine_distance(self._expr, other._expr)) + + +class ExpressionBinaryNamespace(ExpressionNamespace): + def length(self) -> Expression: + """Retrieves the length for a binary string column. + + Example: + >>> import daft + >>> df = daft.from_pydict({"x": [b"foo", b"bar", b"baz"]}) + >>> df = df.select(df["x"].binary.length()) + >>> df.show() + ╭────────╮ + │ x │ + │ --- │ + │ UInt64 │ + ╞════════╡ + │ 3 │ + ├╌╌╌╌╌╌╌╌┤ + │ 3 │ + ├╌╌╌╌╌╌╌╌┤ + │ 3 │ + ╰────────╯ + + (Showing first 3 of 3 rows) + + Returns: + Expression: an UInt64 expression with the length of each binary string in bytes + """ + return Expression._from_pyexpr(native.binary_length(self._expr)) + + def concat(self, other: Expression) -> Expression: + r"""Concatenates two binary strings. + + Example: + >>> import daft + >>> df = daft.from_pydict( + ... {"a": [b"Hello", b"\\xff\\xfe", b"", b"World"], "b": [b" World", b"\\x00", b"empty", b"!"]} + ... ) + >>> df = df.select(df["a"].binary.concat(df["b"])) + >>> df.show() + ╭────────────────────╮ + │ a │ + │ --- │ + │ Binary │ + ╞════════════════════╡ + │ b"Hello World" │ + ├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ + │ b"\\xff\\xfe\\x00" │ + ├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ + │ b"empty" │ + ├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ + │ b"World!" │ + ╰────────────────────╯ + + (Showing first 4 of 4 rows) + + Args: + other: The binary string to concatenate with, can be either an Expression or a bytes literal + + Returns: + Expression: A binary expression containing the concatenated strings + """ + other_expr = Expression._to_expression(other) + return Expression._from_pyexpr(native.binary_concat(self._expr, other_expr._expr)) + + def slice(self, start: Expression | int, length: Expression | int | None = None) -> Expression: + r"""Returns a slice of each binary string. + + Example: + >>> import daft + >>> df = daft.from_pydict({"x": [b"Hello World", b"\xff\xfe\x00", b"empty"]}) + >>> df = df.select(df["x"].binary.slice(1, 3)) + >>> df.show() + ╭─────────────╮ + │ x │ + │ --- │ + │ Binary │ + ╞═════════════╡ + │ b"ell" │ + ├╌╌╌╌╌╌╌╌╌╌╌╌╌┤ + │ b"\xfe\x00" │ + ├╌╌╌╌╌╌╌╌╌╌╌╌╌┤ + │ b"mpt" │ + ╰─────────────╯ + + (Showing first 3 of 3 rows) + + Args: + start: The starting position (0-based) of the slice. + length: The length of the slice. If None, returns all characters from start to the end. + + Returns: + A new expression representing the slice. + """ + start_expr = Expression._to_expression(start) + length_expr = Expression._to_expression(length) + return Expression._from_pyexpr(native.binary_slice(self._expr, start_expr._expr, length_expr._expr)) diff --git a/daft/filesystem.py b/daft/filesystem.py index 88ed7a9985..309ffb09ac 100644 --- a/daft/filesystem.py +++ b/daft/filesystem.py @@ -6,7 +6,8 @@ import pathlib import sys import urllib.parse -from typing import TYPE_CHECKING, Any, Literal +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any from daft.convert import from_pydict from daft.daft import FileFormat, FileInfos, IOConfig, io_glob @@ -19,7 +20,14 @@ logger = logging.getLogger(__name__) -_CACHED_FSES: dict[tuple[str, IOConfig | None], pafs.FileSystem] = {} + +@dataclasses.dataclass(frozen=True) +class PyArrowFSWithExpiry: + fs: pafs.FileSystem + expiry: datetime | None + + +_CACHED_FSES: dict[tuple[str, IOConfig | None], PyArrowFSWithExpiry] = {} def _get_fs_from_cache(protocol: str, io_config: IOConfig | None) -> pafs.FileSystem | None: @@ -29,22 +37,20 @@ def _get_fs_from_cache(protocol: str, io_config: IOConfig | None) -> pafs.FileSy """ global _CACHED_FSES - return _CACHED_FSES.get((protocol, io_config)) + if (protocol, io_config) in _CACHED_FSES: + fs = _CACHED_FSES[(protocol, io_config)] + if fs.expiry is None or fs.expiry > datetime.now(timezone.utc): + return fs.fs -def _put_fs_in_cache(protocol: str, fs: pafs.FileSystem, io_config: IOConfig | None) -> None: - """Put pyarrow filesystem in cache under provided protocol.""" - global _CACHED_FSES + return None - _CACHED_FSES[(protocol, io_config)] = fs +def _put_fs_in_cache(protocol: str, fs: pafs.FileSystem, io_config: IOConfig | None, expiry: datetime | None) -> None: + """Put pyarrow filesystem in cache under provided protocol.""" + global _CACHED_FSES -@dataclasses.dataclass(frozen=True) -class ListingInfo: - path: str - size: int - type: Literal["file", "directory"] - rows: int | None = None + _CACHED_FSES[(protocol, io_config)] = PyArrowFSWithExpiry(fs, expiry) def get_filesystem(protocol: str, **kwargs) -> fsspec.AbstractFileSystem: @@ -154,10 +160,10 @@ def _resolve_paths_and_filesystem( if resolved_filesystem is None: # Resolve path and filesystem for the first path. # We use this first resolved filesystem for validation on all other paths. - resolved_path, resolved_filesystem = _infer_filesystem(paths[0], io_config) + resolved_path, resolved_filesystem, expiry = _infer_filesystem(paths[0], io_config) # Put resolved filesystem in cache under these paths' canonical protocol. - _put_fs_in_cache(protocol, resolved_filesystem, io_config) + _put_fs_in_cache(protocol, resolved_filesystem, io_config, expiry) else: resolved_path = _validate_filesystem(paths[0], resolved_filesystem, io_config) @@ -175,7 +181,7 @@ def _resolve_paths_and_filesystem( def _validate_filesystem(path: str, fs: pafs.FileSystem, io_config: IOConfig | None) -> str: - resolved_path, inferred_fs = _infer_filesystem(path, io_config) + resolved_path, inferred_fs, _ = _infer_filesystem(path, io_config) if not isinstance(fs, type(inferred_fs)): raise RuntimeError( f"Cannot read multiple paths with different inferred PyArrow filesystems. Expected: {fs} but received: {inferred_fs}" @@ -186,8 +192,8 @@ def _validate_filesystem(path: str, fs: pafs.FileSystem, io_config: IOConfig | N def _infer_filesystem( path: str, io_config: IOConfig | None, -) -> tuple[str, pafs.FileSystem]: - """Resolves and normalizes the provided path and infers it's filesystem. +) -> tuple[str, pafs.FileSystem, datetime | None]: + """Resolves and normalizes the provided path and infers its filesystem and expiry. Also ensures that the inferred filesystem is compatible with the passedfilesystem, if provided. @@ -225,9 +231,17 @@ def _set_if_not_none(kwargs: dict[str, Any], key: str, val: Any | None): except ImportError: pass # Config does not exist in pyarrow 7.0.0 + expiry = None + if (s3_creds := s3_config.provide_cached_credentials()) is not None: + _set_if_not_none(translated_kwargs, "access_key", s3_creds.key_id) + _set_if_not_none(translated_kwargs, "secret_key", s3_creds.access_key) + _set_if_not_none(translated_kwargs, "session_token", s3_creds.session_token) + + expiry = s3_creds.expiry + resolved_filesystem = pafs.S3FileSystem(**translated_kwargs) resolved_path = resolved_filesystem.normalize_path(_unwrap_protocol(path)) - return resolved_path, resolved_filesystem + return resolved_path, resolved_filesystem, expiry ### # Local @@ -235,7 +249,7 @@ def _set_if_not_none(kwargs: dict[str, Any], key: str, val: Any | None): elif protocol == "file": resolved_filesystem = pafs.LocalFileSystem() resolved_path = resolved_filesystem.normalize_path(_unwrap_protocol(path)) - return resolved_path, resolved_filesystem + return resolved_path, resolved_filesystem, None ### # GCS @@ -257,7 +271,7 @@ def _set_if_not_none(kwargs: dict[str, Any], key: str, val: Any | None): resolved_filesystem = GcsFileSystem(**translated_kwargs) resolved_path = resolved_filesystem.normalize_path(_unwrap_protocol(path)) - return resolved_path, resolved_filesystem + return resolved_path, resolved_filesystem, None ### # HTTP: Use FSSpec as a fallback @@ -267,7 +281,7 @@ def _set_if_not_none(kwargs: dict[str, Any], key: str, val: Any | None): fsspec_fs = fsspec_fs_cls() resolved_filesystem, resolved_path = pafs._resolve_filesystem_and_path(path, fsspec_fs) resolved_path = resolved_filesystem.normalize_path(resolved_path) - return resolved_path, resolved_filesystem + return resolved_path, resolved_filesystem, None ### # Azure: Use FSSpec as a fallback @@ -290,7 +304,7 @@ def _set_if_not_none(kwargs: dict[str, Any], key: str, val: Any | None): fsspec_fs = fsspec_fs_cls() resolved_filesystem, resolved_path = pafs._resolve_filesystem_and_path(path, fsspec_fs) resolved_path = resolved_filesystem.normalize_path(_unwrap_protocol(resolved_path)) - return resolved_path, resolved_filesystem + return resolved_path, resolved_filesystem, None else: raise NotImplementedError(f"Cannot infer PyArrow filesystem for protocol {protocol}: please file an issue!") @@ -313,7 +327,7 @@ def glob_path_with_stats( file_format: FileFormat | None, io_config: IOConfig | None, ) -> FileInfos: - """Glob a path, returning a list ListingInfo.""" + """Glob a path, returning a FileInfos.""" files = io_glob(path, io_config=io_config) filepaths_to_infos = {f["path"]: {"size": f["size"], "type": f["type"]} for f in files} @@ -354,19 +368,41 @@ def overwrite_files( manifest: DataFrame, root_dir: str | pathlib.Path, io_config: IOConfig | None, + overwrite_partitions: bool, ) -> None: [resolved_path], fs = _resolve_paths_and_filesystem(root_dir, io_config=io_config) - file_selector = pafs.FileSelector(resolved_path, recursive=True) - try: - paths = [info.path for info in fs.get_file_info(file_selector) if info.type == pafs.FileType.File] - except FileNotFoundError: - # The root directory does not exist, so there are no files to delete. - return - - all_file_paths_df = from_pydict({"path": paths}) assert manifest._result is not None written_file_paths = manifest._result._get_merged_micropartition().get_column("path") + + all_file_paths = [] + if overwrite_partitions: + # Get all files in ONLY the directories that were written to. + + written_dirs = set(str(pathlib.Path(path).parent) for path in written_file_paths.to_pylist()) + for dir in written_dirs: + file_selector = pafs.FileSelector(dir, recursive=True) + try: + all_file_paths.extend( + [info.path for info in fs.get_file_info(file_selector) if info.type == pafs.FileType.File] + ) + except FileNotFoundError: + continue + else: + # Get all files in the root directory. + + file_selector = pafs.FileSelector(resolved_path, recursive=True) + try: + all_file_paths.extend( + [info.path for info in fs.get_file_info(file_selector) if info.type == pafs.FileType.File] + ) + except FileNotFoundError: + # The root directory does not exist, so there are no files to delete. + return + + all_file_paths_df = from_pydict({"path": all_file_paths}) + + # Find the files that were not written to in this run and delete them. to_delete = all_file_paths_df.where(~(col("path").is_in(lit(written_file_paths)))) # TODO: Look into parallelizing this diff --git a/daft/io/_iceberg.py b/daft/io/_iceberg.py index c3ea30aaa9..a627dadb92 100644 --- a/daft/io/_iceberg.py +++ b/daft/io/_iceberg.py @@ -1,6 +1,6 @@ # isort: dont-add-import: from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, Optional, Union from daft import context from daft.api_annotations import PublicAPI @@ -53,7 +53,7 @@ def get_first_property_value(*property_names: str) -> Optional[Any]: @PublicAPI def read_iceberg( - table: "pyiceberg.table.Table", + table: Union[str, "pyiceberg.table.Table"], snapshot_id: Optional[int] = None, io_config: Optional["IOConfig"] = None, ) -> DataFrame: @@ -75,15 +75,21 @@ def read_iceberg( official project for Python. Args: - table (pyiceberg.table.Table): `PyIceberg Table `__ created using the PyIceberg library + table (str or pyiceberg.table.Table): `PyIceberg Table `__ created using the PyIceberg library snapshot_id (int, optional): Snapshot ID of the table to query io_config (IOConfig, optional): A custom IOConfig to use when accessing Iceberg object storage data. If provided, configurations set in `table` are ignored. Returns: DataFrame: a DataFrame with the schema converted from the specified Iceberg table """ + import pyiceberg + from daft.iceberg.iceberg_scan import IcebergScanOperator + # support for read_iceberg('path/to/metadata.json') + if isinstance(table, str): + table = pyiceberg.table.StaticTable.from_metadata(metadata_location=table) + io_config = ( _convert_iceberg_file_io_properties_to_io_config(table.io.properties) if io_config is None else io_config ) diff --git a/daft/runners/native_runner.py b/daft/runners/native_runner.py index c7e5ce8034..a03e14c93a 100644 --- a/daft/runners/native_runner.py +++ b/daft/runners/native_runner.py @@ -75,8 +75,9 @@ def run_iter( # Optimize the logical plan. builder = builder.optimize() - executor = NativeExecutor.from_logical_plan_builder(builder) + executor = NativeExecutor() results_gen = executor.run( + builder, {k: v.values() for k, v in self._part_set_cache.get_all_partition_sets().items()}, daft_execution_config, results_buffer_size, diff --git a/daft/runners/pyrunner.py b/daft/runners/pyrunner.py index 450bc4eb57..48be64921b 100644 --- a/daft/runners/pyrunner.py +++ b/daft/runners/pyrunner.py @@ -380,8 +380,9 @@ def run_iter( if daft_execution_config.enable_native_executor: logger.info("Using native executor") - executor = NativeExecutor.from_logical_plan_builder(builder) + executor = NativeExecutor() results_gen = executor.run( + builder, {k: v.values() for k, v in self._part_set_cache.get_all_partition_sets().items()}, daft_execution_config, results_buffer_size, diff --git a/docs-v2/advanced/distributed.md b/docs-v2/advanced/distributed.md deleted file mode 100644 index 9f78b4e0f9..0000000000 --- a/docs-v2/advanced/distributed.md +++ /dev/null @@ -1,72 +0,0 @@ -# Distributed Computing - -By default, Daft runs using your local machine's resources and your operations are thus limited by the CPUs, memory and GPUs available to you in your single local development machine. - -However, Daft has strong integrations with [Ray](https://www.ray.io) which is a distributed computing framework for distributing computations across a cluster of machines. Here is a snippet showing how you can connect Daft to a Ray cluster: - -=== "🐍 Python" - - ```python - import daft - - daft.context.set_runner_ray() - ``` - -By default, if no address is specified Daft will spin up a Ray cluster locally on your machine. If you are running Daft on a powerful machine (such as an AWS P3 machine which is equipped with multiple GPUs) this is already very useful because Daft can parallelize its execution of computation across your CPUs and GPUs. However, if instead you already have your own Ray cluster running remotely, you can connect Daft to it by supplying an address: - -=== "🐍 Python" - - ```python - daft.context.set_runner_ray(address="ray://url-to-mycluster") - ``` - -For more information about the `address` keyword argument, please see the [Ray documentation on initialization](https://docs.ray.io/en/latest/ray-core/api/doc/ray.init.html). - - -If you want to start a single node ray cluster on your local machine, you can do the following: - -```bash -> pip install ray[default] -> ray start --head --port=6379 -``` - -This should output something like: - -``` -Usage stats collection is enabled. To disable this, add `--disable-usage-stats` to the command that starts the cluster, or run the following command: `ray disable-usage-stats` before starting the cluster. See https://docs.ray.io/en/master/cluster/usage-stats.html for more details. - -Local node IP: 127.0.0.1 - --------------------- -Ray runtime started. --------------------- - -... -``` - -You can take the IP address and port and pass it to Daft: - -=== "🐍 Python" - - ```python - >>> import daft - >>> daft.context.set_runner_ray("127.0.0.1:6379") - DaftContext(_daft_execution_config=, _daft_planning_config=, _runner_config=_RayRunnerConfig(address='127.0.0.1:6379', max_task_backlog=None), _disallow_set_runner=True, _runner=None) - >>> df = daft.from_pydict({ - ... 'text': ['hello', 'world'] - ... }) - 2024-07-29 15:49:26,610 INFO worker.py:1567 -- Connecting to existing Ray cluster at address: 127.0.0.1:6379... - 2024-07-29 15:49:26,622 INFO worker.py:1752 -- Connected to Ray cluster. - >>> print(df) - ╭───────╮ - │ text │ - │ --- │ - │ Utf8 │ - ╞═══════╡ - │ hello │ - ├╌╌╌╌╌╌╌┤ - │ world │ - ╰───────╯ - - (Showing first 2 of 2 rows) - ``` diff --git a/docs-v2/core_concepts.md b/docs-v2/core_concepts.md index e748cfc4cf..86435df7a1 100644 --- a/docs-v2/core_concepts.md +++ b/docs-v2/core_concepts.md @@ -2330,6 +2330,6 @@ Let’s turn the bytes into human-readable images using [`image.decode()`](https - [:material-memory: **Managing Memory Usage**](advanced/memory.md) - [:fontawesome-solid-equals: **Partitioning**](advanced/partitioning.md) -- [:material-distribute-vertical-center: **Distributed Computing**](advanced/distributed.md) +- [:material-distribute-vertical-center: **Distributed Computing**](distributed.md) diff --git a/docs-v2/core_concepts/aggregations.md b/docs-v2/core_concepts/aggregations.md deleted file mode 100644 index 4bb835f59d..0000000000 --- a/docs-v2/core_concepts/aggregations.md +++ /dev/null @@ -1,111 +0,0 @@ -# Aggregations and Grouping - -Some operations such as the sum or the average of a column are called **aggregations**. Aggregations are operations that reduce the number of rows in a column. - -## Global Aggregations - -An aggregation can be applied on an entire DataFrame, for example to get the mean on a specific column: - -=== "🐍 Python" - ``` python - import daft - - df = daft.from_pydict({ - "class": ["a", "a", "b", "b"], - "score": [10, 20., 30., 40], - }) - - df.mean("score").show() - ``` - -``` {title="Output"} - -╭─────────╮ -│ score │ -│ --- │ -│ Float64 │ -╞═════════╡ -│ 25 │ -╰─────────╯ - -(Showing first 1 of 1 rows) -``` - -For a full list of available Dataframe aggregations, see [Aggregations](https://www.getdaft.io/projects/docs/en/stable/api_docs/dataframe.html#df-aggregations). - -Aggregations can also be mixed and matched across columns, via the `agg` method: - -=== "🐍 Python" - ``` python - df.agg( - df["score"].mean().alias("mean_score"), - df["score"].max().alias("max_score"), - df["class"].count().alias("class_count"), - ).show() - ``` - -``` {title="Output"} - -╭────────────┬───────────┬─────────────╮ -│ mean_score ┆ max_score ┆ class_count │ -│ --- ┆ --- ┆ --- │ -│ Float64 ┆ Float64 ┆ UInt64 │ -╞════════════╪═══════════╪═════════════╡ -│ 25 ┆ 40 ┆ 4 │ -╰────────────┴───────────┴─────────────╯ - -(Showing first 1 of 1 rows) -``` - -For a full list of available aggregation expressions, see [Aggregation Expressions](https://www.getdaft.io/projects/docs/en/stable/api_docs/expressions.html#api-aggregation-expression) - -## Grouped Aggregations - -Aggregations can also be called on a "Grouped DataFrame". For the above example, perhaps we want to get the mean "score" not for the entire DataFrame, but for each "class". - -Let's run the mean of column "score" again, but this time grouped by "class": - -=== "🐍 Python" - ``` python - df.groupby("class").mean("score").show() - ``` - -``` {title="Output"} - -╭───────┬─────────╮ -│ class ┆ score │ -│ --- ┆ --- │ -│ Utf8 ┆ Float64 │ -╞═══════╪═════════╡ -│ a ┆ 15 │ -├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤ -│ b ┆ 35 │ -╰───────┴─────────╯ - -(Showing first 2 of 2 rows) -``` - -To run multiple aggregations on a Grouped DataFrame, you can use the `agg` method: - -=== "🐍 Python" - ``` python - df.groupby("class").agg( - df["score"].mean().alias("mean_score"), - df["score"].max().alias("max_score"), - ).show() - ``` - -``` {title="Output"} - -╭───────┬────────────┬───────────╮ -│ class ┆ mean_score ┆ max_score │ -│ --- ┆ --- ┆ --- │ -│ Utf8 ┆ Float64 ┆ Float64 │ -╞═══════╪════════════╪═══════════╡ -│ a ┆ 15 ┆ 20 │ -├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌┤ -│ b ┆ 35 ┆ 40 │ -╰───────┴────────────┴───────────╯ - -(Showing first 2 of 2 rows) -``` diff --git a/docs-v2/core_concepts/dataframe.md b/docs-v2/core_concepts/dataframe.md deleted file mode 100644 index 9ee1bb3320..0000000000 --- a/docs-v2/core_concepts/dataframe.md +++ /dev/null @@ -1,654 +0,0 @@ -# DataFrame - -!!! failure "todo(docs): Check that this page makes sense. Can we have a 1-1 mapping of "Common data operations that you would perform on DataFrames are: ..." to its respective section?" - -!!! failure "todo(docs): I reused some of these sections in the Quickstart (create df, execute df and view data, select rows, select columns) but the examples in the quickstart are different. Should we still keep those sections on this page?" - - -If you are coming from other DataFrame libraries such as Pandas or Polars, here are some key differences about Daft DataFrames: - -1. **Distributed:** When running in a distributed cluster, Daft splits your data into smaller "chunks" called *Partitions*. This allows Daft to process your data in parallel across multiple machines, leveraging more resources to work with large datasets. - -2. **Lazy:** When you write operations on a DataFrame, Daft doesn't execute them immediately. Instead, it creates a plan (called a query plan) of what needs to be done. This plan is optimized and only executed when you specifically request the results, which can lead to more efficient computations. - -3. **Multimodal:** Unlike traditional tables that usually contain simple data types like numbers and text, Daft DataFrames can handle complex data types in its columns. This includes things like images, audio files, or even custom Python objects. - -For a full comparison between Daft and other DataFrame Libraries, see [DataFrame Comparison](../resources/dataframe_comparison.md). - -Common data operations that you would perform on DataFrames are: - -1. [**Filtering rows:**](dataframe.md/#selecting-rows) Use [`df.where(...)`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/dataframe_methods/daft.DataFrame.where.html#daft.DataFrame.where) to keep only the rows that meet certain conditions. -2. **Creating new columns:** Use [`df.with_column(...)`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/dataframe_methods/daft.DataFrame.with_column.html#daft.DataFrame.with_column) to add a new column based on calculations from existing ones. -3. [**Joining DataFrames:**](dataframe.md/#combining-dataframes) Use [`df.join(other_df, ...)`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/dataframe_methods/daft.DataFrame.join.html#daft.DataFrame.join) to combine two DataFrames based on common columns. -4. [**Sorting:**](dataframe.md#reordering-rows) Use [`df.sort(...)`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/dataframe_methods/daft.DataFrame.sort.html#daft.DataFrame.sort) to arrange your data based on values in one or more columns. -5. **Grouping and aggregating:** Use [`df.groupby(...).agg(...)`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/dataframe_methods/daft.DataFrame.groupby.html#daft.DataFrame.groupby) to summarize your data by groups. - -## Creating a DataFrame - -!!! tip "See Also" - - [Reading/Writing Data](read_write.md) - a more in-depth guide on various options for reading/writing data to/from Daft DataFrames from in-memory data (Python, Arrow), files (Parquet, CSV, JSON), SQL Databases and Data Catalogs - -Let's create our first Dataframe from a Python dictionary of columns. - -=== "🐍 Python" - - ```python - import daft - - df = daft.from_pydict({ - "A": [1, 2, 3, 4], - "B": [1.5, 2.5, 3.5, 4.5], - "C": [True, True, False, False], - "D": [None, None, None, None], - }) - ``` - -Examine your Dataframe by printing it: - -``` -df -``` - -``` {title="Output"} - -╭───────┬─────────┬─────────┬──────╮ -│ A ┆ B ┆ C ┆ D │ -│ --- ┆ --- ┆ --- ┆ --- │ -│ Int64 ┆ Float64 ┆ Boolean ┆ Null │ -╞═══════╪═════════╪═════════╪══════╡ -│ 1 ┆ 1.5 ┆ true ┆ None │ -├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌┤ -│ 2 ┆ 2.5 ┆ true ┆ None │ -├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌┤ -│ 3 ┆ 3.5 ┆ false ┆ None │ -├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌┤ -│ 4 ┆ 4.5 ┆ false ┆ None │ -╰───────┴─────────┴─────────┴──────╯ - -(Showing first 4 of 4 rows) -``` - -Congratulations - you just created your first DataFrame! It has 4 columns, "A", "B", "C", and "D". Let's try to select only the "A", "B", and "C" columns: - -=== "🐍 Python" - ``` python - df = df.select("A", "B", "C") - df - ``` - -=== "⚙️ SQL" - ```python - df = daft.sql("SELECT A, B, C FROM df") - df - ``` - -``` {title="Output"} - -╭───────┬─────────┬─────────╮ -│ A ┆ B ┆ C │ -│ --- ┆ --- ┆ --- │ -│ Int64 ┆ Float64 ┆ Boolean │ -╰───────┴─────────┴─────────╯ - -(No data to display: Dataframe not materialized) -``` - -But wait - why is it printing the message `(No data to display: Dataframe not materialized)` and where are the rows of each column? - -## Executing DataFrame and Viewing Data - -The reason that our DataFrame currently does not display its rows is that Daft DataFrames are **lazy**. This just means that Daft DataFrames will defer all its work until you tell it to execute. - -In this case, Daft is just deferring the work required to read the data and select columns, however in practice this laziness can be very useful for helping Daft optimize your queries before execution! - -!!! info "Info" - - When you call methods on a Daft Dataframe, it defers the work by adding to an internal "plan". You can examine the current plan of a DataFrame by calling [`df.explain()`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/dataframe_methods/daft.DataFrame.explain.html#daft.DataFrame.explain)! - - Passing the `show_all=True` argument will show you the plan after Daft applies its query optimizations and the physical (lower-level) plan. - - ``` - Plan Output - - == Unoptimized Logical Plan == - - * Project: col(A), col(B), col(C) - | - * Source: - | Number of partitions = 1 - | Output schema = A#Int64, B#Float64, C#Boolean, D#Null - - - == Optimized Logical Plan == - - * Project: col(A), col(B), col(C) - | - * Source: - | Number of partitions = 1 - | Output schema = A#Int64, B#Float64, C#Boolean, D#Null - - - == Physical Plan == - - * Project: col(A), col(B), col(C) - | Clustering spec = { Num partitions = 1 } - | - * InMemoryScan: - | Schema = A#Int64, B#Float64, C#Boolean, D#Null, - | Size bytes = 65, - | Clustering spec = { Num partitions = 1 } - ``` - -We can tell Daft to execute our DataFrame and store the results in-memory using [`df.collect()`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/dataframe_methods/daft.DataFrame.collect.html#daft.DataFrame.collect): - -=== "🐍 Python" - ``` python - df.collect() - df - ``` - -``` {title="Output"} -╭───────┬─────────┬─────────┬──────╮ -│ A ┆ B ┆ C ┆ D │ -│ --- ┆ --- ┆ --- ┆ --- │ -│ Int64 ┆ Float64 ┆ Boolean ┆ Null │ -╞═══════╪═════════╪═════════╪══════╡ -│ 1 ┆ 1.5 ┆ true ┆ None │ -├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌┤ -│ 2 ┆ 2.5 ┆ true ┆ None │ -├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌┤ -│ 3 ┆ 3.5 ┆ false ┆ None │ -├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌┤ -│ 4 ┆ 4.5 ┆ false ┆ None │ -╰───────┴─────────┴─────────┴──────╯ - -(Showing first 4 of 4 rows) -``` - -Now your DataFrame object `df` is **materialized** - Daft has executed all the steps required to compute the results, and has cached the results in memory so that it can display this preview. - -Any subsequent operations on `df` will avoid recomputations, and just use this materialized result! - -### When should I materialize my DataFrame? - -If you "eagerly" call [`df.collect()`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/dataframe_methods/daft.DataFrame.collect.html#daft.DataFrame.collect) immediately on every DataFrame, you may run into issues: - -1. If data is too large at any step, materializing all of it may cause memory issues -2. Optimizations are not possible since we cannot "predict future operations" - -However, data science is all about experimentation and trying different things on the same data. This means that materialization is crucial when working interactively with DataFrames, since it speeds up all subsequent experimentation on that DataFrame. - -We suggest materializing DataFrames using [`df.collect()`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/dataframe_methods/daft.DataFrame.collect.html#daft.DataFrame.collect) when they contain expensive operations (e.g. sorts or expensive function calls) and have to be called multiple times by downstream code: - -=== "🐍 Python" - ``` python - df = df.sort("A") # expensive sort - df.collect() # materialize the DataFrame - - # All subsequent work on df avoids recomputing previous steps - df.sum("B").show() - df.mean("B").show() - df.with_column("try_this", df["A"] + 1).show(5) - ``` - -=== "⚙️ SQL" - ```python - df = daft.sql("SELECT * FROM df ORDER BY A") - df.collect() - - # All subsequent work on df avoids recomputing previous steps - daft.sql("SELECT sum(B) FROM df").show() - daft.sql("SELECT mean(B) FROM df").show() - daft.sql("SELECT *, (A + 1) AS try_this FROM df").show(5) - ``` - -``` {title="Output"} - -╭─────────╮ -│ B │ -│ --- │ -│ Float64 │ -╞═════════╡ -│ 12 │ -╰─────────╯ - -(Showing first 1 of 1 rows) - -╭─────────╮ -│ B │ -│ --- │ -│ Float64 │ -╞═════════╡ -│ 3 │ -╰─────────╯ - -(Showing first 1 of 1 rows) - -╭───────┬─────────┬─────────┬──────────╮ -│ A ┆ B ┆ C ┆ try_this │ -│ --- ┆ --- ┆ --- ┆ --- │ -│ Int64 ┆ Float64 ┆ Boolean ┆ Int64 │ -╞═══════╪═════════╪═════════╪══════════╡ -│ 1 ┆ 1.5 ┆ true ┆ 2 │ -├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┤ -│ 2 ┆ 2.5 ┆ true ┆ 3 │ -├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┤ -│ 3 ┆ 3.5 ┆ false ┆ 4 │ -├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┤ -│ 4 ┆ 4.5 ┆ false ┆ 5 │ -╰───────┴─────────┴─────────┴──────────╯ - -(Showing first 4 of 4 rows) -``` - -In many other cases however, there are better options than materializing your entire DataFrame with [`df.collect()`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/dataframe_methods/daft.DataFrame.collect.html#daft.DataFrame.collect): - -1. **Peeking with df.show(N)**: If you only want to "peek" at the first few rows of your data for visualization purposes, you can use [`df.show(N)`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/dataframe_methods/daft.DataFrame.show.html#daft.DataFrame.show), which processes and shows only the first `N` rows. -2. **Writing to disk**: The `df.write_*` methods will process and write your data to disk per-partition, avoiding materializing it all in memory at once. -3. **Pruning data**: You can materialize your DataFrame after performing a [`df.limit()`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/dataframe_methods/daft.DataFrame.limit.html#daft.DataFrame.limit), [`df.where()`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/dataframe_methods/daft.DataFrame.where.html#daft.DataFrame.where) or [`df.select()`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/dataframe_methods/daft.DataFrame.select.html#daft.DataFrame.select) operation which processes your data or prune it down to a smaller size. - -## Schemas and Types - -Notice also that when we printed our DataFrame, Daft displayed its **schema**. Each column of your DataFrame has a **name** and a **type**, and all data in that column will adhere to that type! - -Daft can display your DataFrame's schema without materializing it. Under the hood, it performs intelligent sampling of your data to determine the appropriate schema, and if you make any modifications to your DataFrame it can infer the resulting types based on the operation. - -!!! note "Note" - - Under the hood, Daft represents data in the [Apache Arrow](https://arrow.apache.org/) format, which allows it to efficiently represent and work on data using high-performance kernels which are written in Rust. - -## Running Computation with Expressions - -To run computations on data in our DataFrame, we use Expressions. - -The following statement will [`df.show()`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/dataframe_methods/daft.DataFrame.show.html#daft.DataFrame.show) a DataFrame that has only one column - the column `A` from our original DataFrame but with every row incremented by 1. - -=== "🐍 Python" - ``` python - df.select(df["A"] + 1).show() - ``` - -=== "⚙️ SQL" - ```python - daft.sql("SELECT A + 1 FROM df").show() - ``` - -``` {title="Output"} - -╭───────╮ -│ A │ -│ --- │ -│ Int64 │ -╞═══════╡ -│ 2 │ -├╌╌╌╌╌╌╌┤ -│ 3 │ -├╌╌╌╌╌╌╌┤ -│ 4 │ -├╌╌╌╌╌╌╌┤ -│ 5 │ -╰───────╯ - -(Showing first 4 of 4 rows) -``` - -!!! info "Info" - - A common pattern is to create a new columns using [`DataFrame.with_column`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/dataframe_methods/daft.DataFrame.with_column.html): - - === "🐍 Python" - ``` python - # Creates a new column named "foo" which takes on values - # of column "A" incremented by 1 - df = df.with_column("foo", df["A"] + 1) - df.show() - ``` - - === "⚙️ SQL" - ```python - # Creates a new column named "foo" which takes on values - # of column "A" incremented by 1 - df = daft.sql("SELECT *, A + 1 AS foo FROM df") - df.show() - ``` - -``` {title="Output"} - -╭───────┬─────────┬─────────┬───────╮ -│ A ┆ B ┆ C ┆ foo │ -│ --- ┆ --- ┆ --- ┆ --- │ -│ Int64 ┆ Float64 ┆ Boolean ┆ Int64 │ -╞═══════╪═════════╪═════════╪═══════╡ -│ 1 ┆ 1.5 ┆ true ┆ 2 │ -├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ -│ 2 ┆ 2.5 ┆ true ┆ 3 │ -├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ -│ 3 ┆ 3.5 ┆ false ┆ 4 │ -├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ -│ 4 ┆ 4.5 ┆ false ┆ 5 │ -╰───────┴─────────┴─────────┴───────╯ - -(Showing first 4 of 4 rows) -``` - -Congratulations, you have just written your first **Expression**: `df["A"] + 1`! Expressions are a powerful way of describing computation on columns. For more details, check out the next section on [Expressions](expressions.md). - - - -## Selecting Rows - -We can limit the rows to the first ``N`` rows using [`df.limit(N)`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/dataframe_methods/daft.DataFrame.limit.html#daft.DataFrame.limit): - -=== "🐍 Python" - ``` python - df = daft.from_pydict({ - "A": [1, 2, 3, 4, 5], - "B": [6, 7, 8, 9, 10], - }) - - df.limit(3).show() - ``` - -``` {title="Output"} - -+---------+---------+ -| A | B | -| Int64 | Int64 | -+=========+=========+ -| 1 | 6 | -+---------+---------+ -| 2 | 7 | -+---------+---------+ -| 3 | 8 | -+---------+---------+ -(Showing first 3 rows) -``` - -We can also filter rows using [`df.where()`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/dataframe_methods/daft.DataFrame.where.html#daft.DataFrame.where), which takes an input a Logical Expression predicate: - -=== "🐍 Python" - ``` python - df.where(df["A"] > 3).show() - ``` - -``` {title="Output"} - -+---------+---------+ -| A | B | -| Int64 | Int64 | -+=========+=========+ -| 4 | 9 | -+---------+---------+ -| 5 | 10 | -+---------+---------+ -(Showing first 2 rows) -``` - -## Selecting Columns - -Select specific columns in a DataFrame using [`df.select()`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/dataframe_methods/daft.DataFrame.select.html#daft.DataFrame.select), which also takes Expressions as an input. - -=== "🐍 Python" - ``` python - import daft - - df = daft.from_pydict({"A": [1, 2, 3], "B": [4, 5, 6]}) - - df.select("A").show() - ``` - -``` {title="Output"} - -+---------+ -| A | -| Int64 | -+=========+ -| 1 | -+---------+ -| 2 | -+---------+ -| 3 | -+---------+ -(Showing first 3 rows) -``` - -A useful alias for [`df.select()`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/dataframe_methods/daft.DataFrame.select.html#daft.DataFrame.select) is indexing a DataFrame with a list of column names or Expressions: - -=== "🐍 Python" - ``` python - df[["A", "B"]].show() - ``` - -``` {title="Output"} - -+---------+---------+ -| A | B | -| Int64 | Int64 | -+=========+=========+ -| 1 | 4 | -+---------+---------+ -| 2 | 5 | -+---------+---------+ -| 3 | 6 | -+---------+---------+ -(Showing first 3 rows) -``` - -Sometimes, it may be useful to exclude certain columns from a DataFrame. This can be done with [`df.exclude()`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/dataframe_methods/daft.DataFrame.exclude.html#daft.DataFrame.exclude): - -=== "🐍 Python" - ``` python - df.exclude("A").show() - ``` - -```{title="Output"} - -+---------+ -| B | -| Int64 | -+=========+ -| 4 | -+---------+ -| 5 | -+---------+ -| 6 | -+---------+ -(Showing first 3 rows) -``` - -Adding a new column can be achieved with [`df.with_column()`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/dataframe_methods/daft.DataFrame.with_column.html#daft.DataFrame.with_column): - -=== "🐍 Python" - ``` python - df.with_column("C", df["A"] + df["B"]).show() - ``` - -``` {title="Output"} - -+---------+---------+---------+ -| A | B | C | -| Int64 | Int64 | Int64 | -+=========+=========+=========+ -| 1 | 4 | 5 | -+---------+---------+---------+ -| 2 | 5 | 7 | -+---------+---------+---------+ -| 3 | 6 | 9 | -+---------+---------+---------+ -(Showing first 3 rows) -``` - -### Selecting Columns Using Wildcards - -We can select multiple columns at once using wildcards. The expression [`.col(*)`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/expression_methods/daft.col.html#daft.col) selects every column in a DataFrame, and you can operate on this expression in the same way as a single column: - -=== "🐍 Python" - ``` python - df = daft.from_pydict({"A": [1, 2, 3], "B": [4, 5, 6]}) - df.select(col("*") * 3).show() - ``` - -``` {title="Output"} -╭───────┬───────╮ -│ A ┆ B │ -│ --- ┆ --- │ -│ Int64 ┆ Int64 │ -╞═══════╪═══════╡ -│ 3 ┆ 12 │ -├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ -│ 6 ┆ 15 │ -├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ -│ 9 ┆ 18 │ -╰───────┴───────╯ -``` - -We can also select multiple columns within structs using `col("struct.*")`: - -=== "🐍 Python" - ``` python - df = daft.from_pydict({ - "A": [ - {"B": 1, "C": 2}, - {"B": 3, "C": 4} - ] - }) - df.select(col("A.*")).show() - ``` - -``` {title="Output"} - -╭───────┬───────╮ -│ B ┆ C │ -│ --- ┆ --- │ -│ Int64 ┆ Int64 │ -╞═══════╪═══════╡ -│ 1 ┆ 2 │ -├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ -│ 3 ┆ 4 │ -╰───────┴───────╯ -``` - -Under the hood, wildcards work by finding all of the columns that match, then copying the expression several times and replacing the wildcard. This means that there are some caveats: - -* Only one wildcard is allowed per expression tree. This means that `col("*") + col("*")` and similar expressions do not work. -* Be conscious about duplicated column names. Any code like `df.select(col("*"), col("*") + 3)` will not work because the wildcards expand into the same column names. - - For the same reason, `col("A") + col("*")` will not work because the name on the left-hand side is inherited, meaning all the output columns are named `A`, causing an error if there is more than one. - However, `col("*") + col("A")` will work fine. - -## Combining DataFrames - -Two DataFrames can be column-wise joined using [`df.join()`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/dataframe_methods/daft.DataFrame.join.html#daft.DataFrame.join). - -This requires a "join key", which can be supplied as the `on` argument if both DataFrames have the same name for their key columns, or the `left_on` and `right_on` argument if the key column has different names in each DataFrame. - -Daft also supports multi-column joins if you have a join key comprising of multiple columns! - -=== "🐍 Python" - ``` python - df1 = daft.from_pydict({"A": [1, 2, 3], "B": [4, 5, 6]}) - df2 = daft.from_pydict({"A": [1, 2, 3], "C": [7, 8, 9]}) - - df1.join(df2, on="A").show() - ``` - -``` {title="Output"} - -+---------+---------+---------+ -| A | B | C | -| Int64 | Int64 | Int64 | -+=========+=========+=========+ -| 1 | 4 | 7 | -+---------+---------+---------+ -| 2 | 5 | 8 | -+---------+---------+---------+ -| 3 | 6 | 9 | -+---------+---------+---------+ -(Showing first 3 rows) -``` - -## Reordering Rows - -Rows in a DataFrame can be reordered based on some column using [`df.sort()`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/dataframe_methods/daft.DataFrame.sort.html#daft.DataFrame.sort). Daft also supports multi-column sorts for sorting on multiple columns at once. - -=== "🐍 Python" - ``` python - df = daft.from_pydict({ - "A": [1, 2, 3], - "B": [6, 7, 8], - }) - - df.sort("A", desc=True).show() - ``` - -```{title="Output"} - -+---------+---------+ -| A | B | -| Int64 | Int64 | -+=========+=========+ -| 3 | 8 | -+---------+---------+ -| 2 | 7 | -+---------+---------+ -| 1 | 6 | -+---------+---------+ -(Showing first 3 rows) -``` - -## Exploding Columns - -The [`df.explode()`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/dataframe_methods/daft.DataFrame.explode.html#daft.DataFrame.explode) method can be used to explode a column containing a list of values into multiple rows. All other rows will be **duplicated**. - -=== "🐍 Python" - ``` python - df = daft.from_pydict({ - "A": [1, 2, 3], - "B": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], - }) - - df.explode("B").show() - ``` - -``` {title="Output"} - -+---------+---------+ -| A | B | -| Int64 | Int64 | -+=========+=========+ -| 1 | 1 | -+---------+---------+ -| 1 | 2 | -+---------+---------+ -| 1 | 3 | -+---------+---------+ -| 2 | 4 | -+---------+---------+ -| 2 | 5 | -+---------+---------+ -| 2 | 6 | -+---------+---------+ -| 3 | 7 | -+---------+---------+ -| 3 | 8 | -+---------+---------+ -(Showing first 8 rows) -``` - - - diff --git a/docs-v2/core_concepts/datatypes.md b/docs-v2/core_concepts/datatypes.md deleted file mode 100644 index f932623806..0000000000 --- a/docs-v2/core_concepts/datatypes.md +++ /dev/null @@ -1,96 +0,0 @@ -# DataTypes - -All columns in a Daft DataFrame have a DataType (also often abbreviated as `dtype`). - -All elements of a column are of the same dtype, or they can be the special Null value (indicating a missing value). - -Daft provides simple DataTypes that are ubiquituous in many DataFrames such as numbers, strings and dates - all the way up to more complex types like tensors and images. - -!!! tip "Tip" - - For a full overview on all the DataTypes that Daft supports, see the [DataType API Reference](https://www.getdaft.io/projects/docs/en/stable/api_docs/datatype.html). - - -## Numeric DataTypes - -Numeric DataTypes allows Daft to represent numbers. These numbers can differ in terms of the number of bits used to represent them (8, 16, 32 or 64 bits) and the semantic meaning of those bits -(float vs integer vs unsigned integers). - -Examples: - -1. [`DataType.int8()`](https://www.getdaft.io/projects/docs/en/stable/api_docs/datatype.html#daft.DataType.int8): represents an 8-bit signed integer (-128 to 127) -2. [`DataType.float32()`](https://www.getdaft.io/projects/docs/en/stable/api_docs/datatype.html#daft.DataType.float32): represents a 32-bit float (a float number with about 7 decimal digits of precision) - -Columns/expressions with these datatypes can be operated on with many numeric expressions such as `+` and `*`. - -See also: [Numeric Expressions](https://www.getdaft.io/projects/docs/en/stable/user_guide/expressions.html#userguide-numeric-expressions) - -## Logical DataTypes - -The [`Boolean`](https://www.getdaft.io/projects/docs/en/stable/api_docs/datatype.html#daft.DataType.bool) DataType represents values which are boolean values: `True`, `False` or `Null`. - -Columns/expressions with this dtype can be operated on using logical expressions such as ``&`` and [`.if_else()`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/expression_methods/daft.Expression.if_else.html#daft.Expression.if_else). - -See also: [Logical Expressions](https://www.getdaft.io/projects/docs/en/stable/user_guide/expressions.html#userguide-logical-expressions) - -## String Types - -Daft has string types, which represent a variable-length string of characters. - -As a convenience method, string types also support the `+` Expression, which has been overloaded to support concatenation of elements between two [`DataType.string()`](https://www.getdaft.io/projects/docs/en/stable/api_docs/datatype.html#daft.DataType.string) columns. - -1. [`DataType.string()`](https://www.getdaft.io/projects/docs/en/stable/api_docs/datatype.html#daft.DataType.string): represents a string of UTF-8 characters -2. [`DataType.binary()`](https://www.getdaft.io/projects/docs/en/stable/api_docs/datatype.html#daft.DataType.binary): represents a string of bytes - -See also: [String Expressions](https://www.getdaft.io/projects/docs/en/stable/user_guide/expressions.html#userguide-string-expressions) - -## Temporal - -Temporal dtypes represent data that have to do with time. - -Examples: - -1. [`DataType.date()`](https://www.getdaft.io/projects/docs/en/stable/api_docs/datatype.html#daft.DataType.date): represents a Date (year, month and day) -2. [`DataType.timestamp()`](https://www.getdaft.io/projects/docs/en/stable/api_docs/datatype.html#daft.DataType.timestamp): represents a Timestamp (particular instance in time) - -See also: [Temporal Expressions](https://www.getdaft.io/projects/docs/en/stable/api_docs/expressions.html#api-expressions-temporal) - -## Nested - -Nested DataTypes wrap other DataTypes, allowing you to compose types into complex data structures. - -Examples: - -1. [`DataType.list(child_dtype)`](https://www.getdaft.io/projects/docs/en/stable/api_docs/datatype.html#daft.DataType.list): represents a list where each element is of the child `dtype` -2. [`DataType.struct({"field_name": child_dtype})`](https://www.getdaft.io/projects/docs/en/stable/api_docs/datatype.html#daft.DataType.struct): represents a structure that has children `dtype`s, each mapped to a field name - -## Python - -The [`DataType.python()`](https://www.getdaft.io/projects/docs/en/stable/api_docs/datatype.html#daft.DataType.python) dtype represent items that are Python objects. - -!!! warning "Warning" - - Daft does not impose any invariants about what *Python types* these objects are. To Daft, these are just generic Python objects! - -Python is AWESOME because it's so flexible, but it's also slow and memory inefficient! Thus we recommend: - -1. **Cast early!**: Casting your Python data into native Daft DataTypes if possible - this results in much more efficient downstream data serialization and computation. -2. **Use Python UDFs**: If there is no suitable Daft representation for your Python objects, use Python UDFs to process your Python data and extract the relevant data to be returned as native Daft DataTypes! - -!!! note "Note" - - If you work with Python classes for a generalizable use-case (e.g. documents, protobufs), it may be that these types are good candidates for "promotion" into a native Daft type! Please get in touch with the Daft team and we would love to work together on building your type into canonical Daft types. - -## Complex Types - -Daft supports many more interesting complex DataTypes, for example: - -* [`DataType.tensor()`](https://www.getdaft.io/projects/docs/en/stable/api_docs/datatype.html#daft.DataType.tensor): Multi-dimensional (potentially uniformly-shaped) tensors of data -* [`DataType.embedding()`](https://www.getdaft.io/projects/docs/en/stable/api_docs/datatype.html#daft.DataType.embedding): Lower-dimensional vector representation of data (e.g. words) -* [`DataType.image()`](https://www.getdaft.io/projects/docs/en/stable/api_docs/datatype.html#daft.DataType.image): NHWC images - -Daft abstracts away the in-memory representation of your data and provides kernels for many common operations on top of these data types. For supported image operations see the [image expressions API reference](https://www.getdaft.io/projects/docs/en/stable/api_docs/expressions.html#api-expressions-images). - -For more complex algorithms, you can also drop into a Python UDF to process this data using your custom Python libraries. - -Please add suggestions for new DataTypes to our [Github Discussions page](https://github.com/Eventual-Inc/Daft/discussions)! diff --git a/docs-v2/core_concepts/expressions.md b/docs-v2/core_concepts/expressions.md deleted file mode 100644 index 81ba19bfc6..0000000000 --- a/docs-v2/core_concepts/expressions.md +++ /dev/null @@ -1,744 +0,0 @@ -# Expressions - -Expressions are how you can express computations that should be run over columns of data. - -## Creating Expressions - -### Referring to a column in a DataFrame - -Most commonly you will be creating expressions by using the [`daft.col`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/expression_methods/daft.col.html#daft.col) function. - -=== "🐍 Python" - ``` python - # Refers to column "A" - daft.col("A") - ``` - -=== "⚙️ SQL" - ```python - daft.sql_expr("A") - ``` - -``` {title="Output"} - -col(A) -``` - -The above code creates an Expression that refers to a column named `"A"`. - -### Using SQL - -Daft can also parse valid SQL as expressions. - -=== "⚙️ SQL" - ```python - daft.sql_expr("A + 1") - ``` -``` {title="Output"} - -col(A) + lit(1) -``` - -The above code will create an expression representing "the column named 'x' incremented by 1". For many APIs, [`sql_expr`](https://www.getdaft.io/projects/docs/en/stable/api_docs/sql.html#daft.sql_expr) will actually be applied for you as syntactic sugar! - -### Literals - -You may find yourself needing to hardcode a "single value" oftentimes as an expression. Daft provides a [`lit()`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/expression_methods/daft.lit.html) helper to do so: - -=== "🐍 Python" - ``` python - from daft import lit - - # Refers to an expression which always evaluates to 42 - lit(42) - ``` - -=== "⚙️ SQL" - ```python - # Refers to an expression which always evaluates to 42 - daft.sql_expr("42") - ``` - -```{title="Output"} - -lit(42) -``` -This special :func:`~daft.expressions.lit` expression we just created evaluates always to the value ``42``. - -### Wildcard Expressions - -You can create expressions on multiple columns at once using a wildcard. The expression [`col("*")`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/expression_methods/daft.col.html#daft.col)) selects every column in a DataFrame, and you can operate on this expression in the same way as a single column: - -=== "🐍 Python" - ``` python - import daft - from daft import col - - df = daft.from_pydict({"A": [1, 2, 3], "B": [4, 5, 6]}) - df.select(col("*") * 3).show() - ``` - -``` {title="Output"} - -╭───────┬───────╮ -│ A ┆ B │ -│ --- ┆ --- │ -│ Int64 ┆ Int64 │ -╞═══════╪═══════╡ -│ 3 ┆ 12 │ -├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ -│ 6 ┆ 15 │ -├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ -│ 9 ┆ 18 │ -╰───────┴───────╯ -``` - -Wildcards also work very well for accessing all members of a struct column: - -=== "🐍 Python" - ``` python - - import daft - from daft import col - - df = daft.from_pydict({ - "person": [ - {"name": "Alice", "age": 30}, - {"name": "Bob", "age": 25}, - {"name": "Charlie", "age": 35} - ] - }) - - # Access all fields of the 'person' struct - df.select(col("person.*")).show() - ``` - -=== "⚙️ SQL" - ```python - import daft - - df = daft.from_pydict({ - "person": [ - {"name": "Alice", "age": 30}, - {"name": "Bob", "age": 25}, - {"name": "Charlie", "age": 35} - ] - }) - - # Access all fields of the 'person' struct using SQL - daft.sql("SELECT person.* FROM df").show() - ``` - -``` {title="Output"} - -╭──────────┬───────╮ -│ name ┆ age │ -│ --- ┆ --- │ -│ String ┆ Int64 │ -╞══════════╪═══════╡ -│ Alice ┆ 30 │ -├╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ -│ Bob ┆ 25 │ -├╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ -│ Charlie ┆ 35 │ -╰──────────┴───────╯ -``` - -In this example, we use the wildcard `*` to access all fields of the `person` struct column. This is equivalent to selecting each field individually (`person.name`, `person.age`), but is more concise and flexible, especially when dealing with structs that have many fields. - - - -## Composing Expressions - -### Numeric Expressions - -Since column "A" is an integer, we can run numeric computation such as addition, division and checking its value. Here are some examples where we create new columns using the results of such computations: - -=== "🐍 Python" - ``` python - # Add 1 to each element in column "A" - df = df.with_column("A_add_one", df["A"] + 1) - - # Divide each element in column A by 2 - df = df.with_column("A_divide_two", df["A"] / 2.) - - # Check if each element in column A is more than 1 - df = df.with_column("A_gt_1", df["A"] > 1) - - df.collect() - ``` - -=== "⚙️ SQL" - ```python - df = daft.sql(""" - SELECT - *, - A + 1 AS A_add_one, - A / 2.0 AS A_divide_two, - A > 1 AS A_gt_1 - FROM df - """) - df.collect() - ``` - -```{title="Output"} - -+---------+-------------+----------------+-----------+ -| A | A_add_one | A_divide_two | A_gt_1 | -| Int64 | Int64 | Float64 | Boolean | -+=========+=============+================+===========+ -| 1 | 2 | 0.5 | false | -+---------+-------------+----------------+-----------+ -| 2 | 3 | 1 | true | -+---------+-------------+----------------+-----------+ -| 3 | 4 | 1.5 | true | -+---------+-------------+----------------+-----------+ -(Showing first 3 of 3 rows) -``` - -Notice that the returned types of these operations are also well-typed according to their input types. For example, calling ``df["A"] > 1`` returns a column of type [`Boolean`](https://www.getdaft.io/projects/docs/en/stable/api_docs/datatype.html#daft.DataType.bool). - -Both the [`Float`](https://www.getdaft.io/projects/docs/en/stable/api_docs/datatype.html#daft.DataType.float32) and [`Int`](https://www.getdaft.io/projects/docs/en/stable/api_docs/datatype.html#daft.DataType.int16) types are numeric types, and inherit many of the same arithmetic Expression operations. You may find the full list of numeric operations in the [Expressions API Reference](https://www.getdaft.io/projects/docs/en/stable/api_docs/expressions.html#numeric). - -### String Expressions - -Daft also lets you have columns of strings in a DataFrame. Let's take a look! - -=== "🐍 Python" - ``` python - df = daft.from_pydict({"B": ["foo", "bar", "baz"]}) - df.show() - ``` - -``` {title="Output"} - -+--------+ -| B | -| Utf8 | -+========+ -| foo | -+--------+ -| bar | -+--------+ -| baz | -+--------+ -(Showing first 3 rows) -``` - -Unlike the numeric types, the string type does not support arithmetic operations such as `*` and `/`. The one exception to this is the `+` operator, which is overridden to concatenate two string expressions as is commonly done in Python. Let's try that! - -=== "🐍 Python" - ``` python - df = df.with_column("B2", df["B"] + "foo") - df.show() - ``` - -=== "⚙️ SQL" - ```python - df = daft.sql("SELECT *, B + 'foo' AS B2 FROM df") - df.show() - ``` - -``` {title="Output"} - -+--------+--------+ -| B | B2 | -| Utf8 | Utf8 | -+========+========+ -| foo | foofoo | -+--------+--------+ -| bar | barfoo | -+--------+--------+ -| baz | bazfoo | -+--------+--------+ -(Showing first 3 rows) -``` - -There are also many string operators that are accessed through a separate [`.str.*`](https://www.getdaft.io/projects/docs/en/stable/api_docs/expressions.html#strings) "method namespace". - -For example, to check if each element in column "B" contains the substring "a", we can use the [`.str.contains`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/expression_methods/daft.Expression.str.contains.html#daft.Expression.str.contains) method: - -=== "🐍 Python" - ``` python - df = df.with_column("B2_contains_B", df["B2"].str.contains(df["B"])) - df.show() - ``` - -=== "⚙️ SQL" - ```python - df = daft.sql("SELECT *, contains(B2, B) AS B2_contains_B FROM df") - df.show() - ``` - -``` {title="Output"} - -+--------+--------+-----------------+ -| B | B2 | B2_contains_B | -| Utf8 | Utf8 | Boolean | -+========+========+=================+ -| foo | foofoo | true | -+--------+--------+-----------------+ -| bar | barfoo | true | -+--------+--------+-----------------+ -| baz | bazfoo | true | -+--------+--------+-----------------+ -(Showing first 3 rows) -``` - -You may find a full list of string operations in the [Expressions API Reference](https://www.getdaft.io/projects/docs/en/stable/api_docs/expressions.html). - -### URL Expressions - -One special case of a String column you may find yourself working with is a column of URL strings. - -Daft provides the [`.url.*`](https://www.getdaft.io/projects/docs/en/stable/api_docs/expressions.html) method namespace with functionality for working with URL strings. For example, to download data from URLs: - -=== "🐍 Python" - ``` python - df = daft.from_pydict({ - "urls": [ - "https://www.google.com", - "s3://daft-public-data/open-images/validation-images/0001eeaf4aed83f9.jpg", - ], - }) - df = df.with_column("data", df["urls"].url.download()) - df.collect() - ``` - -=== "⚙️ SQL" - ```python - df = daft.from_pydict({ - "urls": [ - "https://www.google.com", - "s3://daft-public-data/open-images/validation-images/0001eeaf4aed83f9.jpg", - ], - }) - df = daft.sql(""" - SELECT - urls, - url_download(urls) AS data - FROM df - """) - df.collect() - ``` - -``` {title="Output"} - -+----------------------+----------------------+ -| urls | data | -| Utf8 | Binary | -+======================+======================+ -| https://www.google.c | b' df["B"]).if_else(df["A"], df["B"]), - ) - - df.collect() - ``` - -=== "⚙️ SQL" - ```python - df = daft.from_pydict({"A": [1, 2, 3], "B": [0, 2, 4]}) - - df = daft.sql(""" - SELECT - A, - B, - CASE - WHEN A > B THEN A - ELSE B - END AS A_if_bigger_else_B - FROM df - """) - - df.collect() - ``` - -```{title="Output"} - -╭───────┬───────┬────────────────────╮ -│ A ┆ B ┆ A_if_bigger_else_B │ -│ --- ┆ --- ┆ --- │ -│ Int64 ┆ Int64 ┆ Int64 │ -╞═══════╪═══════╪════════════════════╡ -│ 1 ┆ 0 ┆ 1 │ -├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ -│ 2 ┆ 2 ┆ 2 │ -├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ -│ 3 ┆ 4 ┆ 4 │ -╰───────┴───────┴────────────────────╯ - -(Showing first 3 of 3 rows) -``` - -This is a useful expression for cleaning your data! - - -### Temporal Expressions - -Daft provides rich support for working with temporal data types like Timestamp and Duration. Let's explore some common temporal operations: - -#### Basic Temporal Operations - -You can perform arithmetic operations with timestamps and durations, such as adding a duration to a timestamp or calculating the duration between two timestamps: - -=== "🐍 Python" - ``` python - import datetime - - df = daft.from_pydict({ - "timestamp": [ - datetime.datetime(2021, 1, 1, 0, 1, 1), - datetime.datetime(2021, 1, 1, 0, 1, 59), - datetime.datetime(2021, 1, 1, 0, 2, 0), - ] - }) - - # Add 10 seconds to each timestamp - df = df.with_column( - "plus_10_seconds", - df["timestamp"] + datetime.timedelta(seconds=10) - ) - - df.show() - ``` - -=== "⚙️ SQL" - ```python - import datetime - - df = daft.from_pydict({ - "timestamp": [ - datetime.datetime(2021, 1, 1, 0, 1, 1), - datetime.datetime(2021, 1, 1, 0, 1, 59), - datetime.datetime(2021, 1, 1, 0, 2, 0), - ] - }) - - # Add 10 seconds to each timestamp and calculate duration between timestamps - df = daft.sql(""" - SELECT - timestamp, - timestamp + INTERVAL '10 seconds' as plus_10_seconds, - FROM df - """) - - df.show() - ``` - -``` {title="Output"} - -╭───────────────────────────────┬───────────────────────────────╮ -│ timestamp ┆ plus_10_seconds │ -│ --- ┆ --- │ -│ Timestamp(Microseconds, None) ┆ Timestamp(Microseconds, None) │ -╞═══════════════════════════════╪═══════════════════════════════╡ -│ 2021-01-01 00:01:01 ┆ 2021-01-01 00:01:11 │ -├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ -│ 2021-01-01 00:01:59 ┆ 2021-01-01 00:02:09 │ -├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ -│ 2021-01-01 00:02:00 ┆ 2021-01-01 00:02:10 │ -╰───────────────────────────────┴───────────────────────────────╯ -``` - -#### Temporal Component Extraction - -The [`.dt.*`](https://www.getdaft.io/projects/docs/en/stable/api_docs/expressions.html#temporal) method namespace provides extraction methods for the components of a timestamp, such as year, month, day, hour, minute, and second: - -=== "🐍 Python" - ``` python - df = daft.from_pydict({ - "timestamp": [ - datetime.datetime(2021, 1, 1, 0, 1, 1), - datetime.datetime(2021, 1, 1, 0, 1, 59), - datetime.datetime(2021, 1, 1, 0, 2, 0), - ] - }) - - # Extract year, month, day, hour, minute, and second from the timestamp - df = df.with_columns({ - "year": df["timestamp"].dt.year(), - "month": df["timestamp"].dt.month(), - "day": df["timestamp"].dt.day(), - "hour": df["timestamp"].dt.hour(), - "minute": df["timestamp"].dt.minute(), - "second": df["timestamp"].dt.second() - }) - - df.show() - ``` - -=== "⚙️ SQL" - ```python - df = daft.from_pydict({ - "timestamp": [ - datetime.datetime(2021, 1, 1, 0, 1, 1), - datetime.datetime(2021, 1, 1, 0, 1, 59), - datetime.datetime(2021, 1, 1, 0, 2, 0), - ] - }) - - # Extract year, month, day, hour, minute, and second from the timestamp - df = daft.sql(""" - SELECT - timestamp, - year(timestamp) as year, - month(timestamp) as month, - day(timestamp) as day, - hour(timestamp) as hour, - minute(timestamp) as minute, - second(timestamp) as second - FROM df - """) - - df.show() - ``` - -``` {title="Output"} - -╭───────────────────────────────┬───────┬────────┬────────┬────────┬────────┬────────╮ -│ timestamp ┆ year ┆ month ┆ day ┆ hour ┆ minute ┆ second │ -│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ -│ Timestamp(Microseconds, None) ┆ Int32 ┆ UInt32 ┆ UInt32 ┆ UInt32 ┆ UInt32 ┆ UInt32 │ -╞═══════════════════════════════╪═══════╪════════╪════════╪════════╪════════╪════════╡ -│ 2021-01-01 00:01:01 ┆ 2021 ┆ 1 ┆ 1 ┆ 0 ┆ 1 ┆ 1 │ -├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤ -│ 2021-01-01 00:01:59 ┆ 2021 ┆ 1 ┆ 1 ┆ 0 ┆ 1 ┆ 59 │ -├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤ -│ 2021-01-01 00:02:00 ┆ 2021 ┆ 1 ┆ 1 ┆ 0 ┆ 2 ┆ 0 │ -╰───────────────────────────────┴───────┴────────┴────────┴────────┴────────┴────────╯ -``` - -#### Time Zone Operations - -You can parse strings as timestamps with time zones and convert between different time zones: - -=== "🐍 Python" - ``` python - df = daft.from_pydict({ - "timestamp_str": [ - "2021-01-01 00:00:00.123 +0800", - "2021-01-02 12:30:00.456 +0800" - ] - }) - - # Parse the timestamp string with time zone and convert to New York time - df = df.with_column( - "ny_time", - df["timestamp_str"].str.to_datetime( - "%Y-%m-%d %H:%M:%S%.3f %z", - timezone="America/New_York" - ) - ) - - df.show() - ``` - -=== "⚙️ SQL" - ```python - df = daft.from_pydict({ - "timestamp_str": [ - "2021-01-01 00:00:00.123 +0800", - "2021-01-02 12:30:00.456 +0800" - ] - }) - - # Parse the timestamp string with time zone and convert to New York time - df = daft.sql(""" - SELECT - timestamp_str, - to_datetime(timestamp_str, '%Y-%m-%d %H:%M:%S%.3f %z', 'America/New_York') as ny_time - FROM df - """) - - df.show() - ``` - -``` {title="Output"} - -╭───────────────────────────────┬───────────────────────────────────────────────────╮ -│ timestamp_str ┆ ny_time │ -│ --- ┆ --- │ -│ Utf8 ┆ Timestamp(Milliseconds, Some("America/New_York")) │ -╞═══════════════════════════════╪═══════════════════════════════════════════════════╡ -│ 2021-01-01 00:00:00.123 +0800 ┆ 2020-12-31 11:00:00.123 EST │ -├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ -│ 2021-01-02 12:30:00.456 +0800 ┆ 2021-01-01 23:30:00.456 EST │ -╰───────────────────────────────┴───────────────────────────────────────────────────╯ -``` - -#### Temporal Truncation - -The [`.dt.truncate()`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/expression_methods/daft.Expression.dt.truncate.html#daft.Expression.dt.truncate) method allows you to truncate timestamps to specific time units. This can be useful for grouping data by time periods. For example, to truncate timestamps to the nearest hour: - -=== "🐍 Python" - ``` python - df = daft.from_pydict({ - "timestamp": [ - datetime.datetime(2021, 1, 7, 0, 1, 1), - datetime.datetime(2021, 1, 8, 0, 1, 59), - datetime.datetime(2021, 1, 9, 0, 30, 0), - datetime.datetime(2021, 1, 10, 1, 59, 59), - ] - }) - - # Truncate timestamps to the nearest hour - df = df.with_column( - "hour_start", - df["timestamp"].dt.truncate("1 hour") - ) - - df.show() - ``` - -``` {title="Output"} - -╭───────────────────────────────┬───────────────────────────────╮ -│ timestamp ┆ hour_start │ -│ --- ┆ --- │ -│ Timestamp(Microseconds, None) ┆ Timestamp(Microseconds, None) │ -╞═══════════════════════════════╪═══════════════════════════════╡ -│ 2021-01-07 00:01:01 ┆ 2021-01-07 00:00:00 │ -├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ -│ 2021-01-08 00:01:59 ┆ 2021-01-08 00:00:00 │ -├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ -│ 2021-01-09 00:30:00 ┆ 2021-01-09 00:00:00 │ -├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ -│ 2021-01-10 01:59:59 ┆ 2021-01-10 01:00:00 │ -╰───────────────────────────────┴───────────────────────────────╯ -``` diff --git a/docs-v2/core_concepts/read_write.md b/docs-v2/core_concepts/read_write.md deleted file mode 100644 index fa3aeef066..0000000000 --- a/docs-v2/core_concepts/read_write.md +++ /dev/null @@ -1,142 +0,0 @@ -# Reading/Writing Data - -!!! failure "todo(docs): Should this page also have sql examples?" - -Daft can read data from a variety of sources, and write data to many destinations. - -## Reading Data - -### From Files - -DataFrames can be loaded from file(s) on some filesystem, commonly your local filesystem or a remote cloud object store such as AWS S3. Additionally, Daft can read data from a variety of container file formats, including CSV, line-delimited JSON and Parquet. - -Daft supports file paths to a single file, a directory of files, and wildcards. It also supports paths to remote object storage such as AWS S3. -=== "🐍 Python" - ```python - import daft - - # You can read a single CSV file from your local filesystem - df = daft.read_csv("path/to/file.csv") - - # You can also read folders of CSV files, or include wildcards to select for patterns of file paths - df = daft.read_csv("path/to/*.csv") - - # Other formats such as parquet and line-delimited JSON are also supported - df = daft.read_parquet("path/to/*.parquet") - df = daft.read_json("path/to/*.json") - - # Remote filesystems such as AWS S3 are also supported, and can be specified with their protocols - df = daft.read_csv("s3://mybucket/path/to/*.csv") - ``` - -To learn more about each of these constructors, as well as the options that they support, consult the API documentation on [`creating DataFrames from files`](https://www.getdaft.io/projects/docs/en/stable/api_docs/creation.html#df-io-files). - -### From Data Catalogs - -If you use catalogs such as Apache Iceberg or Hive, you may wish to consult our user guide on integrations with Data Catalogs: [`Daft integration with Data Catalogs`](https://www.getdaft.io/projects/docs/en/stable/user_guide/integrations.html). - -### From File Paths - -Daft also provides an easy utility to create a DataFrame from globbing a path. You can use the [`daft.from_glob_path`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/io_functions/daft.from_glob_path.html#daft.from_glob_path) method which will read a DataFrame of globbed filepaths. - -=== "🐍 Python" - ``` python - df = daft.from_glob_path("s3://mybucket/path/to/images/*.jpeg") - - # +----------+------+-----+ - # | name | size | ... | - # +----------+------+-----+ - # ... - ``` - -This is especially useful for reading things such as a folder of images or documents into Daft. A common pattern is to then download data from these files into your DataFrame as bytes, using the [`.url.download()`](https://getdaft.io/projects/docs/en/stable/api_docs/doc_gen/expression_methods/daft.Expression.url.download.html#daft.Expression.url.download) method. - - -### From Memory - -For testing, or small datasets that fit in memory, you may also create DataFrames using Python lists and dictionaries. - -=== "🐍 Python" - ``` python - # Create DataFrame using a dictionary of {column_name: list_of_values} - df = daft.from_pydict({"A": [1, 2, 3], "B": ["foo", "bar", "baz"]}) - - # Create DataFrame using a list of rows, where each row is a dictionary of {column_name: value} - df = daft.from_pylist([{"A": 1, "B": "foo"}, {"A": 2, "B": "bar"}, {"A": 3, "B": "baz"}]) - ``` - -To learn more, consult the API documentation on [`creating DataFrames from in-memory data structures`](https://www.getdaft.io/projects/docs/en/stable/api_docs/creation.html#df-io-in-memory). - -### From Databases - -Daft can also read data from a variety of databases, including PostgreSQL, MySQL, Trino, and SQLite using the [`daft.read_sql`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/io_functions/daft.read_sql.html#daft.read_sql) method. In order to partition the data, you can specify a partition column, which will allow Daft to read the data in parallel. - -=== "🐍 Python" - ``` python - # Read from a PostgreSQL database - uri = "postgresql://user:password@host:port/database" - df = daft.read_sql("SELECT * FROM my_table", uri) - - # Read with a partition column - df = daft.read_sql("SELECT * FROM my_table", partition_col="date", uri) - ``` - -To learn more, consult the [`SQL User Guide`](https://www.getdaft.io/projects/docs/en/stable/user_guide/integrations/sql.html) or the API documentation on [`daft.read_sql`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/io_functions/daft.read_sql.html#daft.read_sql). - -## Reading a column of URLs - -Daft provides a convenient way to read data from a column of URLs using the [`.url.download()`](https://getdaft.io/projects/docs/en/stable/api_docs/doc_gen/expression_methods/daft.Expression.url.download.html#daft.Expression.url.download) method. This is particularly useful when you have a DataFrame with a column containing URLs pointing to external resources that you want to fetch and incorporate into your DataFrame. - -Here's an example of how to use this feature: - -=== "🐍 Python" - ```python - # Assume we have a DataFrame with a column named 'image_urls' - df = daft.from_pydict({ - "image_urls": [ - "https://example.com/image1.jpg", - "https://example.com/image2.jpg", - "https://example.com/image3.jpg" - ] - }) - - # Download the content from the URLs and create a new column 'image_data' - df = df.with_column("image_data", df["image_urls"].url.download()) - df.show() - ``` - -``` {title="Output"} - -+------------------------------------+------------------------------------+ -| image_urls | image_data | -| Utf8 | Binary | -+====================================+====================================+ -| https://example.com/image1.jpg | b'\xff\xd8\xff\xe0\x00\x10JFIF...' | -+------------------------------------+------------------------------------+ -| https://example.com/image2.jpg | b'\xff\xd8\xff\xe0\x00\x10JFIF...' | -+------------------------------------+------------------------------------+ -| https://example.com/image3.jpg | b'\xff\xd8\xff\xe0\x00\x10JFIF...' | -+------------------------------------+------------------------------------+ - -(Showing first 3 of 3 rows) -``` - -This approach allows you to efficiently download and process data from a large number of URLs in parallel, leveraging Daft's distributed computing capabilities. - -## Writing Data - -Writing data will execute your DataFrame and write the results out to the specified backend. The [`df.write_*(...)`](https://www.getdaft.io/projects/docs/en/stable/api_docs/dataframe.html#df-write-data) methods are used to write DataFrames to files or other destinations. - -=== "🐍 Python" - ``` python - # Write to various file formats in a local folder - df.write_csv("path/to/folder/") - df.write_parquet("path/to/folder/") - - # Write DataFrame to a remote filesystem such as AWS S3 - df.write_csv("s3://mybucket/path/") - ``` - -!!! note "Note" - - Because Daft is a distributed DataFrame library, by default it will produce multiple files (one per partition) at your specified destination. Writing your dataframe is a **blocking** operation that executes your DataFrame. It will return a new `DataFrame` that contains the filepaths to the written data. diff --git a/docs-v2/core_concepts/sql.md b/docs-v2/core_concepts/sql.md deleted file mode 100644 index 55ebde486e..0000000000 --- a/docs-v2/core_concepts/sql.md +++ /dev/null @@ -1,224 +0,0 @@ -# SQL - -Daft supports Structured Query Language (SQL) as a way of constructing query plans (represented in Python as a [`daft.DataFrame`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/dataframe_methods/daft.DataFrame.html#daft.DataFrame)) and expressions ([`daft.Expression`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/dataframe_methods/daft.DataFrame.html#daft.DataFrame)). - -SQL is a human-readable way of constructing these query plans, and can often be more ergonomic than using DataFrames for writing queries. - -!!! tip "Daft's SQL support is new and is constantly being improved on!" - - Please give us feedback or submit an [issue](https://github.com/Eventual-Inc/Daft/issues) and we'd love to hear more about what you would like. - - -## Running SQL on DataFrames - -Daft's [`daft.sql`](https://www.getdaft.io/projects/docs/en/stable/api_docs/sql.html#daft.sql) function will automatically detect any [`daft.DataFrame`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/dataframe_methods/daft.DataFrame.html#daft.DataFrame) objects in your current Python environment to let you query them easily by name. - -=== "⚙️ SQL" - ```python - # Note the variable name `my_special_df` - my_special_df = daft.from_pydict({"A": [1, 2, 3], "B": [1, 2, 3]}) - - # Use the SQL table name "my_special_df" to refer to the above DataFrame! - sql_df = daft.sql("SELECT A, B FROM my_special_df") - - sql_df.show() - ``` - -``` {title="Output"} - -╭───────┬───────╮ -│ A ┆ B │ -│ --- ┆ --- │ -│ Int64 ┆ Int64 │ -╞═══════╪═══════╡ -│ 1 ┆ 1 │ -├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ -│ 2 ┆ 2 │ -├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ -│ 3 ┆ 3 │ -╰───────┴───────╯ - -(Showing first 3 of 3 rows) -``` - -In the above example, we query the DataFrame called `"my_special_df"` by simply referring to it in the SQL command. This produces a new DataFrame `sql_df` which can natively integrate with the rest of your Daft query. - -## Reading data from SQL - -!!! warning "Warning" - - This feature is a WIP and will be coming soon! We will support reading common datasources directly from SQL: - - === "🐍 Python" - - ```python - daft.sql("SELECT * FROM read_parquet('s3://...')") - daft.sql("SELECT * FROM read_delta_lake('s3://...')") - ``` - - Today, a workaround for this is to construct your dataframe in Python first and use it from SQL instead: - - === "🐍 Python" - - ```python - df = daft.read_parquet("s3://...") - daft.sql("SELECT * FROM df") - ``` - - We appreciate your patience with us and hope to deliver this crucial feature soon! - -## SQL Expressions - -SQL has the concept of expressions as well. Here is an example of a simple addition expression, adding columns "a" and "b" in SQL to produce a new column C. - -We also present here the equivalent query for SQL and DataFrame. Notice how similar the concepts are! - -=== "⚙️ SQL" - ```python - df = daft.from_pydict({"A": [1, 2, 3], "B": [1, 2, 3]}) - df = daft.sql("SELECT A + B as C FROM df") - df.show() - ``` - -=== "🐍 Python" - ``` python - expr = (daft.col("A") + daft.col("B")).alias("C") - - df = daft.from_pydict({"A": [1, 2, 3], "B": [1, 2, 3]}) - df = df.select(expr) - df.show() - ``` - -``` {title="Output"} - -╭───────╮ -│ C │ -│ --- │ -│ Int64 │ -╞═══════╡ -│ 2 │ -├╌╌╌╌╌╌╌┤ -│ 4 │ -├╌╌╌╌╌╌╌┤ -│ 6 │ -╰───────╯ - -(Showing first 3 of 3 rows) -``` - -In the above query, both the SQL version of the query and the DataFrame version of the query produce the same result. - -Under the hood, they run the same Expression `col("A") + col("B")`! - -One really cool trick you can do is to use the [`daft.sql_expr`](https://www.getdaft.io/projects/docs/en/stable/api_docs/sql.html#daft.sql_expr) function as a helper to easily create Expressions. The following are equivalent: - -=== "⚙️ SQL" - ```python - sql_expr = daft.sql_expr("A + B as C") - print("SQL expression:", sql_expr) - ``` - -=== "🐍 Python" - ``` python - py_expr = (daft.col("A") + daft.col("B")).alias("C") - print("Python expression:", py_expr) - ``` - -``` {title="Output"} - -SQL expression: col(A) + col(B) as C -Python expression: col(A) + col(B) as C -``` - -This means that you can pretty much use SQL anywhere you use Python expressions, making Daft extremely versatile at mixing workflows which leverage both SQL and Python. - -As an example, consider the filter query below and compare the two equivalent Python and SQL queries: - -=== "⚙️ SQL" - ```python - df = daft.from_pydict({"A": [1, 2, 3], "B": [1, 2, 3]}) - - # Daft automatically converts this string using `daft.sql_expr` - df = df.where("A < 2") - - df.show() - ``` - -=== "🐍 Python" - ``` python - df = daft.from_pydict({"A": [1, 2, 3], "B": [1, 2, 3]}) - - # Using Daft's Python Expression API - df = df.where(df["A"] < 2) - - df.show() - ``` - -``` {title="Output"} - -╭───────┬───────╮ -│ A ┆ B │ -│ --- ┆ --- │ -│ Int64 ┆ Int64 │ -╞═══════╪═══════╡ -│ 1 ┆ 1 │ -╰───────┴───────╯ - -(Showing first 1 of 1 rows) -``` - -Pretty sweet! Of course, this support for running Expressions on your columns extends well beyond arithmetic as we'll see in the next section on SQL Functions. - -## SQL Functions - -SQL also has access to all of Daft's powerful [`daft.Expression`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/dataframe_methods/daft.DataFrame.html#daft.DataFrame) functionality through SQL functions. - -However, unlike the Python Expression API which encourages method-chaining (e.g. `col("a").url.download().image.decode()`), in SQL you have to do function nesting instead (e.g. `"image_decode(url_download(a))"`). - -!!! note "Note" - - A full catalog of the available SQL Functions in Daft is available in the [`../api_docs/sql`](https://www.getdaft.io/projects/docs/en/stable/api_docs/sql.html). - - Note that it closely mirrors the Python API, with some function naming differences vs the available Python methods. - We also have some aliased functions for ANSI SQL-compliance or familiarity to users coming from other common SQL dialects such as PostgreSQL and SparkSQL to easily find their functionality. - -Here is an example of an equivalent function call in SQL vs Python: - -=== "⚙️ SQL" - ```python - df = daft.from_pydict({"urls": [ - "https://user-images.githubusercontent.com/17691182/190476440-28f29e87-8e3b-41c4-9c28-e112e595f558.png", - "https://user-images.githubusercontent.com/17691182/190476440-28f29e87-8e3b-41c4-9c28-e112e595f558.png", - "https://user-images.githubusercontent.com/17691182/190476440-28f29e87-8e3b-41c4-9c28-e112e595f558.png", - ]}) - df = daft.sql("SELECT image_decode(url_download(urls)) FROM df") - df.show() - ``` - -=== "🐍 Python" - ``` python - df = daft.from_pydict({"urls": [ - "https://user-images.githubusercontent.com/17691182/190476440-28f29e87-8e3b-41c4-9c28-e112e595f558.png", - "https://user-images.githubusercontent.com/17691182/190476440-28f29e87-8e3b-41c4-9c28-e112e595f558.png", - "https://user-images.githubusercontent.com/17691182/190476440-28f29e87-8e3b-41c4-9c28-e112e595f558.png", - ]}) - df = df.select(daft.col("urls").url.download().image.decode()) - df.show() - ``` - -``` {title="Output"} - -╭──────────────╮ -│ urls │ -│ --- │ -│ Image[MIXED] │ -╞══════════════╡ -│ │ -├╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ -│ │ -├╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ -│ │ -╰──────────────╯ - -(Showing first 3 of 3 rows) -``` diff --git a/docs-v2/core_concepts/udf.md b/docs-v2/core_concepts/udf.md deleted file mode 100644 index a57913a05f..0000000000 --- a/docs-v2/core_concepts/udf.md +++ /dev/null @@ -1,213 +0,0 @@ -# User-Defined Functions (UDF) - -A key piece of functionality in Daft is the ability to flexibly define custom functions that can run on any data in your dataframe. This guide walks you through the different types of UDFs that Daft allows you to run. - -Let's first create a dataframe that will be used as a running example throughout this tutorial! - -=== "🐍 Python" - ``` python - import daft - import numpy as np - - df = daft.from_pydict({ - # the `image` column contains images represented as 2D numpy arrays - "image": [np.ones((128, 128)) for i in range(16)], - # the `crop` column contains a box to crop from our image, represented as a list of integers: [x1, x2, y1, y2] - "crop": [[0, 1, 0, 1] for i in range(16)], - }) - ``` - - -## Per-column per-row functions using [`.apply`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/expression_methods/daft.Expression.apply.html) - -You can use [`.apply`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/expression_methods/daft.Expression.apply.html) to run a Python function on every row in a column. - -For example, the following example creates a new `flattened_image` column by calling `.flatten()` on every object in the `image` column. - -=== "🐍 Python" - ``` python - df.with_column( - "flattened_image", - df["image"].apply(lambda img: img.flatten(), return_dtype=daft.DataType.python()) - ).show(2) - ``` - -``` {title="Output"} - -+----------------------+---------------+---------------------+ -| image | crop | flattened_image | -| Python | List[Int64] | Python | -+======================+===============+=====================+ -| [[1. 1. 1. ... 1. 1. | [0, 1, 0, 1] | [1. 1. 1. ... 1. 1. | -| 1.] [1. 1. 1. ... | | 1.] | -| 1. 1. 1.] [1. 1.... | | | -+----------------------+---------------+---------------------+ -| [[1. 1. 1. ... 1. 1. | [0, 1, 0, 1] | [1. 1. 1. ... 1. 1. | -| 1.] [1. 1. 1. ... | | 1.] | -| 1. 1. 1.] [1. 1.... | | | -+----------------------+---------------+---------------------+ -(Showing first 2 rows) -``` - -Note here that we use the `return_dtype` keyword argument to specify that our returned column type is a Python column! - -## Multi-column per-partition functions using [`@udf`](https://www.getdaft.io/projects/docs/en/stable/api_docs/udf.html#creating-udfs) - -[`.apply`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/expression_methods/daft.Expression.apply.html) is great for convenience, but has two main limitations: - -1. It can only run on single columns -2. It can only run on single items at a time - -Daft provides the [`@udf`](https://www.getdaft.io/projects/docs/en/stable/api_docs/udf.html#creating-udfs) decorator for defining your own UDFs that process multiple columns or multiple rows at a time. - -For example, let's try writing a function that will crop all our images in the `image` column by its corresponding value in the `crop` column: - -=== "🐍 Python" - ``` python - @daft.udf(return_dtype=daft.DataType.python()) - def crop_images(images, crops, padding=0): - cropped = [] - for img, crop in zip(images.to_pylist(), crops.to_pylist()): - x1, x2, y1, y2 = crop - cropped_img = img[x1:x2 + padding, y1:y2 + padding] - cropped.append(cropped_img) - return cropped - - df = df.with_column( - "cropped", - crop_images(df["image"], df["crop"], padding=1), - ) - df.show(2) - ``` - -``` {title="Output"} - -+----------------------+---------------+--------------------+ -| image | crop | cropped | -| Python | List[Int64] | Python | -+======================+===============+====================+ -| [[1. 1. 1. ... 1. 1. | [0, 1, 0, 1] | [[1. 1.] [1. 1.]] | -| 1.] [1. 1. 1. ... | | | -| 1. 1. 1.] [1. 1.... | | | -+----------------------+---------------+--------------------+ -| [[1. 1. 1. ... 1. 1. | [0, 1, 0, 1] | [[1. 1.] [1. 1.]] | -| 1.] [1. 1. 1. ... | | | -| 1. 1. 1.] [1. 1.... | | | -+----------------------+---------------+--------------------+ -(Showing first 2 rows) -``` - -There's a few things happening here, let's break it down: - -1. `crop_images` is a normal Python function. It takes as input: - - a. A list of images: `images` - - b. A list of cropping boxes: `crops` - - c. An integer indicating how much padding to apply to the right and bottom of the cropping: `padding` - -2. To allow Daft to pass column data into the `images` and `crops` arguments, we decorate the function with [`@udf`](https://www.getdaft.io/projects/docs/en/stable/api_docs/udf.html#creating-udfs) - - a. `return_dtype` defines the returned data type. In this case, we return a column containing Python objects of numpy arrays - - b. At runtime, because we call the UDF on the `image` and `crop` columns, the UDF will receive a [`daft.Series`](https://www.getdaft.io/projects/docs/en/stable/api_docs/series.html#daft.Series) object for each argument. - -3. We can create a new column in our DataFrame by applying our UDF on the `"image"` and `"crop"` columns inside of a [`df.with_column()`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/dataframe_methods/daft.DataFrame.with_column.html#daft.DataFrame.with_column) call. - -### UDF Inputs - - -When you specify an Expression as an input to a UDF, Daft will calculate the result of that Expression and pass it into your function as a [`daft.Series`](https://www.getdaft.io/projects/docs/en/stable/api_docs/series.html#daft.Series) object. - -The Daft [`daft.Series`](https://www.getdaft.io/projects/docs/en/stable/api_docs/series.html#daft.Series) is just an abstraction on a "column" of data! You can obtain several different data representations from a [`daft.Series`](https://www.getdaft.io/projects/docs/en/stable/api_docs/series.html#daft.Series): - -1. PyArrow Arrays (`pa.Array`): [`s.to_arrow()`](https://www.getdaft.io/projects/docs/en/stable/api_docs/series.html#daft.Series.to_arrow) -2. Python lists (`list`): [`s.to_pylist()`](https://www.getdaft.io/projects/docs/en/stable/api_docs/series.html#daft.Series.to_pylist) - -Depending on your application, you may choose a different data representation that is more performant or more convenient! - -!!! info "Info" - - Certain array formats have some restrictions around the type of data that they can handle: - - 1. **Null Handling**: In Pandas and Numpy, nulls are represented as NaNs for numeric types, and Nones for non-numeric types. Additionally, the existence of nulls will trigger a type casting from integer to float arrays. If null handling is important to your use-case, we recommend using one of the other available options. - - 2. **Python Objects**: PyArrow array formats cannot support Python columns. - - We recommend using Python lists if performance is not a major consideration, and using the arrow-native formats such as PyArrow arrays and numpy arrays if performance is important. - -### Return Types - -The `return_dtype` argument specifies what type of column your UDF will return. Types can be specified using the [`daft.DataType`](https://www.getdaft.io/projects/docs/en/stable/api_docs/datatype.html#daft.DataType) class. - -Your UDF function itself needs to return a batch of columnar data, and can do so as any one of the following array types: - -1. Numpy Arrays (`np.ndarray`) -2. PyArrow Arrays (`pa.Array`) -3. Python lists (`list`) - -Note that if the data you have returned is not castable to the return_dtype that you specify (e.g. if you return a list of floats when you've specified a `return_dtype=DataType.bool()`), Daft will throw a runtime error! - -## Class UDFs - -UDFs can also be created on Classes, which allow for initialization on some expensive state that can be shared between invocations of the class, for example downloading data or creating a model. - -=== "🐍 Python" - ``` python - @daft.udf(return_dtype=daft.DataType.int64()) - class RunModel: - - def __init__(self): - # Perform expensive initializations - self._model = create_model() - - def __call__(self, features_col): - return self._model(features_col) - ``` - -Running Class UDFs are exactly the same as running their functional cousins. - -=== "🐍 Python" - ``` python - df = df.with_column("image_classifications", RunModel(df["images"])) - ``` - -## Resource Requests - -Sometimes, you may want to request for specific resources for your UDF. For example, some UDFs need one GPU to run as they will load a model onto the GPU. - -To do so, you can create your UDF and assign it a resource request: - -=== "🐍 Python" - ``` python - @daft.udf(return_dtype=daft.DataType.int64(), num_gpus=1) - class RunModelWithOneGPU: - - def __init__(self): - # Perform expensive initializations - self._model = create_model() - - def __call__(self, features_col): - return self._model(features_col) - ``` - - ``` python - df = df.with_column( - "image_classifications", - RunModelWithOneGPU(df["images"]), - ) - ``` - -In the above example, if Daft ran on a Ray cluster consisting of 8 GPUs and 64 CPUs, Daft would be able to run 8 replicas of your UDF in parallel, thus massively increasing the throughput of your UDF! - -UDFs can also be parametrized with new resource requests after being initialized. - -=== "🐍 Python" - ``` python - RunModelWithTwoGPUs = RunModelWithOneGPU.override_options(num_gpus=2) - df = df.with_column( - "image_classifications", - RunModelWithTwoGPUs(df["images"]), - ) - ``` diff --git a/docs-v2/distributed.md b/docs-v2/distributed.md new file mode 100644 index 0000000000..455cd726cf --- /dev/null +++ b/docs-v2/distributed.md @@ -0,0 +1,303 @@ +# Distributed Computing + +!!! failure "todo(docs): add daft launcher docs and review order of information" + +By default, Daft runs using your local machine's resources and your operations are thus limited by the CPUs, memory and GPUs available to you in your single local development machine. + +However, Daft has strong integrations with [Ray](https://www.ray.io) which is a distributed computing framework for distributing computations across a cluster of machines. Here is a snippet showing how you can connect Daft to a Ray cluster: + +=== "🐍 Python" + + ```python + import daft + + daft.context.set_runner_ray() + ``` + +By default, if no address is specified Daft will spin up a Ray cluster locally on your machine. If you are running Daft on a powerful machine (such as an AWS P3 machine which is equipped with multiple GPUs) this is already very useful because Daft can parallelize its execution of computation across your CPUs and GPUs. However, if instead you already have your own Ray cluster running remotely, you can connect Daft to it by supplying an address: + +=== "🐍 Python" + + ```python + daft.context.set_runner_ray(address="ray://url-to-mycluster") + ``` + +For more information about the `address` keyword argument, please see the [Ray documentation on initialization](https://docs.ray.io/en/latest/ray-core/api/doc/ray.init.html). + + +If you want to start a single node ray cluster on your local machine, you can do the following: + +```bash +> pip install ray[default] +> ray start --head --port=6379 +``` + +This should output something like: + +``` +Usage stats collection is enabled. To disable this, add `--disable-usage-stats` to the command that starts the cluster, or run the following command: `ray disable-usage-stats` before starting the cluster. See https://docs.ray.io/en/master/cluster/usage-stats.html for more details. + +Local node IP: 127.0.0.1 + +-------------------- +Ray runtime started. +-------------------- + +... +``` + +You can take the IP address and port and pass it to Daft: + +=== "🐍 Python" + + ```python + >>> import daft + >>> daft.context.set_runner_ray("127.0.0.1:6379") + DaftContext(_daft_execution_config=, _daft_planning_config=, _runner_config=_RayRunnerConfig(address='127.0.0.1:6379', max_task_backlog=None), _disallow_set_runner=True, _runner=None) + >>> df = daft.from_pydict({ + ... 'text': ['hello', 'world'] + ... }) + 2024-07-29 15:49:26,610 INFO worker.py:1567 -- Connecting to existing Ray cluster at address: 127.0.0.1:6379... + 2024-07-29 15:49:26,622 INFO worker.py:1752 -- Connected to Ray cluster. + >>> print(df) + ╭───────╮ + │ text │ + │ --- │ + │ Utf8 │ + ╞═══════╡ + │ hello │ + ├╌╌╌╌╌╌╌┤ + │ world │ + ╰───────╯ + + (Showing first 2 of 2 rows) + ``` + +## Daft Launcher + +Daft Launcher is a convenient command-line tool that provides simple abstractions over Ray, enabling a quick uptime for users to leverage Daft for distributed computations. Rather than worrying about the complexities of managing Ray, users can simply run a few CLI commands to spin up a cluster, submit a job, observe the status of jobs and clusters, and spin down a cluster. + +### Prerequisites + +The following should be installed on your machine: + +- The [AWS CLI](https://aws.amazon.com/cli) tool. (Assuming you're using AWS as your cloud provider) + +- A python package manager. We recommend using `uv` to manage everything (i.e., dependencies, as well as the python version itself). It's much cleaner and faster than `pip`. + +### Install Daft Launcher + +Run the following commands in your terminal to initialize your project: + +```bash +# Create a project directory +cd some/working/directory +mkdir launch-test +cd launch-test + +# Initialize the project +uv init --python 3.12 +uv venv +source .venv/bin/activate + +# Install Daft Launcher +uv pip install "daft-launcher" +``` + +In your virtual environment, you should have Daft launcher installed — you can verify this by running `daft --version` which will return the latest version of Daft launcher available. You should also have a basic working directly that may look something like this: + +```bash +/ +|- .venv/ +|- hello.py +|- pyproject.toml +|- README.md +|- .python-version +``` + +### Configure AWS Credentials + +Establish an SSO connection to configure your AWS credentials: + +```bash +# Configure your SSO +aws configure sso + +# Login to your SSO +aws sso login +``` + +These commands should open your browsers. Accept the prompted requests and then return to your terminal, you should see a success message from the AWS CLI tool. At this point, your AWS CLI tool has been configured and your environment is fully setup. + +### Initialize Configuration File + +Initialize a default configuration file to store default values that you can later tune, and they are denoted as required and optional respectively. + +```python +# Initialize the default .daft.toml configuration file +daft init-config + +# Optionally you can also specify a custom name for your file +daft init-config my-custom-config.toml +``` + +Fill out the required values in your `.daft.toml` file. Optional configurations will have a default value pre-defined. + +```toml +[setup] + +# (required) +# The name of the cluster. +name = ... + +# (required) +# The cloud provider that this cluster will be created in. +# Has to be one of the following: +# - "aws" +# - "gcp" +# - "azure" +provider = ... + +# (optional; default = None) +# The IAM instance profile ARN which will provide this cluster with the necessary permissions to perform whatever actions. +# Please note that if you don't specify this field, Ray will create an automatic instance profile for you. +# That instance profile will be minimal and may restrict some of the feature of Daft. +iam_instance_profile_arn = ... + +# (required) +# The AWS region in which to place this cluster. +region = ... + +# (optional; default = "ec2-user") +# The ssh user name when connecting to the cluster. +ssh_user = ... + +# (optional; default = 2) +# The number of worker nodes to create in the cluster. +number_of_workers = ... + +# (optional; default = "m7g.medium") +# The instance type to use for the head and worker nodes. +instance_type = ... + +# (optional; default = "ami-01c3c55948a949a52") +# The AMI ID to use for the head and worker nodes. +image_id = ... + +# (optional; default = []) +# A list of dependencies to install on the head and worker nodes. +# These will be installed using UV (https://docs.astral.sh/uv/). +dependencies = [...] + +[run] + +# (optional; default = ['echo "Hello, World!"']) +# Any post-setup commands that you want to invoke manually. +# This is a good location to install any custom dependencies or run some arbitrary script. +setup_commands = [...] + +``` + +### Spin Up a Cluster + +`daft up` will spin up a cluster given the configuration file you initialized earlier. The configuration file contains all required information necessary for Daft launcher to know how to spin up a cluster. + +```python +# Spin up a cluster using the default .daft.toml configuration file created earlier +daft up + +# Alternatively spin up a cluster using a custom configuration file created earlier +daft up -c my-custom-config.toml +``` + +This command will do a couple of things: + +1. First, it will reach into your cloud provider and spin up the necessary resources. This includes things such as the worker nodes, security groups, permissions, etc. + +2. When the nodes are spun up, the ray and daft dependencies will be downloaded into a python virtual environment. + +3. Next, any other custom dependencies that you've specified in the configuration file will then be downloaded. + +4. Finally, the setup commands that you've specified in the configuration file will be run on the head node. + +!!! note "Note" + + `daft up` will only return successfully when the head node is fully set up. Even though the command will request the worker nodes to also spin up, it will not wait for them to be spun up before returning. Therefore, when the command completes and you type `daft list`, the worker nodes may be in a “pending” state immediately after. Give it a few seconds and they should be fully running. + +### Submit a Job + +`daft submit` enables you to submit a working directory and command or a “job” to the remote cluster to be run. + +```python +# Submit a job using the default .daft.toml configuration file +daft submit -i my-keypair.pem -w my-working-director + +# Alternatively submit a job using a custom configuration file +daft submit -c my-custom-config.toml -i my-keypair.pem -w my-working-director +``` + +### Run a SQL Query + +Daft supports SQL API so you can use `daft sql` to run raw SQL queries against your data. The SQL dialect is the postgres standard. + +```python +# Run a sql query using the default .daft.toml configuration file +daft sql -- "\"SELECT * FROM my_table\"" + +# Alternatively you can run a sql query using a custom configuration file +daft sql -c my-custom-config.toml -- "\"SELECT * FROM my_table\"" +``` + +### View Ray Dashboard + +You can view the Ray dashboard of your running cluster with `daft connect` which establishes a port-forward over SSH from your local machine to the head node of the cluster (connecting `localhost:8265` to the remote head's `8265`). + +```python +# Establish the port-forward using the default .daft.toml configuration file +daft connect -i my-keypair.pem + +# Alternatively establish the port-forward using a custom configuration file +daft connect -c my-custom-config.toml -i my-keypair.pem +``` + +!!! note "Note" + + `daft connect` will require you to have the appropriate SSH keypair to authenticate against the remote head’s public SSH keypair. Make sure to pass this SSH keypair as an argument to the command. + +### Spin Down a Cluster + +`daft down` will spin down all instances of the cluster specified in the configuration file, not just the head node. + +```python +# Spin down a cluster using the default .daft.toml configuration file +daft down + +# Alternatively spin down a cluster using a custom configuration file +daft down -c my-custom-config.toml +``` + +### List Running and Terminated Clusters + +`daft list` allows you to view the current state of all clusters, running and terminated, and includes each instance name and their given IPs (assuming the cluster is running). Here’s an example output after running `daft list`: + +```python +Running: + - daft-demo, head, i-053f9d4856d92ea3d, 35.94.91.91 + - daft-demo, worker, i-00c340dc39d54772d + - daft-demo, worker, i-042a96ce1413c1dd6 +``` + +Say we spun up another cluster `new-cluster` and then terminated it, here’s what the output of `daft list` would look like immediately after: + +```python +Running: + - daft-demo, head, i-053f9d4856d92ea3d, 35.94.91.91 + - daft-demo, worker, i-00c340dc39d54772d, 44.234.112.173 + - daft-demo, worker, i-042a96ce1413c1dd6, 35.94.206.130 +Shutting-down: + - new-cluster, head, i-0be0db9803bd06652, 35.86.200.101 + - new-cluster, worker, i-056f46bd69e1dd3f1, 44.242.166.108 + - new-cluster, worker, i-09ff0e1d8e67b8451, 35.87.221.180 +``` + +In a few seconds later, the state of `new-cluster` will be finalized to “Terminated”. diff --git a/docs-v2/integrations/delta_lake.md b/docs-v2/integrations/delta_lake.md index d0c92bed5c..e19f05bb88 100644 --- a/docs-v2/integrations/delta_lake.md +++ b/docs-v2/integrations/delta_lake.md @@ -4,7 +4,7 @@ Daft currently supports: -1. **Parallel + Distributed Reads:** Daft parallelizes Delta Lake table reads over all cores of your machine, if using the default multithreading runner, or all cores + machines of your Ray cluster, if using the [distributed Ray runner](../advanced/distributed.md). +1. **Parallel + Distributed Reads:** Daft parallelizes Delta Lake table reads over all cores of your machine, if using the default multithreading runner, or all cores + machines of your Ray cluster, if using the [distributed Ray runner](../distributed.md). 2. **Skipping Filtered Data:** Daft ensures that only data that matches your [`df.where(...)`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/dataframe_methods/daft.DataFrame.where.html#daft.DataFrame.where) filter will be read, often skipping entire files/partitions. diff --git a/docs-v2/integrations/hudi.md b/docs-v2/integrations/hudi.md index e0a28ec7da..6c5b33fb01 100644 --- a/docs-v2/integrations/hudi.md +++ b/docs-v2/integrations/hudi.md @@ -4,7 +4,7 @@ Daft currently supports: -1. **Parallel + Distributed Reads:** Daft parallelizes Hudi table reads over all cores of your machine, if using the default multithreading runner, or all cores + machines of your Ray cluster, if using the [distributed Ray runner](../advanced/distributed.md). +1. **Parallel + Distributed Reads:** Daft parallelizes Hudi table reads over all cores of your machine, if using the default multithreading runner, or all cores + machines of your Ray cluster, if using the [distributed Ray runner](../distributed.md). 2. **Skipping Filtered Data:** Daft ensures that only data that matches your [`df.where(...)`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/dataframe_methods/daft.DataFrame.where.html#daft.DataFrame.where) filter will be read, often skipping entire files/partitions. diff --git a/docs-v2/integrations/ray.md b/docs-v2/integrations/ray.md index 0517248aa7..55c334ba35 100644 --- a/docs-v2/integrations/ray.md +++ b/docs-v2/integrations/ray.md @@ -1,5 +1,8 @@ # Ray +!!! failure "todo(docs): add reference to daft launcher" + + [Ray](https://docs.ray.io/en/latest/ray-overview/index.html) is an open-source framework for distributed computing. Daft's native support for Ray enables you to run distributed DataFrame workloads at scale. ## Usage diff --git a/docs-v2/integrations/sql.md b/docs-v2/integrations/sql.md index 95676d0082..ce32e8ca65 100644 --- a/docs-v2/integrations/sql.md +++ b/docs-v2/integrations/sql.md @@ -6,7 +6,7 @@ Daft currently supports: 1. **20+ SQL Dialects:** Daft supports over 20 databases, data warehouses, and query engines by using [SQLGlot](https://sqlglot.com/sqlglot.html) to convert SQL queries across dialects. See the full list of supported dialects [here](https://sqlglot.com/sqlglot/dialects.html). -2. **Parallel + Distributed Reads:** Daft parallelizes SQL reads by using all local machine cores with its default multithreading runner, or all cores across multiple machines if using the [distributed Ray runner](../advanced/distributed.md). +2. **Parallel + Distributed Reads:** Daft parallelizes SQL reads by using all local machine cores with its default multithreading runner, or all cores across multiple machines if using the [distributed Ray runner](../distributed.md). 3. **Skipping Filtered Data:** Daft ensures that only data that matches your [`df.select(...)`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/dataframe_methods/daft.DataFrame.select.html#daft.DataFrame.select), [`df.limit(...)`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/dataframe_methods/daft.DataFrame.limit.html#daft.DataFrame.limit), and [`df.where(...)`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/dataframe_methods/daft.DataFrame.where.html#daft.DataFrame.where) expressions will be read, often skipping entire partitions/columns. @@ -80,7 +80,7 @@ You can also directly provide a SQL alchemy connection via a **connection factor ## Parallel + Distributed Reads -For large datasets, Daft can parallelize SQL reads by using all local machine cores with its default multithreading runner, or all cores across multiple machines if using the [distributed Ray runner](../advanced/distributed.md). +For large datasets, Daft can parallelize SQL reads by using all local machine cores with its default multithreading runner, or all cores across multiple machines if using the [distributed Ray runner](../distributed.md). Supply the [`daft.read_sql()`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/io_functions/daft.read_sql.html#daft.read_sql) function with a **partition column** and optionally the **number of partitions** to enable parallel reads. diff --git a/docs-v2/migration/dask_migration.md b/docs-v2/migration/dask_migration.md index 359d2d91e5..af2d91b7ce 100644 --- a/docs-v2/migration/dask_migration.md +++ b/docs-v2/migration/dask_migration.md @@ -117,7 +117,7 @@ Dask supports the same data types as pandas. Daft is built to support many more ## Distributed Computing and Remote Clusters -Both Dask and Daft support distributed computing on remote clusters. In Dask, you create a Dask cluster either locally or remotely and perform computations in parallel there. Currently, Daft supports distributed cluster computing [with Ray](../advanced/distributed.md). Support for running Daft computations on Dask clusters is on the roadmap. +Both Dask and Daft support distributed computing on remote clusters. In Dask, you create a Dask cluster either locally or remotely and perform computations in parallel there. Currently, Daft supports distributed cluster computing [with Ray](../distributed.md). Support for running Daft computations on Dask clusters is on the roadmap. Cloud support for both Dask and Daft is the same. diff --git a/docs-v2/quickstart.md b/docs-v2/quickstart.md index c744d99c9e..2372b868fe 100644 --- a/docs-v2/quickstart.md +++ b/docs-v2/quickstart.md @@ -4,6 +4,11 @@ !!! failure "todo(docs): Incorporate SQL examples" +!!! failure "todo(docs): Add link to notebook to DIY (notebook is in docs-v2 dir, but idk how to host on colab)." + +!!! failure "todo(docs): What does the actual output look like for some of these examples?" + + In this quickstart, you will learn the basics of Daft's DataFrame and SQL API and the features that set it apart from frameworks like Pandas, PySpark, Dask, and Ray. @@ -57,6 +62,7 @@ See also [DataFrame Creation](https://www.getdaft.io/projects/docs/en/stable/api (Showing first 4 of 4 rows) + ``` You just created your first DataFrame! diff --git a/docs-v2/terms.md b/docs-v2/terms.md index eb703b37e7..8f8bcd62bb 100644 --- a/docs-v2/terms.md +++ b/docs-v2/terms.md @@ -83,9 +83,9 @@ You can examine a logical plan using [`df.explain()`](https://www.getdaft.io/pro | Clustering spec = { Num partitions = 1 } ``` -## Structured Query Language (SQL) +## SQL -SQL is a common query language for expressing queries over tables of data. Daft exposes a SQL API as an alternative (but often also complementary API) to the Python [`DataFrame`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/dataframe_methods/daft.DataFrame.html#daft.DataFrame) and +[SQL (Structured Query Language)](https://en.wikipedia.org/wiki/SQL) is a common query language for expressing queries over tables of data. Daft exposes a SQL API as an alternative (but often also complementary API) to the Python [`DataFrame`](https://www.getdaft.io/projects/docs/en/stable/api_docs/doc_gen/dataframe_methods/daft.DataFrame.html#daft.DataFrame) and [`Expression`](https://www.getdaft.io/projects/docs/en/stable/api_docs/expressions.html) APIs for building queries. You can use SQL in Daft via the [`daft.sql()`](https://www.getdaft.io/projects/docs/en/stable/api_docs/sql.html#daft.sql) function, and Daft will also convert many SQL-compatible strings into Expressions via [`daft.sql_expr()`](https://www.getdaft.io/projects/docs/en/stable/api_docs/sql.html#daft.sql_expr) for easy interoperability with DataFrames. diff --git a/docs/source/api_docs/expressions.rst b/docs/source/api_docs/expressions.rst index 4d44459215..170268f4f9 100644 --- a/docs/source/api_docs/expressions.rst +++ b/docs/source/api_docs/expressions.rst @@ -94,6 +94,7 @@ Logical Expression.__lt__ Expression.__le__ Expression.__eq__ + Expression.eq_null_safe Expression.__ne__ Expression.__gt__ Expression.__ge__ @@ -169,6 +170,22 @@ The following methods are available under the ``expr.str`` attribute. Expression.str.tokenize_decode Expression.str.count_matches +.. _api-binary-expression-operations: + +Binary +###### + +The following methods are available under the ``expr.binary`` attribute. + +.. autosummary:: + :nosignatures: + :toctree: doc_gen/expression_methods + :template: autosummary/accessor_method.rst + + Expression.binary.concat + Expression.binary.length + Expression.binary.slice + .. _api-float-expression-operations: Floats diff --git a/mkdocs.yml b/mkdocs.yml index cadf521f08..7816c8f330 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -6,6 +6,12 @@ site_name: Daft Documentation docs_dir: docs-v2 +# Scarf pixel for tracking analytics +image: + referrerpolicy: "no-referrer-when-downgrade" + src: "https://static.scarf.sh/a.png?x-pxid=c9065f3a-a090-4243-8f69-145d5de7bfca" + + # Repository repo_name: Daft repo_url: https://github.com/Eventual-Inc/Daft @@ -18,17 +24,10 @@ nav: - Installation: install.md - Quickstart: quickstart.md - Core Concepts: core_concepts.md - # - DataFrame: core_concepts/dataframe.md - # - Expressions: core_concepts/expressions.md - # - Reading/Writing Data: core_concepts/read_write.md - # - DataTypes: core_concepts/datatypes.md - # - SQL: core_concepts/sql.md - # - Aggregations and Grouping: core_concepts/aggregations.md - # - User-Defined Functions (UDF): core_concepts/udf.md + - Distributed Computing: distributed.md - Advanced: - Managing Memory Usage: advanced/memory.md - Partitioning: advanced/partitioning.md - - Distributed Computing: advanced/distributed.md - Integrations: - Ray: integrations/ray.md - Unity Catalog: integrations/unity_catalog.md @@ -58,6 +57,7 @@ theme: features: - search.suggest - search.highlight + - content.code.copy # add copy button to code sections - content.tabs.link # If one tab switches Python to SQL, all tabs switch - toc.follow - toc.integrate # adds page subsections to left-hand menu (instead of right-hand menu) diff --git a/pyproject.toml b/pyproject.toml index 7673bfbf24..8e7e872f80 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,7 +55,7 @@ check-hidden = true ignore-words-list = "crate,arithmetics,ser" # Feel free to un-skip examples, and experimental, you will just need to # work through many typos (--write-changes and --interactive will help) -skip = "tests/series/*,target,.git,.venv,venv,data,*.csv,*.csv.*,*.html,*.json,*.jsonl,*.pdf,*.txt,*.ipynb,*.tiktoken,*.sql" +skip = "tests/series/*,target,.git,.venv,venv,data,*.csv,*.csv.*,*.html,*.json,*.jsonl,*.pdf,*.txt,*.ipynb,*.tiktoken,*.sql,tests/table/utf8/*,tests/table/binary/*" [tool.maturin] # "python" tells pyo3 we want to build an extension module (skips linking against libpython.so) diff --git a/requirements-docs.txt b/requirements-docs.txt new file mode 100644 index 0000000000..bf91dca6c8 --- /dev/null +++ b/requirements-docs.txt @@ -0,0 +1,5 @@ +# not pinned at the moment +markdown-exec +mkdocs-jupyter +mkdocs-material +pymdown-extensions diff --git a/src/arrow2/src/bitmap/bitmap_ops.rs b/src/arrow2/src/bitmap/bitmap_ops.rs index 99e256328c..dc924ee1ea 100644 --- a/src/arrow2/src/bitmap/bitmap_ops.rs +++ b/src/arrow2/src/bitmap/bitmap_ops.rs @@ -211,6 +211,20 @@ pub fn xor(lhs: &Bitmap, rhs: &Bitmap) -> Bitmap { } } +#[inline] +/// Compute bitwise equality operation +pub fn bitwise_eq(lhs: &Bitmap, rhs: &Bitmap) -> Bitmap { + assert_eq!(lhs.len(), rhs.len()); + // Fast path: if bitmaps are identical, return all true + if lhs == rhs { + let mut mutable = MutableBitmap::with_capacity(lhs.len()); + mutable.extend_constant(lhs.len(), true); + mutable.into() + } else { + binary(lhs, rhs, |x, y| !(x ^ y)) // XNOR operation + } +} + fn eq(lhs: &Bitmap, rhs: &Bitmap) -> bool { if lhs.len() != rhs.len() { return false; diff --git a/src/common/io-config/Cargo.toml b/src/common/io-config/Cargo.toml index d132d74ee8..d6b7a4ffd8 100644 --- a/src/common/io-config/Cargo.toml +++ b/src/common/io-config/Cargo.toml @@ -1,7 +1,9 @@ [dependencies] aws-credential-types = {version = "0.55.3"} chrono = {workspace = true} +common-error = {path = "../error", default-features = false} common-py-serde = {path = "../py-serde", default-features = false} +derivative = {workspace = true} derive_more = {workspace = true} pyo3 = {workspace = true, optional = true} secrecy = {version = "0.8.0", features = ["alloc"], default-features = false} @@ -9,7 +11,7 @@ serde = {workspace = true} typetag = {workspace = true} [features] -python = ["dep:pyo3", "common-py-serde/python"] +python = ["dep:pyo3", "common-error/python", "common-py-serde/python"] [lints] workspace = true diff --git a/src/common/io-config/src/lib.rs b/src/common/io-config/src/lib.rs index ae620a112d..eab0c6c803 100644 --- a/src/common/io-config/src/lib.rs +++ b/src/common/io-config/src/lib.rs @@ -1,3 +1,5 @@ +#![feature(let_chains)] + #[cfg(feature = "python")] pub mod python; diff --git a/src/common/io-config/src/python.rs b/src/common/io-config/src/python.rs index 1c28775da6..de60fbbc08 100644 --- a/src/common/io-config/src/python.rs +++ b/src/common/io-config/src/python.rs @@ -2,20 +2,20 @@ use std::{ any::Any, hash::{Hash, Hasher}, sync::Arc, - time::{Duration, SystemTime}, }; -use aws_credential_types::{ - provider::{error::CredentialsError, ProvideCredentials}, - Credentials, -}; +use chrono::{DateTime, Utc}; +use common_error::DaftResult; use common_py_serde::{ deserialize_py_object, impl_bincode_py_state_serialization, serialize_py_object, }; use pyo3::prelude::*; use serde::{Deserialize, Serialize}; -use crate::{config, s3::S3CredentialsProvider}; +use crate::{ + config, + s3::{S3CredentialsProvider, S3CredentialsProviderWrapper}, +}; /// Create configurations to be used when accessing an S3-compatible system /// @@ -60,10 +60,11 @@ pub struct S3Config { /// expiry (datetime.datetime, optional): Expiry time of the credentials, credentials are assumed to be permanent if not provided /// /// Example: +/// >>> from datetime import datetime, timedelta, timezone /// >>> get_credentials = lambda: S3Credentials( /// ... key_id="xxx", /// ... access_key="xxx", -/// ... expiry=(datetime.datetime.now() + datetime.timedelta(hours=1)) +/// ... expiry=(datetime.now(timezone.utc) + timedelta(hours=1)) /// ... ) /// >>> io_config = IOConfig(s3=S3Config(credentials_provider=get_credentials)) /// >>> daft.read_parquet("s3://some-path", io_config=io_config) @@ -309,8 +310,9 @@ impl S3Config { access_key: access_key.map(std::convert::Into::into).or(def.access_key), credentials_provider: credentials_provider .map(|p| { - Ok::<_, PyErr>(Box::new(PyS3CredentialsProvider::new(p)?) - as Box) + Ok::<_, PyErr>(S3CredentialsProviderWrapper::new( + PyS3CredentialsProvider::new(p)?, + )) }) .transpose()? .or(def.credentials_provider), @@ -394,8 +396,9 @@ impl S3Config { .or_else(|| self.config.access_key.clone()), credentials_provider: credentials_provider .map(|p| { - Ok::<_, PyErr>(Box::new(PyS3CredentialsProvider::new(p)?) - as Box) + Ok::<_, PyErr>(S3CredentialsProviderWrapper::new( + PyS3CredentialsProvider::new(p)?, + )) }) .transpose()? .or_else(|| self.config.credentials_provider.clone()), @@ -489,7 +492,8 @@ impl S3Config { #[getter] pub fn credentials_provider(&self, py: Python) -> PyResult>> { Ok(self.config.credentials_provider.as_ref().and_then(|p| { - p.as_any() + p.provider + .as_any() .downcast_ref::() .map(|p| p.provider.clone_ref(py)) })) @@ -572,6 +576,18 @@ impl S3Config { pub fn profile_name(&self) -> PyResult> { Ok(self.config.profile_name.clone()) } + + pub fn provide_cached_credentials(&self) -> PyResult> { + self.config + .credentials_provider + .as_ref() + .map(|provider| { + Ok(S3Credentials { + credentials: provider.get_cached_credentials()?, + }) + }) + .transpose() + } } #[pymethods] @@ -579,66 +595,47 @@ impl S3Credentials { #[new] #[pyo3(signature = (key_id, access_key, session_token=None, expiry=None))] pub fn new( - py: Python, key_id: String, access_key: String, session_token: Option, - expiry: Option>, - ) -> PyResult { - // TODO(Kevin): Refactor when upgrading to PyO3 0.21 (https://github.com/Eventual-Inc/Daft/issues/2288) - let expiry = expiry - .map(|e| { - let ts = e.call_method0(pyo3::intern!(py, "timestamp"))?.extract()?; - - Ok::<_, PyErr>(SystemTime::UNIX_EPOCH + Duration::from_secs_f64(ts)) - }) - .transpose()?; - - Ok(Self { + expiry: Option>, + ) -> Self { + Self { credentials: crate::S3Credentials { key_id, access_key, session_token, expiry, }, - }) + } } - pub fn __repr__(&self) -> PyResult { - Ok(format!("{}", self.credentials)) + pub fn __repr__(&self) -> String { + format!("{}", self.credentials) } /// AWS Access Key ID #[getter] - pub fn key_id(&self) -> PyResult { - Ok(self.credentials.key_id.clone()) + pub fn key_id(&self) -> &str { + &self.credentials.key_id } /// AWS Secret Access Key #[getter] - pub fn access_key(&self) -> PyResult { - Ok(self.credentials.access_key.clone()) + pub fn access_key(&self) -> &str { + &self.credentials.access_key } /// AWS Session Token #[getter] - pub fn expiry<'py>(&self, py: Python<'py>) -> PyResult>> { - // TODO(Kevin): Refactor when upgrading to PyO3 0.21 (https://github.com/Eventual-Inc/Daft/issues/2288) - self.credentials - .expiry - .map(|e| { - let datetime = py.import(pyo3::intern!(py, "datetime"))?; - - datetime - .getattr(pyo3::intern!(py, "datetime"))? - .call_method1( - pyo3::intern!(py, "fromtimestamp"), - (e.duration_since(SystemTime::UNIX_EPOCH) - .unwrap() - .as_secs_f64(),), - ) - }) - .transpose() + pub fn session_token(&self) -> Option<&str> { + self.credentials.session_token.as_deref() + } + + /// AWS Credentials Expiry + #[getter] + pub fn expiry(&self) -> Option> { + self.credentials.expiry } } @@ -662,32 +659,6 @@ impl PyS3CredentialsProvider { } } -impl ProvideCredentials for PyS3CredentialsProvider { - fn provide_credentials<'a>( - &'a self, - ) -> aws_credential_types::provider::future::ProvideCredentials<'a> - where - Self: 'a, - { - aws_credential_types::provider::future::ProvideCredentials::ready( - Python::with_gil(|py| { - let py_creds = self.provider.call0(py)?; - py_creds.extract::(py) - }) - .map_err(|e| CredentialsError::provider_error(Box::new(e))) - .map(|creds| { - Credentials::new( - creds.credentials.key_id, - creds.credentials.access_key, - creds.credentials.session_token, - creds.credentials.expiry, - "daft_custom_provider", - ) - }), - ) - } -} - impl PartialEq for PyS3CredentialsProvider { fn eq(&self, other: &Self) -> bool { self.hash == other.hash @@ -722,6 +693,13 @@ impl S3CredentialsProvider for PyS3CredentialsProvider { fn dyn_hash(&self, mut state: &mut dyn Hasher) { self.hash(&mut state); } + + fn provide_credentials(&self) -> DaftResult { + Python::with_gil(|py| { + let py_creds = self.provider.call0(py)?; + Ok(py_creds.extract::(py)?.credentials) + }) + } } #[pymethods] diff --git a/src/common/io-config/src/s3.rs b/src/common/io-config/src/s3.rs index 41db6c8b29..eba2fd5926 100644 --- a/src/common/io-config/src/s3.rs +++ b/src/common/io-config/src/s3.rs @@ -2,11 +2,16 @@ use std::{ any::Any, fmt::{Debug, Display, Formatter}, hash::{Hash, Hasher}, - time::SystemTime, + sync::{Arc, Mutex}, }; -use aws_credential_types::provider::ProvideCredentials; +use aws_credential_types::{ + provider::{error::CredentialsError, ProvideCredentials}, + Credentials, +}; use chrono::{offset::Utc, DateTime}; +use common_error::DaftResult; +use derivative::Derivative; use serde::{Deserialize, Serialize}; pub use crate::ObfuscatedString; @@ -18,7 +23,7 @@ pub struct S3Config { pub key_id: Option, pub session_token: Option, pub access_key: Option, - pub credentials_provider: Option>, + pub credentials_provider: Option, pub buffer_time: Option, pub max_connections_per_io_thread: u32, pub retry_initial_backoff_ms: u64, @@ -40,15 +45,54 @@ pub struct S3Credentials { pub key_id: String, pub access_key: String, pub session_token: Option, - pub expiry: Option, + pub expiry: Option>, } #[typetag::serde(tag = "type")] -pub trait S3CredentialsProvider: ProvideCredentials + Debug { +pub trait S3CredentialsProvider: Debug + Send + Sync { fn as_any(&self) -> &dyn Any; fn clone_box(&self) -> Box; fn dyn_eq(&self, other: &dyn S3CredentialsProvider) -> bool; fn dyn_hash(&self, state: &mut dyn Hasher); + fn provide_credentials(&self) -> DaftResult; +} + +#[derive(Derivative, Clone, Debug, Deserialize, Serialize)] +#[derivative(PartialEq, Eq, Hash)] +pub struct S3CredentialsProviderWrapper { + pub provider: Box, + #[derivative(PartialEq = "ignore")] + #[derivative(Hash = "ignore")] + cached_creds: Arc>>, +} + +impl S3CredentialsProviderWrapper { + pub fn new(provider: impl S3CredentialsProvider + 'static) -> Self { + Self { + provider: Box::new(provider), + cached_creds: Arc::new(Mutex::new(None)), + } + } + + pub fn get_new_credentials(&self) -> DaftResult { + let creds = self.provider.provide_credentials()?; + *self.cached_creds.lock().unwrap() = Some(creds.clone()); + Ok(creds) + } + + pub fn get_cached_credentials(&self) -> DaftResult { + let mut cached_creds = self.cached_creds.lock().unwrap(); + + if let Some(creds) = cached_creds.clone() + && creds.expiry.map_or(true, |expiry| expiry > Utc::now()) + { + Ok(creds) + } else { + let creds = self.provider.provide_credentials()?; + *cached_creds = Some(creds.clone()); + Ok(creds) + } + } } impl Clone for Box { @@ -71,14 +115,26 @@ impl Hash for Box { } } -impl ProvideCredentials for Box { +impl ProvideCredentials for S3CredentialsProviderWrapper { fn provide_credentials<'a>( &'a self, ) -> aws_credential_types::provider::future::ProvideCredentials<'a> where Self: 'a, { - self.as_ref().provide_credentials() + aws_credential_types::provider::future::ProvideCredentials::ready( + self.get_new_credentials() + .map_err(|e| CredentialsError::provider_error(Box::new(e))) + .map(|creds| { + Credentials::new( + creds.key_id, + creds.access_key, + creds.session_token, + creds.expiry.map(|e| e.into()), + "daft_custom_provider", + ) + }), + ) } } @@ -225,8 +281,6 @@ impl S3Credentials { res.push(format!("Session token = {session_token}")); } if let Some(expiry) = &self.expiry { - let expiry: DateTime = (*expiry).into(); - res.push(format!("Expiry = {}", expiry.format("%Y-%m-%dT%H:%M:%S"))); } res diff --git a/src/common/runtime/src/lib.rs b/src/common/runtime/src/lib.rs index 2c8fc6acdd..df222fcfe9 100644 --- a/src/common/runtime/src/lib.rs +++ b/src/common/runtime/src/lib.rs @@ -69,13 +69,16 @@ impl Future for RuntimeTask { } pub struct Runtime { - runtime: tokio::runtime::Runtime, + pub runtime: Arc, pool_type: PoolType, } impl Runtime { pub(crate) fn new(runtime: tokio::runtime::Runtime, pool_type: PoolType) -> RuntimeRef { - Arc::new(Self { runtime, pool_type }) + Arc::new(Self { + runtime: Arc::new(runtime), + pool_type, + }) } async fn execute_task(future: F, pool_type: PoolType) -> DaftResult diff --git a/src/common/scan-info/src/pushdowns.rs b/src/common/scan-info/src/pushdowns.rs index 8599123fae..39d56fce56 100644 --- a/src/common/scan-info/src/pushdowns.rs +++ b/src/common/scan-info/src/pushdowns.rs @@ -1,7 +1,8 @@ use std::sync::Arc; use common_display::DisplayAs; -use daft_dsl::ExprRef; +use daft_dsl::{estimated_selectivity, ExprRef}; +use daft_schema::schema::Schema; use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] @@ -103,6 +104,14 @@ impl Pushdowns { } res } + + pub fn estimated_selectivity(&self, schema: &Schema) -> f64 { + if let Some(filters) = &self.filters { + estimated_selectivity(filters, schema) + } else { + 1.0 + } + } } impl DisplayAs for Pushdowns { diff --git a/src/daft-catalog/python-catalog/src/python.rs b/src/daft-catalog/python-catalog/src/python.rs index 5d2ae500d6..8eab98ffd1 100644 --- a/src/daft-catalog/python-catalog/src/python.rs +++ b/src/daft-catalog/python-catalog/src/python.rs @@ -86,6 +86,7 @@ impl DataCatalogTable for PythonTable { } /// Wrapper around a `daft.catalog.python_catalog.PythonCatalog` +#[derive(Debug)] pub struct PythonCatalog { python_catalog_pyobj: PyObject, } diff --git a/src/daft-catalog/src/data_catalog.rs b/src/daft-catalog/src/data_catalog.rs index 0cf6d45708..59f6f0f491 100644 --- a/src/daft-catalog/src/data_catalog.rs +++ b/src/daft-catalog/src/data_catalog.rs @@ -4,7 +4,7 @@ use crate::{data_catalog_table::DataCatalogTable, errors::Result}; /// /// It allows registering and retrieving data sources, as well as querying their schemas. /// The catalog is used by the query planner to resolve table references in queries. -pub trait DataCatalog: Sync + Send { +pub trait DataCatalog: Sync + Send + std::fmt::Debug { /// Lists the fully-qualified names of tables in the catalog with the specified prefix fn list_tables(&self, prefix: &str) -> Result>; diff --git a/src/daft-catalog/src/lib.rs b/src/daft-catalog/src/lib.rs index 73f75864c8..87faed0c19 100644 --- a/src/daft-catalog/src/lib.rs +++ b/src/daft-catalog/src/lib.rs @@ -19,11 +19,11 @@ pub mod global_catalog { use lazy_static::lazy_static; - use crate::{DaftMetaCatalog, DataCatalog}; + use crate::{DaftCatalog, DataCatalog}; lazy_static! { - pub(crate) static ref GLOBAL_DAFT_META_CATALOG: RwLock = - RwLock::new(DaftMetaCatalog::new_from_env()); + pub(crate) static ref GLOBAL_DAFT_META_CATALOG: RwLock = + RwLock::new(DaftCatalog::new_from_env()); } /// Register a DataCatalog with the global DaftMetaCatalog @@ -50,7 +50,8 @@ static DEFAULT_CATALOG_NAME: &str = "default"; /// /// Users of Daft can register various [`DataCatalog`] with Daft, enabling /// discovery of tables across various [`DataCatalog`] implementations. -pub struct DaftMetaCatalog { +#[derive(Debug, Clone, Default)] +pub struct DaftCatalog { /// Map of catalog names to the DataCatalog impls. /// /// NOTE: The default catalog is always named "default" @@ -60,11 +61,11 @@ pub struct DaftMetaCatalog { named_tables: HashMap, } -impl DaftMetaCatalog { +impl DaftCatalog { /// Create a `DaftMetaCatalog` from the current environment pub fn new_from_env() -> Self { // TODO: Parse a YAML file to produce the catalog - DaftMetaCatalog { + DaftCatalog { data_catalogs: default::Default::default(), named_tables: default::Default::default(), } @@ -95,16 +96,25 @@ impl DaftMetaCatalog { } /// Registers a LogicalPlan with a name in the DaftMetaCatalog - pub fn register_named_table(&mut self, name: &str, view: LogicalPlanBuilder) -> Result<()> { + pub fn register_table( + &mut self, + name: &str, + view: impl Into, + ) -> Result<()> { if !name.chars().all(|c| c.is_alphanumeric() || c == '_') { return Err(Error::InvalidTableName { name: name.to_string(), }); } - self.named_tables.insert(name.to_string(), view); + self.named_tables.insert(name.to_string(), view.into()); Ok(()) } + /// Check if a named table is registered in the DaftCatalog + pub fn contains_table(&self, name: &str) -> bool { + self.named_tables.contains_key(name) + } + /// Provides high-level functionality for reading a table of data against a [`DaftMetaCatalog`] /// /// Resolves the provided table_identifier against the catalog: @@ -146,6 +156,15 @@ impl DaftMetaCatalog { table_id: searched_table_name.to_string(), }) } + /// Copy from another catalog, using tables from other in case of conflict + pub fn copy_from(&mut self, other: &Self) { + for (name, plan) in &other.named_tables { + self.named_tables.insert(name.clone(), plan.clone()); + } + for (name, catalog) in &other.data_catalogs { + self.data_catalogs.insert(name.clone(), catalog.clone()); + } + } } #[cfg(test)] @@ -181,26 +200,24 @@ mod tests { #[test] fn test_register_and_unregister_named_table() { - let mut catalog = DaftMetaCatalog::new_from_env(); + let mut catalog = DaftCatalog::new_from_env(); let plan = LogicalPlanBuilder::from(mock_plan()); // Register a table - assert!(catalog - .register_named_table("test_table", plan.clone()) - .is_ok()); + assert!(catalog.register_table("test_table", plan.clone()).is_ok()); // Try to register a table with invalid name assert!(catalog - .register_named_table("invalid name", plan.clone()) + .register_table("invalid name", plan.clone()) .is_err()); } #[test] fn test_read_registered_table() { - let mut catalog = DaftMetaCatalog::new_from_env(); + let mut catalog = DaftCatalog::new_from_env(); let plan = LogicalPlanBuilder::from(mock_plan()); - catalog.register_named_table("test_table", plan).unwrap(); + catalog.register_table("test_table", plan).unwrap(); assert!(catalog.read_table("test_table").is_ok()); assert!(catalog.read_table("non_existent_table").is_err()); diff --git a/src/daft-catalog/src/python.rs b/src/daft-catalog/src/python.rs index a4896402ec..9f4381bd09 100644 --- a/src/daft-catalog/src/python.rs +++ b/src/daft-catalog/src/python.rs @@ -61,7 +61,7 @@ fn py_register_table( global_catalog::GLOBAL_DAFT_META_CATALOG .write() .unwrap() - .register_named_table(table_identifier, logical_plan.builder.clone())?; + .register_table(table_identifier, logical_plan.builder.clone())?; Ok(table_identifier.to_string()) } diff --git a/src/daft-connect/Cargo.toml b/src/daft-connect/Cargo.toml index a72d574677..55c972b219 100644 --- a/src/daft-connect/Cargo.toml +++ b/src/daft-connect/Cargo.toml @@ -1,30 +1,51 @@ [dependencies] arrow2 = {workspace = true, features = ["io_json_integration"]} async-stream = "0.3.6" -common-daft-config = {workspace = true} -common-file-formats = {workspace = true} -daft-core = {workspace = true} -daft-dsl = {workspace = true} -daft-local-execution = {workspace = true} -daft-logical-plan = {workspace = true} -daft-micropartition = {workspace = true} -daft-scan = {workspace = true} -daft-schema = {workspace = true} -daft-sql = {workspace = true} -daft-table = {workspace = true} +common-error = {workspace = true, optional = true, features = ["python"]} +common-file-formats = {workspace = true, optional = true, features = ["python"]} +daft-catalog = {path = "../daft-catalog", optional = true, features = ["python"]} +daft-core = {workspace = true, optional = true, features = ["python"]} +daft-dsl = {workspace = true, optional = true, features = ["python"]} +daft-local-execution = {workspace = true, optional = true, features = ["python"]} +daft-logical-plan = {workspace = true, optional = true, features = ["python"]} +daft-micropartition = {workspace = true, optional = true, features = ["python"]} +daft-ray-execution = {workspace = true, optional = true, features = ["python"]} +daft-scan = {workspace = true, optional = true, features = ["python"]} +daft-schema = {workspace = true, optional = true, features = ["python"]} +daft-sql = {workspace = true, optional = true, features = ["python"]} +daft-table = {workspace = true, optional = true, features = ["python"]} dashmap = "6.1.0" eyre = "0.6.12" futures = "0.3.31" itertools = {workspace = true} +once_cell = {workspace = true} pyo3 = {workspace = true, optional = true} spark-connect = {workspace = true} +textwrap = "0.16.1" tokio = {version = "1.40.0", features = ["full"]} tonic = "0.12.3" tracing = {workspace = true} uuid = {version = "1.10.0", features = ["v4"]} +common-runtime.workspace = true [features] -python = ["dep:pyo3", "common-daft-config/python", "daft-local-execution/python", "daft-logical-plan/python", "daft-scan/python", "daft-table/python", "daft-dsl/python", "daft-schema/python", "daft-core/python", "daft-micropartition/python"] +default = ["python"] +python = [ + "dep:pyo3", + "dep:common-error", + "dep:common-file-formats", + "dep:daft-core", + "dep:daft-dsl", + "dep:daft-local-execution", + "dep:daft-logical-plan", + "dep:daft-micropartition", + "dep:daft-ray-execution", + "dep:daft-scan", + "dep:daft-schema", + "dep:daft-sql", + "dep:daft-table", + "dep:daft-catalog" +] [lints] workspace = true diff --git a/src/daft-connect/src/config.rs b/src/daft-connect/src/config.rs index b29215e668..3fb5a22fed 100644 --- a/src/daft-connect/src/config.rs +++ b/src/daft-connect/src/config.rs @@ -6,7 +6,7 @@ use spark_connect::{ }; use tonic::Status; -use crate::Session; +use crate::session::Session; impl Session { fn config_response(&self) -> ConfigResponse { diff --git a/src/daft-connect/src/connect_service.rs b/src/daft-connect/src/connect_service.rs new file mode 100644 index 0000000000..6cf907ff73 --- /dev/null +++ b/src/daft-connect/src/connect_service.rs @@ -0,0 +1,303 @@ +use dashmap::DashMap; +use spark_connect::{ + command::CommandType, plan::OpType, spark_connect_service_server::SparkConnectService, + AddArtifactsRequest, AddArtifactsResponse, AnalyzePlanRequest, AnalyzePlanResponse, + ArtifactStatusesRequest, ArtifactStatusesResponse, ConfigRequest, ConfigResponse, + ExecutePlanRequest, ExecutePlanResponse, FetchErrorDetailsRequest, FetchErrorDetailsResponse, + InterruptRequest, InterruptResponse, Plan, ReattachExecuteRequest, ReleaseExecuteRequest, + ReleaseExecuteResponse, ReleaseSessionRequest, ReleaseSessionResponse, +}; +use tonic::{Request, Response, Status}; +use tracing::debug; +use uuid::Uuid; + +use crate::{ + display::SparkDisplay, + invalid_argument_err, not_yet_implemented, + response_builder::ResponseBuilder, + session::Session, + spark_analyzer::{to_spark_datatype, SparkAnalyzer}, + util::FromOptionalField, +}; + +#[derive(Default)] +pub struct DaftSparkConnectService { + client_to_session: DashMap, // To track session data +} + +impl DaftSparkConnectService { + fn get_session( + &self, + session_id: &str, + ) -> Result, Status> { + let Ok(uuid) = Uuid::parse_str(session_id) else { + return Err(Status::invalid_argument( + "Invalid session_id format, must be a UUID", + )); + }; + + let res = self + .client_to_session + .entry(uuid) + .or_insert_with(|| Session::new(session_id.to_string())); + + Ok(res) + } +} + +#[tonic::async_trait] +impl SparkConnectService for DaftSparkConnectService { + type ExecutePlanStream = std::pin::Pin< + Box> + Send + 'static>, + >; + type ReattachExecuteStream = std::pin::Pin< + Box> + Send + 'static>, + >; + + #[tracing::instrument(skip_all)] + async fn execute_plan( + &self, + request: Request, + ) -> Result, Status> { + let request = request.into_inner(); + + let session = self.get_session(&request.session_id)?; + let operation_id = request + .operation_id + .unwrap_or_else(|| Uuid::new_v4().to_string()); + + let rb = ResponseBuilder::new(&session, operation_id); + + // Proceed with executing the plan... + let plan = request.plan.required("plan")?; + let plan = plan.op_type.required("op_type")?; + + match plan { + OpType::Root(relation) => { + let result = session.execute_command(relation, rb).await?; + Ok(Response::new(result)) + } + OpType::Command(command) => { + let command = command.command_type.required("command_type")?; + match command { + CommandType::WriteOperation(op) => { + let result = session.execute_write_operation(op, rb).await?; + Ok(Response::new(result)) + } + CommandType::CreateDataframeView(create_dataframe) => { + let result = session + .execute_create_dataframe_view(create_dataframe, rb) + .await?; + Ok(Response::new(result)) + } + CommandType::SqlCommand(sql) => { + let result = session.execute_sql_command(sql, rb).await?; + Ok(Response::new(result)) + } + other => { + not_yet_implemented!("CommandType '{:?}'", command_type_to_str(&other)) + } + } + } + } + } + + #[tracing::instrument(skip_all)] + async fn config( + &self, + request: Request, + ) -> Result, Status> { + let request = request.into_inner(); + + let mut session = self.get_session(&request.session_id)?; + + let operation = request + .operation + .and_then(|op| op.op_type) + .required("operation.op_type")?; + + use spark_connect::config_request::operation::OpType; + + let response = match operation { + OpType::Set(op) => session.set(op), + OpType::Get(op) => session.get(op), + OpType::GetWithDefault(op) => session.get_with_default(op), + OpType::GetOption(op) => session.get_option(op), + OpType::GetAll(op) => session.get_all(op), + OpType::Unset(op) => session.unset(op), + OpType::IsModifiable(op) => session.is_modifiable(op), + }?; + + Ok(Response::new(response)) + } + + #[tracing::instrument(skip_all)] + async fn add_artifacts( + &self, + _request: Request>, + ) -> Result, Status> { + not_yet_implemented!("add_artifacts operation") + } + + #[tracing::instrument(skip_all)] + async fn analyze_plan( + &self, + request: Request, + ) -> Result, Status> { + use spark_connect::analyze_plan_request::*; + let request = request.into_inner(); + + let AnalyzePlanRequest { + session_id, + analyze, + .. + } = request; + + let session = self.get_session(&session_id)?; + let rb = ResponseBuilder::new(&session, Uuid::new_v4().to_string()); + + let analyze = analyze.required("analyze")?; + + match analyze { + Analyze::Schema(Schema { plan }) => { + let Plan { op_type } = plan.required("plan")?; + + let OpType::Root(relation) = op_type.required("op_type")? else { + return invalid_argument_err!("op_type must be Root"); + }; + + let translator = SparkAnalyzer::new(&session); + + let result = match translator.relation_to_spark_schema(relation).await { + Ok(schema) => schema, + Err(e) => { + return invalid_argument_err!( + "Failed to translate relation to schema: {e:?}" + ); + } + }; + Ok(Response::new(rb.schema_response(result))) + } + Analyze::DdlParse(DdlParse { ddl_string }) => { + let daft_schema = match daft_sql::sql_schema(&ddl_string) { + Ok(daft_schema) => daft_schema, + Err(e) => return invalid_argument_err!("{e}"), + }; + + let daft_schema = daft_schema.to_struct(); + + let schema = to_spark_datatype(&daft_schema); + + Ok(Response::new(rb.schema_response(schema))) + } + Analyze::TreeString(TreeString { plan, level }) => { + let plan = plan.required("plan")?; + + if let Some(level) = level { + debug!("ignoring tree string level: {level:?}"); + }; + + let OpType::Root(input) = plan.op_type.required("op_type")? else { + return invalid_argument_err!("op_type must be Root"); + }; + + if let Some(common) = &input.common { + if common.origin.is_some() { + debug!("Ignoring common metadata for relation: {common:?}; not yet implemented"); + } + } + + let translator = SparkAnalyzer::new(&session); + let plan = Box::pin(translator.to_logical_plan(input)) + .await + .unwrap() + .build(); + + let schema = plan.schema(); + let tree_string = schema.repr_spark_string(); + Ok(Response::new(rb.treestring_response(tree_string))) + } + other => not_yet_implemented!("Analyze '{other:?}'"), + } + } + + #[tracing::instrument(skip_all)] + async fn artifact_status( + &self, + _request: Request, + ) -> Result, Status> { + not_yet_implemented!("artifact_status operation") + } + + #[tracing::instrument(skip_all)] + async fn interrupt( + &self, + _request: Request, + ) -> Result, Status> { + not_yet_implemented!("interrupt operation") + } + + #[tracing::instrument(skip_all)] + async fn reattach_execute( + &self, + _request: Request, + ) -> Result, Status> { + not_yet_implemented!("reattach_execute operation") + } + + #[tracing::instrument(skip_all)] + async fn release_execute( + &self, + request: Request, + ) -> Result, Status> { + let request = request.into_inner(); + + let session = self.get_session(&request.session_id)?; + + let response = ReleaseExecuteResponse { + session_id: session.client_side_session_id().to_string(), + server_side_session_id: session.server_side_session_id().to_string(), + operation_id: None, // todo: set but not strictly required + }; + + Ok(Response::new(response)) + } + + #[tracing::instrument(skip_all)] + async fn release_session( + &self, + _request: Request, + ) -> Result, Status> { + not_yet_implemented!("release_session operation") + } + + #[tracing::instrument(skip_all)] + async fn fetch_error_details( + &self, + _request: Request, + ) -> Result, Status> { + not_yet_implemented!("fetch_error_details operation") + } +} + +fn command_type_to_str(cmd_type: &CommandType) -> &str { + match cmd_type { + CommandType::RegisterFunction(_) => "RegisterFunction", + CommandType::WriteOperation(_) => "WriteOperation", + CommandType::CreateDataframeView(_) => "CreateDataframeView", + CommandType::WriteOperationV2(_) => "WriteOperationV2", + CommandType::SqlCommand(_) => "SqlCommand", + CommandType::WriteStreamOperationStart(_) => "WriteStreamOperationStart", + CommandType::StreamingQueryCommand(_) => "StreamingQueryCommand", + CommandType::GetResourcesCommand(_) => "GetResourcesCommand", + CommandType::StreamingQueryManagerCommand(_) => "StreamingQueryManagerCommand", + CommandType::RegisterTableFunction(_) => "RegisterTableFunction", + CommandType::StreamingQueryListenerBusCommand(_) => "StreamingQueryListenerBusCommand", + CommandType::RegisterDataSource(_) => "RegisterDataSource", + CommandType::CreateResourceProfileCommand(_) => "CreateResourceProfileCommand", + CommandType::CheckpointCommand(_) => "CheckpointCommand", + CommandType::RemoveCachedRemoteRelationCommand(_) => "RemoveCachedRemoteRelationCommand", + CommandType::MergeIntoTableCommand(_) => "MergeIntoTableCommand", + CommandType::Extension(_) => "Extension", + } +} diff --git a/src/daft-connect/src/display.rs b/src/daft-connect/src/display.rs index 83fce57fb5..8f80402997 100644 --- a/src/daft-connect/src/display.rs +++ b/src/daft-connect/src/display.rs @@ -114,7 +114,6 @@ fn type_to_string(dtype: &DataType) -> String { DataType::FixedShapeTensor(_, _) => "daft.fixed_shape_tensor".to_string(), DataType::SparseTensor(_) => "daft.sparse_tensor".to_string(), DataType::FixedShapeSparseTensor(_, _) => "daft.fixed_shape_sparse_tensor".to_string(), - #[cfg(feature = "python")] DataType::Python => "daft.python".to_string(), DataType::Unknown => "unknown".to_string(), DataType::UInt8 => "arrow.uint8".to_string(), diff --git a/src/daft-connect/src/err.rs b/src/daft-connect/src/err.rs index 4e0377912c..f9b7cf8b77 100644 --- a/src/daft-connect/src/err.rs +++ b/src/daft-connect/src/err.rs @@ -7,9 +7,11 @@ macro_rules! invalid_argument_err { } #[macro_export] -macro_rules! unimplemented_err { - ($arg: tt) => {{ - let msg = format!($arg); +macro_rules! not_yet_implemented { + ($($arg:tt)*) => {{ + let msg = format!($($arg)*); + let msg = format!(r#"Feature: {msg} is not yet implemented, please open an issue at https://github.com/Eventual-Inc/Daft/issues/new?assignees=&labels=enhancement%2Cneeds+triage&projects=&template=feature_request.yaml"#); + Err(::tonic::Status::unimplemented(msg)) }}; } diff --git a/src/daft-connect/src/execute.rs b/src/daft-connect/src/execute.rs new file mode 100644 index 0000000000..d6b443475b --- /dev/null +++ b/src/daft-connect/src/execute.rs @@ -0,0 +1,412 @@ +use std::{future::ready, sync::Arc}; + +use common_error::DaftResult; +use common_file_formats::FileFormat; +use daft_dsl::LiteralValue; +use daft_logical_plan::LogicalPlanBuilder; +use daft_micropartition::MicroPartition; +use daft_ray_execution::RayEngine; +use daft_table::Table; +use eyre::{bail, Context}; +use futures::{ + stream::{self, BoxStream}, + StreamExt, TryStreamExt, +}; +use pyo3::Python; +use spark_connect::{ + relation::RelType, + write_operation::{SaveMode, SaveType}, + CreateDataFrameViewCommand, ExecutePlanResponse, Relation, ShowString, SqlCommand, + WriteOperation, +}; +use tonic::{codegen::tokio_stream::wrappers::ReceiverStream, Status}; +use tracing::debug; + +use crate::{ + not_yet_implemented, response_builder::ResponseBuilder, session::Session, + spark_analyzer::SparkAnalyzer, util::FromOptionalField, ExecuteStream, Runner, +}; + +impl Session { + pub fn get_runner(&self) -> eyre::Result { + let runner = match self.config_values().get("daft.runner") { + Some(runner) => match runner.as_str() { + "ray" => Runner::Ray, + "native" => Runner::Native, + _ => bail!("Invalid runner: {}", runner), + }, + None => Runner::Native, + }; + Ok(runner) + } + + pub async fn run_query( + &self, + lp: LogicalPlanBuilder, + ) -> eyre::Result>>> { + match self.get_runner()? { + Runner::Ray => { + let runner_address = self.config_values().get("daft.runner.ray.address"); + let runner_address = runner_address.map(|s| s.to_string()); + + let runner = RayEngine::try_new(runner_address, None, None)?; + let result_set = tokio::task::spawn_blocking(move || { + Python::with_gil(|py| runner.run_iter_impl(py, lp, None)) + }) + .await??; + + Ok(Box::pin(stream::iter(result_set))) + } + + Runner::Native => { + let this = self.clone(); + + let plan = lp.optimize_async().await?; + + let results = this + .engine + .run(&plan, &*this.psets, Default::default(), None)?; + Ok(results.into_stream().boxed()) + } + } + } + + pub async fn execute_command( + &self, + command: Relation, + res: ResponseBuilder, + ) -> Result { + use futures::{StreamExt, TryStreamExt}; + + let result_complete = res.result_complete_response(); + + let (tx, rx) = tokio::sync::mpsc::channel::>(1); + + let this = self.clone(); + self.compute_runtime.runtime.spawn(async move { + let execution_fut = async { + let translator = SparkAnalyzer::new(&this); + match command.rel_type { + Some(RelType::ShowString(ss)) => { + let response = this.show_string(*ss, res.clone()).await?; + if tx.send(Ok(response)).await.is_err() { + return Ok(()); + } + + Ok(()) + } + _ => { + let lp = translator.to_logical_plan(command).await?; + + let mut result_stream = this.run_query(lp).await?; + + while let Some(result) = result_stream.next().await { + let result = result?; + let tables = result.get_tables()?; + for table in tables.as_slice() { + let response = res.arrow_batch_response(table)?; + if tx.send(Ok(response)).await.is_err() { + return Ok(()); + } + } + } + Ok(()) + } + } + }; + if let Err(e) = execution_fut.await { + let _ = tx.send(Err(e)).await; + } + }); + + let stream = ReceiverStream::new(rx); + + let stream = stream + .map_err(|e| { + Status::internal( + textwrap::wrap(&format!("Error in Daft server: {e}"), 120).join("\n"), + ) + }) + .chain(stream::once(ready(Ok(result_complete)))); + + Ok(Box::pin(stream)) + } + + pub async fn execute_write_operation( + &self, + operation: WriteOperation, + res: ResponseBuilder, + ) -> Result { + fn check_write_operation(write_op: &WriteOperation) -> Result<(), Status> { + if !write_op.sort_column_names.is_empty() { + return not_yet_implemented!("Sort with column names"); + } + if !write_op.partitioning_columns.is_empty() { + return not_yet_implemented!("Partitioning with column names"); + } + if !write_op.clustering_columns.is_empty() { + return not_yet_implemented!("Clustering with column names"); + } + + if let Some(bucket_by) = &write_op.bucket_by { + return not_yet_implemented!("Bucketing by: {:?}", bucket_by); + } + + if !write_op.options.is_empty() { + // todo(completeness): implement options + debug!( + "Ignoring options: {:?} (not yet implemented)", + write_op.options + ); + } + + let mode = SaveMode::try_from(write_op.mode) + .map_err(|_| Status::internal("invalid write mode"))?; + + if mode == SaveMode::Unspecified { + Ok(()) + } else { + not_yet_implemented!("save mode: {}", mode.as_str_name()) + } + } + + let finished = res.result_complete_response(); + + let (tx, rx) = tokio::sync::mpsc::channel::>(1); + + let this = self.clone(); + + self.compute_runtime.runtime.spawn(async move { + let result = async { + check_write_operation(&operation)?; + + let WriteOperation { + input, + source, + save_type, + .. + } = operation; + + let input = input.required("input")?; + let source = source.required("source")?; + + let file_format: FileFormat = source.parse()?; + + let Some(save_type) = save_type else { + bail!("Save type is required"); + }; + + let path = match save_type { + SaveType::Path(path) => path, + SaveType::Table(_) => { + return not_yet_implemented!("write to table").map_err(|e| e.into()) + } + }; + + let translator = SparkAnalyzer::new(&this); + + let plan = translator.to_logical_plan(input).await?; + + let plan = plan.table_write(&path, file_format, None, None, None)?; + + let mut result_stream = this.run_query(plan).await?; + + // this is so we make sure the operation is actually done + // before we return + // + // an example where this is important is if we write to a parquet file + // and then read immediately after, we need to wait for the write to finish + while let Some(_result) = result_stream.next().await {} + + Ok(()) + }; + + if let Err(e) = result.await { + let _ = tx.send(Err(e)).await; + } + }); + let stream = ReceiverStream::new(rx); + + let stream = stream + .map_err(|e| { + Status::internal( + textwrap::wrap(&format!("Error in Daft server: {e}"), 120).join("\n"), + ) + }) + .chain(stream::once(ready(Ok(finished)))); + + Ok(Box::pin(stream)) + } + + pub async fn execute_create_dataframe_view( + &self, + create_dataframe: CreateDataFrameViewCommand, + rb: ResponseBuilder, + ) -> Result { + let CreateDataFrameViewCommand { + input, + name, + is_global, + replace, + } = create_dataframe; + + if is_global { + return not_yet_implemented!("Global dataframe view"); + } + + let input = input.required("input")?; + let input = SparkAnalyzer::new(self) + .to_logical_plan(input) + .await + .map_err(|e| { + Status::internal( + textwrap::wrap(&format!("Error in Daft server: {e}"), 120).join("\n"), + ) + })?; + + { + let catalog = self.catalog.read().unwrap(); + if !replace && catalog.contains_table(&name) { + return Err(Status::internal("Dataframe view already exists")); + } + } + + let mut catalog = self.catalog.write().unwrap(); + + catalog.register_table(&name, input).map_err(|e| { + Status::internal(textwrap::wrap(&format!("Error in Daft server: {e}"), 120).join("\n")) + })?; + + let response = rb.result_complete_response(); + let stream = stream::once(ready(Ok(response))); + Ok(Box::pin(stream)) + } + + #[allow(deprecated)] + pub async fn execute_sql_command( + &self, + SqlCommand { + sql, + args, + pos_args, + named_arguments, + pos_arguments, + input, + }: SqlCommand, + res: ResponseBuilder, + ) -> Result { + if !args.is_empty() { + return not_yet_implemented!("Named arguments"); + } + if !pos_args.is_empty() { + return not_yet_implemented!("Positional arguments"); + } + if !named_arguments.is_empty() { + return not_yet_implemented!("Named arguments"); + } + if !pos_arguments.is_empty() { + return not_yet_implemented!("Positional arguments"); + } + + if input.is_some() { + return not_yet_implemented!("Input"); + } + + let catalog = self.catalog.read().unwrap(); + let catalog = catalog.clone(); + + let mut planner = daft_sql::SQLPlanner::new(catalog); + + let plan = planner + .plan_sql(&sql) + .wrap_err("Error planning SQL") + .map_err(|e| { + Status::internal( + textwrap::wrap(&format!("Error in Daft server: {e}"), 120).join("\n"), + ) + })?; + + let plan = LogicalPlanBuilder::from(plan); + + // TODO: code duplication + let result_complete = res.result_complete_response(); + + let (tx, rx) = tokio::sync::mpsc::channel::>(1); + + let this = self.clone(); + + tokio::spawn(async move { + let execution_fut = async { + let mut result_stream = this.run_query(plan).await?; + while let Some(result) = result_stream.next().await { + let result = result?; + let tables = result.get_tables()?; + for table in tables.as_slice() { + let response = res.arrow_batch_response(table)?; + if tx.send(Ok(response)).await.is_err() { + return Ok(()); + } + } + } + Ok(()) + }; + if let Err(e) = execution_fut.await { + let _ = tx.send(Err(e)).await; + } + }); + + let stream = ReceiverStream::new(rx); + + let stream = stream + .map_err(|e| { + Status::internal( + textwrap::wrap(&format!("Error in Daft server: {e}"), 120).join("\n"), + ) + }) + .chain(stream::once(ready(Ok(result_complete)))); + + Ok(Box::pin(stream)) + } + + async fn show_string( + &self, + show_string: ShowString, + response_builder: ResponseBuilder, + ) -> eyre::Result { + let translator = SparkAnalyzer::new(self); + + let ShowString { + input, + num_rows, + truncate: _, + vertical, + } = show_string; + + if vertical { + bail!("Vertical show string is not supported"); + } + + let input = input.required("input")?; + + let plan = Box::pin(translator.to_logical_plan(*input)).await?; + let plan = plan.limit(num_rows as i64, true)?; + + let results = translator.session.run_query(plan).await?; + let results = results.try_collect::>().await?; + let single_batch = results + .into_iter() + .next() + .ok_or_else(|| eyre::eyre!("No results"))?; + + let tbls = single_batch.get_tables()?; + let tbl = Table::concat(&tbls)?; + let output = tbl.to_comfy_table(None).to_string(); + + let s = LiteralValue::Utf8(output) + .into_single_value_series()? + .rename("show_string"); + + let tbl = Table::from_nonempty_columns(vec![s])?; + response_builder.arrow_batch_response(&tbl) + } +} diff --git a/src/daft-connect/src/functions.rs b/src/daft-connect/src/functions.rs new file mode 100644 index 0000000000..053e686265 --- /dev/null +++ b/src/daft-connect/src/functions.rs @@ -0,0 +1,55 @@ +use std::{collections::HashMap, sync::Arc}; + +use once_cell::sync::Lazy; +use spark_connect::Expression; + +use crate::spark_analyzer::SparkAnalyzer; +mod core; + +pub(crate) static CONNECT_FUNCTIONS: Lazy = Lazy::new(|| { + let mut functions = SparkFunctions::new(); + functions.register::(); + functions +}); + +pub trait SparkFunction: Send + Sync { + fn to_expr( + &self, + args: &[Expression], + analyzer: &SparkAnalyzer, + ) -> eyre::Result; +} + +pub struct SparkFunctions { + pub(crate) map: HashMap>, +} + +impl SparkFunctions { + /// Create a new [SparkFunction] instance. + #[must_use] + pub fn new() -> Self { + Self { + map: HashMap::new(), + } + } + + /// Register the module to the [SparkFunctions] instance. + pub fn register(&mut self) { + M::register(self); + } + /// Add a [FunctionExpr] to the [SparkFunction] instance. + pub fn add_fn(&mut self, name: &str, func: F) { + self.map.insert(name.to_string(), Arc::new(func)); + } + + /// Get a function by name from the [SparkFunctions] instance. + #[must_use] + pub fn get(&self, name: &str) -> Option<&Arc> { + self.map.get(name) + } +} + +pub trait FunctionModule { + /// Register this module to the given [SparkFunctions] table. + fn register(_parent: &mut SparkFunctions); +} diff --git a/src/daft-connect/src/functions/core.rs b/src/daft-connect/src/functions/core.rs new file mode 100644 index 0000000000..61f29d7fde --- /dev/null +++ b/src/daft-connect/src/functions/core.rs @@ -0,0 +1,105 @@ +use daft_core::count_mode::CountMode; +use daft_dsl::{binary_op, col, ExprRef, Operator}; +use daft_schema::dtype::DataType; +use spark_connect::Expression; + +use super::{FunctionModule, SparkFunction}; +use crate::{invalid_argument_err, spark_analyzer::SparkAnalyzer}; + +// Core functions are the most basic functions such as `+`, `-`, `*`, `/`, not, notnull, etc. +pub struct CoreFunctions; + +impl FunctionModule for CoreFunctions { + fn register(parent: &mut super::SparkFunctions) { + parent.add_fn("==", BinaryOpFunction(Operator::Eq)); + parent.add_fn("!=", BinaryOpFunction(Operator::NotEq)); + parent.add_fn("<", BinaryOpFunction(Operator::Lt)); + parent.add_fn("<=", BinaryOpFunction(Operator::LtEq)); + parent.add_fn(">", BinaryOpFunction(Operator::Gt)); + parent.add_fn(">=", BinaryOpFunction(Operator::GtEq)); + parent.add_fn("+", BinaryOpFunction(Operator::Plus)); + parent.add_fn("-", BinaryOpFunction(Operator::Minus)); + parent.add_fn("*", BinaryOpFunction(Operator::Multiply)); + parent.add_fn("/", BinaryOpFunction(Operator::TrueDivide)); + parent.add_fn("//", BinaryOpFunction(Operator::FloorDivide)); + parent.add_fn("%", BinaryOpFunction(Operator::Modulus)); + parent.add_fn("&", BinaryOpFunction(Operator::And)); + parent.add_fn("|", BinaryOpFunction(Operator::Or)); + parent.add_fn("^", BinaryOpFunction(Operator::Xor)); + parent.add_fn("<<", BinaryOpFunction(Operator::ShiftLeft)); + parent.add_fn(">>", BinaryOpFunction(Operator::ShiftRight)); + parent.add_fn("isnotnull", UnaryFunction(|arg| arg.not_null())); + parent.add_fn("isnull", UnaryFunction(|arg| arg.is_null())); + parent.add_fn("not", UnaryFunction(|arg| arg.not())); + parent.add_fn("sum", UnaryFunction(|arg| arg.sum())); + parent.add_fn("mean", UnaryFunction(|arg| arg.mean())); + parent.add_fn("stddev", UnaryFunction(|arg| arg.stddev())); + parent.add_fn("min", UnaryFunction(|arg| arg.min())); + parent.add_fn("max", UnaryFunction(|arg| arg.max())); + parent.add_fn("count", CountFunction); + } +} + +pub struct BinaryOpFunction(Operator); +pub struct UnaryFunction(fn(ExprRef) -> ExprRef); +pub struct CountFunction; + +impl SparkFunction for BinaryOpFunction { + fn to_expr( + &self, + args: &[Expression], + analyzer: &SparkAnalyzer, + ) -> eyre::Result { + let args = args + .iter() + .map(|arg| analyzer.to_daft_expr(arg)) + .collect::>>()?; + + let [lhs, rhs] = args + .try_into() + .map_err(|args| eyre::eyre!("requires exactly two arguments; got {:?}", args))?; + + Ok(binary_op(self.0, lhs, rhs)) + } +} + +impl SparkFunction for UnaryFunction { + fn to_expr( + &self, + args: &[Expression], + analyzer: &SparkAnalyzer, + ) -> eyre::Result { + match args { + [arg] => { + let arg = analyzer.to_daft_expr(arg)?; + Ok(self.0(arg)) + } + _ => invalid_argument_err!("requires exactly one argument")?, + } + } +} + +impl SparkFunction for CountFunction { + fn to_expr( + &self, + args: &[Expression], + analyzer: &SparkAnalyzer, + ) -> eyre::Result { + match args { + [arg] => { + let arg = analyzer.to_daft_expr(arg)?; + + let arg = if arg.as_literal().and_then(|lit| lit.as_i32()) == Some(1i32) { + col("*") + } else { + arg + }; + + let count = arg.count(CountMode::All).cast(&DataType::Int64); + + Ok(count) + } + _ => invalid_argument_err!("requires exactly one argument")?, + } + } +} diff --git a/src/daft-connect/src/lib.rs b/src/daft-connect/src/lib.rs index bd55024825..23a182a271 100644 --- a/src/daft-connect/src/lib.rs +++ b/src/daft-connect/src/lib.rs @@ -6,37 +6,46 @@ #![feature(stmt_expr_attributes)] #![feature(try_trait_v2_residual)] -use daft_micropartition::partitioning::InMemoryPartitionSetCache; -use dashmap::DashMap; -use eyre::Context; #[cfg(feature = "python")] -use pyo3::types::PyModuleMethods; -use spark_connect::{ - analyze_plan_response, - command::CommandType, - plan::OpType, - spark_connect_service_server::{SparkConnectService, SparkConnectServiceServer}, - AddArtifactsRequest, AddArtifactsResponse, AnalyzePlanRequest, AnalyzePlanResponse, - ArtifactStatusesRequest, ArtifactStatusesResponse, ConfigRequest, ConfigResponse, - ExecutePlanRequest, ExecutePlanResponse, FetchErrorDetailsRequest, FetchErrorDetailsResponse, - InterruptRequest, InterruptResponse, Plan, ReattachExecuteRequest, ReleaseExecuteRequest, - ReleaseExecuteResponse, ReleaseSessionRequest, ReleaseSessionResponse, -}; -use tonic::{transport::Server, Request, Response, Status}; -use tracing::{info, warn}; -use uuid::Uuid; +mod config; + +#[cfg(feature = "python")] +mod connect_service; -use crate::{display::SparkDisplay, session::Session, translation::SparkAnalyzer}; +#[cfg(feature = "python")] +mod functions; -mod config; +#[cfg(feature = "python")] mod display; +#[cfg(feature = "python")] mod err; -mod op; - +#[cfg(feature = "python")] +mod execute; +#[cfg(feature = "python")] +mod response_builder; +#[cfg(feature = "python")] mod session; -mod translation; +#[cfg(feature = "python")] +mod spark_analyzer; +#[cfg(feature = "python")] pub mod util; +#[cfg(feature = "python")] +use connect_service::DaftSparkConnectService; +#[cfg(feature = "python")] +use eyre::Context; +#[cfg(feature = "python")] +use pyo3::types::PyModuleMethods; +#[cfg(feature = "python")] +use spark_connect::spark_connect_service_server::{SparkConnectService, SparkConnectServiceServer}; +#[cfg(feature = "python")] +use tonic::transport::Server; +#[cfg(feature = "python")] +use tracing::info; + +#[cfg(feature = "python")] +pub type ExecuteStream = ::ExecutePlanStream; + #[cfg_attr(feature = "python", pyo3::pyclass)] pub struct ConnectionHandle { shutdown_signal: Option>, @@ -57,6 +66,7 @@ impl ConnectionHandle { } } +#[cfg(feature = "python")] pub fn start(addr: &str) -> eyre::Result { info!("Daft-Connect server listening on {addr}"); let addr = util::parse_spark_connect_address(addr)?; @@ -74,10 +84,10 @@ pub fn start(addr: &str) -> eyre::Result { shutdown_signal: Some(shutdown_signal), port, }; + let runtime = common_runtime::get_io_runtime(true); std::thread::spawn(move || { - let runtime = tokio::runtime::Runtime::new().unwrap(); - let result = runtime.block_on(async { + let result = runtime.block_on_current_thread(async { let incoming = { let listener = tokio::net::TcpListener::from_std(listener) .wrap_err("Failed to create TcpListener from std::net::TcpListener")?; @@ -117,346 +127,14 @@ pub fn start(addr: &str) -> eyre::Result { Ok(handle) } -#[derive(Default)] -pub struct DaftSparkConnectService { - client_to_session: DashMap, // To track session data -} - -impl DaftSparkConnectService { - fn get_session( - &self, - session_id: &str, - ) -> Result, Status> { - let Ok(uuid) = Uuid::parse_str(session_id) else { - return Err(Status::invalid_argument( - "Invalid session_id format, must be a UUID", - )); - }; - - let res = self - .client_to_session - .entry(uuid) - .or_insert_with(|| Session::new(session_id.to_string())); - - Ok(res) - } -} - -#[tonic::async_trait] -impl SparkConnectService for DaftSparkConnectService { - type ExecutePlanStream = std::pin::Pin< - Box> + Send + 'static>, - >; - type ReattachExecuteStream = std::pin::Pin< - Box> + Send + 'static>, - >; - - #[tracing::instrument(skip_all)] - async fn execute_plan( - &self, - request: Request, - ) -> Result, Status> { - let request = request.into_inner(); - - let session = self.get_session(&request.session_id)?; - - let Some(operation) = request.operation_id else { - return invalid_argument_err!("Operation ID is required"); - }; - - // Proceed with executing the plan... - let Some(plan) = request.plan else { - return invalid_argument_err!("Plan is required"); - }; - - let Some(plan) = plan.op_type else { - return invalid_argument_err!("Plan operation is required"); - }; - - use spark_connect::plan::OpType; - - match plan { - OpType::Root(relation) => { - let result = session.handle_root_command(relation, operation).await?; - return Ok(Response::new(result)); - } - OpType::Command(command) => { - let Some(command) = command.command_type else { - return invalid_argument_err!("Command type is required"); - }; - - match command { - CommandType::RegisterFunction(_) => { - unimplemented_err!("RegisterFunction not implemented") - } - CommandType::WriteOperation(op) => { - let result = session.handle_write_command(op, operation).await?; - return Ok(Response::new(result)); - } - CommandType::CreateDataframeView(_) => { - unimplemented_err!("CreateDataframeView not implemented") - } - CommandType::WriteOperationV2(_) => { - unimplemented_err!("WriteOperationV2 not implemented") - } - CommandType::SqlCommand(..) => { - unimplemented_err!("SQL execution not yet implemented") - } - CommandType::WriteStreamOperationStart(_) => { - unimplemented_err!("WriteStreamOperationStart not implemented") - } - CommandType::StreamingQueryCommand(_) => { - unimplemented_err!("StreamingQueryCommand not implemented") - } - CommandType::GetResourcesCommand(_) => { - unimplemented_err!("GetResourcesCommand not implemented") - } - CommandType::StreamingQueryManagerCommand(_) => { - unimplemented_err!("StreamingQueryManagerCommand not implemented") - } - CommandType::RegisterTableFunction(_) => { - unimplemented_err!("RegisterTableFunction not implemented") - } - CommandType::StreamingQueryListenerBusCommand(_) => { - unimplemented_err!("StreamingQueryListenerBusCommand not implemented") - } - CommandType::RegisterDataSource(_) => { - unimplemented_err!("RegisterDataSource not implemented") - } - CommandType::CreateResourceProfileCommand(_) => { - unimplemented_err!("CreateResourceProfileCommand not implemented") - } - CommandType::CheckpointCommand(_) => { - unimplemented_err!("CheckpointCommand not implemented") - } - CommandType::RemoveCachedRemoteRelationCommand(_) => { - unimplemented_err!("RemoveCachedRemoteRelationCommand not implemented") - } - CommandType::MergeIntoTableCommand(_) => { - unimplemented_err!("MergeIntoTableCommand not implemented") - } - CommandType::Extension(_) => unimplemented_err!("Extension not implemented"), - } - } - }? - } - - #[tracing::instrument(skip_all)] - async fn config( - &self, - request: Request, - ) -> Result, Status> { - let request = request.into_inner(); - - let mut session = self.get_session(&request.session_id)?; - - let Some(operation) = request.operation.and_then(|op| op.op_type) else { - return Err(Status::invalid_argument("Missing operation")); - }; - - use spark_connect::config_request::operation::OpType; - - let response = match operation { - OpType::Set(op) => session.set(op), - OpType::Get(op) => session.get(op), - OpType::GetWithDefault(op) => session.get_with_default(op), - OpType::GetOption(op) => session.get_option(op), - OpType::GetAll(op) => session.get_all(op), - OpType::Unset(op) => session.unset(op), - OpType::IsModifiable(op) => session.is_modifiable(op), - }?; - - Ok(Response::new(response)) - } - - #[tracing::instrument(skip_all)] - async fn add_artifacts( - &self, - _request: Request>, - ) -> Result, Status> { - unimplemented_err!("add_artifacts operation is not yet implemented") - } - - #[tracing::instrument(skip_all)] - async fn analyze_plan( - &self, - request: Request, - ) -> Result, Status> { - use spark_connect::analyze_plan_request::*; - let request = request.into_inner(); - - let AnalyzePlanRequest { - session_id, - analyze, - .. - } = request; - - let Some(analyze) = analyze else { - return Err(Status::invalid_argument("analyze is required")); - }; - - match analyze { - Analyze::Schema(Schema { plan }) => { - let Some(Plan { op_type }) = plan else { - return Err(Status::invalid_argument("plan is required")); - }; - - let Some(OpType::Root(relation)) = op_type else { - return Err(Status::invalid_argument("op_type is required to be root")); - }; - - let result = match translation::relation_to_spark_schema(relation).await { - Ok(schema) => schema, - Err(e) => { - return invalid_argument_err!( - "Failed to translate relation to schema: {e:?}" - ); - } - }; - - let schema = analyze_plan_response::Schema { - schema: Some(result), - }; - - let response = AnalyzePlanResponse { - session_id, - server_side_session_id: String::new(), - result: Some(analyze_plan_response::Result::Schema(schema)), - }; - - Ok(Response::new(response)) - } - Analyze::DdlParse(DdlParse { ddl_string }) => { - let daft_schema = match daft_sql::sql_schema(&ddl_string) { - Ok(daft_schema) => daft_schema, - Err(e) => return invalid_argument_err!("{e}"), - }; - - let daft_schema = daft_schema.to_struct(); - - let schema = translation::to_spark_datatype(&daft_schema); - - let schema = analyze_plan_response::Schema { - schema: Some(schema), - }; - - let response = AnalyzePlanResponse { - session_id, - server_side_session_id: String::new(), - result: Some(analyze_plan_response::Result::Schema(schema)), - }; - - Ok(Response::new(response)) - } - Analyze::TreeString(TreeString { plan, level }) => { - let Some(plan) = plan else { - return invalid_argument_err!("plan is required"); - }; - - if let Some(level) = level { - warn!("ignoring tree string level: {level:?}"); - }; - - let Some(op_type) = plan.op_type else { - return invalid_argument_err!("op_type is required"); - }; - - let OpType::Root(input) = op_type else { - return invalid_argument_err!("op_type must be Root"); - }; - - if let Some(common) = &input.common { - if common.origin.is_some() { - warn!("Ignoring common metadata for relation: {common:?}; not yet implemented"); - } - } - - // We're just checking the schema here, so we don't need to use a persistent cache as it won't be used - let pset = InMemoryPartitionSetCache::empty(); - let translator = SparkAnalyzer::new(&pset); - let plan = Box::pin(translator.to_logical_plan(input)) - .await - .unwrap() - .build(); - - let schema = plan.schema(); - let tree_string = schema.repr_spark_string(); - - let response = AnalyzePlanResponse { - session_id, - server_side_session_id: String::new(), - result: Some(analyze_plan_response::Result::TreeString( - analyze_plan_response::TreeString { tree_string }, - )), - }; - - Ok(Response::new(response)) - } - other => unimplemented_err!("Analyze plan operation is not yet implemented: {other:?}"), - } - } - - #[tracing::instrument(skip_all)] - async fn artifact_status( - &self, - _request: Request, - ) -> Result, Status> { - unimplemented_err!("artifact_status operation is not yet implemented") - } - - #[tracing::instrument(skip_all)] - async fn interrupt( - &self, - _request: Request, - ) -> Result, Status> { - unimplemented_err!("interrupt operation is not yet implemented") - } - - #[tracing::instrument(skip_all)] - async fn reattach_execute( - &self, - _request: Request, - ) -> Result, Status> { - unimplemented_err!("reattach_execute operation is not yet implemented") - } - - #[tracing::instrument(skip_all)] - async fn release_execute( - &self, - request: Request, - ) -> Result, Status> { - let request = request.into_inner(); - - let session = self.get_session(&request.session_id)?; - - let response = ReleaseExecuteResponse { - session_id: session.client_side_session_id().to_string(), - server_side_session_id: session.server_side_session_id().to_string(), - operation_id: None, // todo: set but not strictly required - }; - - Ok(Response::new(response)) - } - - #[tracing::instrument(skip_all)] - async fn release_session( - &self, - _request: Request, - ) -> Result, Status> { - unimplemented_err!("release_session operation is not yet implemented") - } - - #[tracing::instrument(skip_all)] - async fn fetch_error_details( - &self, - _request: Request, - ) -> Result, Status> { - unimplemented_err!("fetch_error_details operation is not yet implemented") - } +#[cfg(feature = "python")] +pub enum Runner { + Ray, + Native, } #[cfg(feature = "python")] -#[pyo3::pyfunction] +#[cfg_attr(feature = "python", pyo3::pyfunction)] #[pyo3(name = "connect_start", signature = (addr = "sc://0.0.0.0:0"))] pub fn py_connect_start(addr: &str) -> pyo3::PyResult { start(addr).map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e:?}"))) diff --git a/src/daft-connect/src/op.rs b/src/daft-connect/src/op.rs deleted file mode 100644 index 2e8bdddf98..0000000000 --- a/src/daft-connect/src/op.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod execute; diff --git a/src/daft-connect/src/op/execute.rs b/src/daft-connect/src/op/execute.rs deleted file mode 100644 index 41baf88b09..0000000000 --- a/src/daft-connect/src/op/execute.rs +++ /dev/null @@ -1,96 +0,0 @@ -use arrow2::io::ipc::write::StreamWriter; -use daft_table::Table; -use eyre::Context; -use spark_connect::{ - execute_plan_response::{ArrowBatch, ResponseType, ResultComplete}, - spark_connect_service_server::SparkConnectService, - ExecutePlanResponse, -}; -use uuid::Uuid; - -use crate::{DaftSparkConnectService, Session}; - -mod root; -mod write; - -pub type ExecuteStream = ::ExecutePlanStream; - -pub struct PlanIds { - session: String, - server_side_session: String, - operation: String, -} - -impl PlanIds { - pub fn new( - client_side_session_id: impl Into, - server_side_session_id: impl Into, - ) -> Self { - let client_side_session_id = client_side_session_id.into(); - let server_side_session_id = server_side_session_id.into(); - Self { - session: client_side_session_id, - server_side_session: server_side_session_id, - operation: Uuid::new_v4().to_string(), - } - } - - pub fn finished(&self) -> ExecutePlanResponse { - ExecutePlanResponse { - session_id: self.session.to_string(), - server_side_session_id: self.server_side_session.to_string(), - operation_id: self.operation.to_string(), - response_id: Uuid::new_v4().to_string(), - metrics: None, - observed_metrics: vec![], - schema: None, - response_type: Some(ResponseType::ResultComplete(ResultComplete {})), - } - } - - pub fn gen_response(&self, table: &Table) -> eyre::Result { - let mut data = Vec::new(); - - let mut writer = StreamWriter::new( - &mut data, - arrow2::io::ipc::write::WriteOptions { compression: None }, - ); - - let row_count = table.num_rows(); - - let schema = table - .schema - .to_arrow() - .wrap_err("Failed to convert Daft schema to Arrow schema")?; - - writer - .start(&schema, None) - .wrap_err("Failed to start Arrow stream writer with schema")?; - - let arrays = table.get_inner_arrow_arrays().collect(); - let chunk = arrow2::chunk::Chunk::new(arrays); - - writer - .write(&chunk, None) - .wrap_err("Failed to write Arrow chunk to stream writer")?; - - let response = ExecutePlanResponse { - session_id: self.session.to_string(), - server_side_session_id: self.server_side_session.to_string(), - operation_id: self.operation.to_string(), - response_id: Uuid::new_v4().to_string(), // todo: implement this - metrics: None, // todo: implement this - observed_metrics: vec![], - schema: None, - response_type: Some(ResponseType::ArrowBatch(ArrowBatch { - row_count: row_count as i64, - data, - start_offset: None, - })), - }; - - Ok(response) - } -} - -impl Session {} diff --git a/src/daft-connect/src/op/execute/root.rs b/src/daft-connect/src/op/execute/root.rs deleted file mode 100644 index ab11a87c27..0000000000 --- a/src/daft-connect/src/op/execute/root.rs +++ /dev/null @@ -1,77 +0,0 @@ -use std::{future::ready, pin::pin, sync::Arc}; - -use common_daft_config::DaftExecutionConfig; -use daft_local_execution::NativeExecutor; -use futures::stream; -use spark_connect::{ExecutePlanResponse, Relation}; -use tonic::{codegen::tokio_stream::wrappers::ReceiverStream, Status}; - -use crate::{ - op::execute::{ExecuteStream, PlanIds}, - session::Session, - translation, -}; - -impl Session { - pub async fn handle_root_command( - &self, - command: Relation, - operation_id: String, - ) -> Result { - use futures::{StreamExt, TryStreamExt}; - - let context = PlanIds { - session: self.client_side_session_id().to_string(), - server_side_session: self.server_side_session_id().to_string(), - operation: operation_id, - }; - - let finished = context.finished(); - - let (tx, rx) = tokio::sync::mpsc::channel::>(1); - - let pset = self.psets.clone(); - - tokio::spawn(async move { - let execution_fut = async { - let translator = translation::SparkAnalyzer::new(&pset); - let lp = translator.to_logical_plan(command).await?; - - // todo: convert optimize to async (looks like A LOT of work)... it touches a lot of API - // I tried and spent about an hour and gave up ~ Andrew Gazelka 🪦 2024-12-09 - let optimized_plan = tokio::task::spawn_blocking(move || lp.optimize()) - .await - .unwrap()?; - - let cfg = Arc::new(DaftExecutionConfig::default()); - let native_executor = NativeExecutor::from_logical_plan_builder(&optimized_plan)?; - - let mut result_stream = pin!(native_executor.run(&pset, cfg, None)?.into_stream()); - - while let Some(result) = result_stream.next().await { - let result = result?; - let tables = result.get_tables()?; - for table in tables.as_slice() { - let response = context.gen_response(table)?; - if tx.send(Ok(response)).await.is_err() { - return Ok(()); - } - } - } - Ok(()) - }; - - if let Err(e) = execution_fut.await { - let _ = tx.send(Err(e)).await; - } - }); - - let stream = ReceiverStream::new(rx); - - let stream = stream - .map_err(|e| Status::internal(format!("Error in Daft server: {e:?}"))) - .chain(stream::once(ready(Ok(finished)))); - - Ok(Box::pin(stream)) - } -} diff --git a/src/daft-connect/src/op/execute/write.rs b/src/daft-connect/src/op/execute/write.rs deleted file mode 100644 index 257a3f4ba7..0000000000 --- a/src/daft-connect/src/op/execute/write.rs +++ /dev/null @@ -1,145 +0,0 @@ -use std::{future::ready, pin::pin}; - -use common_daft_config::DaftExecutionConfig; -use common_file_formats::FileFormat; -use daft_local_execution::NativeExecutor; -use eyre::{bail, WrapErr}; -use spark_connect::{ - write_operation::{SaveMode, SaveType}, - WriteOperation, -}; -use tonic::Status; -use tracing::warn; - -use crate::{ - op::execute::{ExecuteStream, PlanIds}, - session::Session, - translation, -}; - -impl Session { - pub async fn handle_write_command( - &self, - operation: WriteOperation, - operation_id: String, - ) -> Result { - use futures::StreamExt; - - let context = PlanIds { - session: self.client_side_session_id().to_string(), - server_side_session: self.server_side_session_id().to_string(), - operation: operation_id, - }; - - let finished = context.finished(); - let pset = self.psets.clone(); - - let result = async move { - let WriteOperation { - input, - source, - mode, - sort_column_names, - partitioning_columns, - bucket_by, - options, - clustering_columns, - save_type, - } = operation; - - let Some(input) = input else { - bail!("Input is required"); - }; - - let Some(source) = source else { - bail!("Source is required"); - }; - - let file_format: FileFormat = source.parse()?; - - let Ok(mode) = SaveMode::try_from(mode) else { - bail!("Invalid save mode: {mode}"); - }; - - if !sort_column_names.is_empty() { - // todo(completeness): implement sort - warn!("Ignoring sort_column_names: {sort_column_names:?} (not yet implemented)"); - } - - if !partitioning_columns.is_empty() { - // todo(completeness): implement partitioning - warn!( - "Ignoring partitioning_columns: {partitioning_columns:?} (not yet implemented)" - ); - } - - if let Some(bucket_by) = bucket_by { - // todo(completeness): implement bucketing - warn!("Ignoring bucket_by: {bucket_by:?} (not yet implemented)"); - } - - if !options.is_empty() { - // todo(completeness): implement options - warn!("Ignoring options: {options:?} (not yet implemented)"); - } - - if !clustering_columns.is_empty() { - // todo(completeness): implement clustering - warn!("Ignoring clustering_columns: {clustering_columns:?} (not yet implemented)"); - } - - match mode { - SaveMode::Unspecified => {} - SaveMode::Append => {} - SaveMode::Overwrite => {} - SaveMode::ErrorIfExists => {} - SaveMode::Ignore => {} - } - - let Some(save_type) = save_type else { - bail!("Save type is required"); - }; - - let path = match save_type { - SaveType::Path(path) => path, - SaveType::Table(table) => { - let name = table.table_name; - bail!("Tried to write to table {name} but it is not yet implemented. Try to write to a path instead."); - } - }; - - let translator = translation::SparkAnalyzer::new(&pset); - - let plan = translator.to_logical_plan(input).await?; - - let plan = plan - .table_write(&path, file_format, None, None, None) - .wrap_err("Failed to create table write plan")?; - - let optimized_plan = plan.optimize()?; - let cfg = DaftExecutionConfig::default(); - let native_executor = NativeExecutor::from_logical_plan_builder(&optimized_plan)?; - - let mut result_stream = - pin!(native_executor.run(&pset, cfg.into(), None)?.into_stream()); - - // this is so we make sure the operation is actually done - // before we return - // - // an example where this is important is if we write to a parquet file - // and then read immediately after, we need to wait for the write to finish - while let Some(_result) = result_stream.next().await {} - - Ok(()) - }; - - use futures::TryFutureExt; - - let result = result.map_err(|e| Status::internal(format!("Error in Daft server: {e:?}"))); - - let future = result.and_then(|()| ready(Ok(finished))); - let stream = futures::stream::once(future); - - Ok(Box::pin(stream)) - } -} diff --git a/src/daft-connect/src/response_builder.rs b/src/daft-connect/src/response_builder.rs new file mode 100644 index 0000000000..30496d8f45 --- /dev/null +++ b/src/daft-connect/src/response_builder.rs @@ -0,0 +1,133 @@ +use arrow2::io::ipc::write::StreamWriter; +use daft_table::Table; +use eyre::Context; +use spark_connect::{ + analyze_plan_response, + execute_plan_response::{ArrowBatch, ResponseType, ResultComplete}, + AnalyzePlanResponse, DataType, ExecutePlanResponse, +}; +use uuid::Uuid; + +use crate::session::Session; + +/// A utility for constructing responses to send back to the client, +/// It's generic over the type of response it can build, which is determined by the type parameter `T` +/// +/// spark responses are stateful, so we need to keep track of the session id, operation id, and server side session id +#[derive(Clone)] +pub struct ResponseBuilder { + pub(crate) session: String, + pub(crate) operation_id: String, + pub(crate) server_side_session_id: String, + pub(crate) phantom: std::marker::PhantomData, +} +impl ResponseBuilder { + pub fn new(session: &Session, operation_id: String) -> Self { + Self::new_with_op_id( + session.client_side_session_id(), + session.server_side_session_id(), + operation_id, + ) + } + pub fn new_with_op_id( + client_side_session_id: impl Into, + server_side_session_id: impl Into, + operation_id: impl Into, + ) -> Self { + let client_side_session_id = client_side_session_id.into(); + let server_side_session_id = server_side_session_id.into(); + let operation_id = operation_id.into(); + + Self { + session: client_side_session_id, + server_side_session_id, + operation_id, + phantom: std::marker::PhantomData, + } + } +} + +impl ResponseBuilder { + /// Send a result complete response to the client + pub fn result_complete_response(&self) -> ExecutePlanResponse { + ExecutePlanResponse { + session_id: self.session.to_string(), + server_side_session_id: self.server_side_session_id.to_string(), + operation_id: self.operation_id.to_string(), + response_id: Uuid::new_v4().to_string(), + metrics: None, + observed_metrics: vec![], + schema: None, + response_type: Some(ResponseType::ResultComplete(ResultComplete {})), + } + } + + /// Send an arrow batch response to the client + pub fn arrow_batch_response(&self, table: &Table) -> eyre::Result { + let mut data = Vec::new(); + + let mut writer = StreamWriter::new( + &mut data, + arrow2::io::ipc::write::WriteOptions { compression: None }, + ); + + let row_count = table.num_rows(); + + let schema = table + .schema + .to_arrow() + .wrap_err("Failed to convert Daft schema to Arrow schema")?; + + writer + .start(&schema, None) + .wrap_err("Failed to start Arrow stream writer with schema")?; + + let arrays = table.get_inner_arrow_arrays().collect(); + let chunk = arrow2::chunk::Chunk::new(arrays); + + writer + .write(&chunk, None) + .wrap_err("Failed to write Arrow chunk to stream writer")?; + + let response = ExecutePlanResponse { + session_id: self.session.clone(), + server_side_session_id: self.server_side_session_id.clone(), + operation_id: self.operation_id.clone(), + response_id: Uuid::new_v4().to_string(), // todo: implement this + metrics: None, // todo: implement this + observed_metrics: vec![], + schema: None, + response_type: Some(ResponseType::ArrowBatch(ArrowBatch { + row_count: row_count as i64, + data, + start_offset: None, + })), + }; + + Ok(response) + } +} + +impl ResponseBuilder { + pub fn schema_response(&self, dtype: DataType) -> AnalyzePlanResponse { + let schema = analyze_plan_response::Schema { + schema: Some(dtype), + }; + + AnalyzePlanResponse { + session_id: self.session.clone(), + server_side_session_id: self.server_side_session_id.clone(), + result: Some(analyze_plan_response::Result::Schema(schema)), + } + } + + pub fn treestring_response(&self, tree_string: String) -> AnalyzePlanResponse { + AnalyzePlanResponse { + session_id: self.session.clone(), + server_side_session_id: self.server_side_session_id.clone(), + result: Some(analyze_plan_response::Result::TreeString( + analyze_plan_response::TreeString { tree_string }, + )), + } + } +} diff --git a/src/daft-connect/src/session.rs b/src/daft-connect/src/session.rs index 7de8d5851b..234918d946 100644 --- a/src/daft-connect/src/session.rs +++ b/src/daft-connect/src/session.rs @@ -1,8 +1,15 @@ -use std::collections::BTreeMap; - +use std::{ + collections::BTreeMap, + sync::{Arc, RwLock}, +}; + +use common_runtime::RuntimeRef; +use daft_catalog::DaftCatalog; +use daft_local_execution::NativeExecutor; use daft_micropartition::partitioning::InMemoryPartitionSetCache; use uuid::Uuid; +#[derive(Clone)] pub struct Session { /// so order is preserved, and so we can efficiently do a prefix search /// @@ -13,7 +20,10 @@ pub struct Session { server_side_session_id: String, /// MicroPartitionSet associated with this session /// this will be filled up as the user runs queries - pub(crate) psets: InMemoryPartitionSetCache, + pub(crate) psets: Arc, + pub(crate) compute_runtime: RuntimeRef, + pub(crate) engine: Arc, + pub(crate) catalog: Arc>, } impl Session { @@ -28,11 +38,16 @@ impl Session { pub fn new(id: String) -> Self { let server_side_session_id = Uuid::new_v4(); let server_side_session_id = server_side_session_id.to_string(); + let rt = common_runtime::get_compute_runtime(); + Self { config_values: Default::default(), id, server_side_session_id, - psets: InMemoryPartitionSetCache::empty(), + psets: Arc::new(InMemoryPartitionSetCache::empty()), + compute_runtime: rt.clone(), + engine: Arc::new(NativeExecutor::default().with_runtime(rt.runtime.clone())), + catalog: Arc::new(RwLock::new(DaftCatalog::default())), } } diff --git a/src/daft-connect/src/spark_analyzer.rs b/src/daft-connect/src/spark_analyzer.rs new file mode 100644 index 0000000000..f98cbec714 --- /dev/null +++ b/src/daft-connect/src/spark_analyzer.rs @@ -0,0 +1,885 @@ +//! Translation between Spark Connect and Daft + +mod datatype; +mod literal; + +use std::{io::Cursor, sync::Arc}; + +use arrow2::io::ipc::read::{read_stream_metadata, StreamReader, StreamState}; +use daft_core::series::Series; +use daft_dsl::col; +use daft_logical_plan::{LogicalPlanBuilder, PyLogicalPlanBuilder}; +use daft_micropartition::{ + partitioning::{ + MicroPartitionSet, PartitionCacheEntry, PartitionMetadata, PartitionSet, PartitionSetCache, + }, + python::PyMicroPartition, + MicroPartition, +}; +use daft_scan::builder::{CsvScanBuilder, ParquetScanBuilder}; +use daft_schema::schema::{Schema, SchemaRef}; +use daft_sql::SQLPlanner; +use daft_table::Table; +use datatype::to_daft_datatype; +pub use datatype::to_spark_datatype; +use eyre::{bail, ensure, Context}; +use itertools::zip_eq; +use literal::to_daft_literal; +use pyo3::{intern, prelude::*}; +use spark_connect::{ + aggregate::GroupType, + data_type::StructField, + expression::{ + self as spark_expr, + cast::{CastToType, EvalMode}, + sort_order::{NullOrdering, SortDirection}, + ExprType, SortOrder, UnresolvedFunction, + }, + read::ReadType, + relation::RelType, + Deduplicate, Expression, Limit, Range, Relation, Sort, Sql, +}; +use tracing::debug; + +use crate::{ + functions::CONNECT_FUNCTIONS, invalid_argument_err, not_yet_implemented, session::Session, + util::FromOptionalField, Runner, +}; + +#[derive(Clone)] +pub struct SparkAnalyzer<'a> { + pub session: &'a Session, +} + +impl SparkAnalyzer<'_> { + pub fn new(session: &Session) -> SparkAnalyzer<'_> { + SparkAnalyzer { session } + } + + pub fn create_in_memory_scan( + &self, + plan_id: usize, + schema: Arc, + tables: Vec, + ) -> eyre::Result { + let runner = self.session.get_runner()?; + + match runner { + Runner::Ray => { + let mp = + MicroPartition::new_loaded(tables[0].schema.clone(), Arc::new(tables), None); + Python::with_gil(|py| { + // Convert MicroPartition to a logical plan using Python interop. + let py_micropartition = py + .import(intern!(py, "daft.table"))? + .getattr(intern!(py, "MicroPartition"))? + .getattr(intern!(py, "_from_pymicropartition"))? + .call1((PyMicroPartition::from(mp),))?; + + // ERROR: 2: AttributeError: 'daft.daft.PySchema' object has no attribute '_schema' + let py_plan_builder = py + .import(intern!(py, "daft.dataframe.dataframe"))? + .getattr(intern!(py, "to_logical_plan_builder"))? + .call1((py_micropartition,))?; + let py_plan_builder = py_plan_builder.getattr(intern!(py, "_builder"))?; + let plan: PyLogicalPlanBuilder = py_plan_builder.extract()?; + + Ok::<_, eyre::Error>(dbg!(plan.builder)) + }) + } + Runner::Native => { + let partition_key = uuid::Uuid::new_v4().to_string(); + + let pset = Arc::new(MicroPartitionSet::from_tables(plan_id, tables)?); + + let PartitionMetadata { + num_rows, + size_bytes, + } = pset.metadata(); + let num_partitions = pset.num_partitions(); + + self.session.psets.put_partition_set(&partition_key, &pset); + + let cache_entry = PartitionCacheEntry::new_rust(partition_key.clone(), pset); + + Ok(LogicalPlanBuilder::in_memory_scan( + &partition_key, + cache_entry, + schema, + num_partitions, + size_bytes, + num_rows, + )?) + } + } + } + + pub async fn to_logical_plan(&self, relation: Relation) -> eyre::Result { + let Some(common) = relation.common else { + bail!("Common metadata is required"); + }; + + if common.origin.is_some() { + debug!("Ignoring common metadata for relation: {common:?}; not yet implemented"); + } + + let Some(rel_type) = relation.rel_type else { + bail!("Relation type is required"); + }; + + match rel_type { + RelType::Limit(l) => self.limit(*l).await, + RelType::Range(r) => self.range(r), + RelType::Project(p) => self.project(*p).await, + RelType::Aggregate(a) => self.aggregate(*a).await, + RelType::WithColumns(w) => self.with_columns(*w).await, + RelType::ToDf(t) => self.to_df(*t).await, + RelType::LocalRelation(l) => { + let plan_id = common.plan_id.required("plan_id")?; + self.local_relation(plan_id, l) + } + RelType::WithColumnsRenamed(w) => self.with_columns_renamed(*w).await, + RelType::Read(r) => self.read(r).await, + RelType::Drop(d) => self.drop(*d).await, + RelType::Filter(f) => self.filter(*f).await, + RelType::ShowString(_) => unreachable!("should already be handled in execute"), + RelType::Deduplicate(rel) => self.deduplicate(*rel).await, + RelType::Sort(rel) => self.sort(*rel).await, + RelType::Sql(sql) => self.sql(sql).await, + plan => not_yet_implemented!("relation type: \"{}\"", rel_name(&plan))?, + } + } + + async fn limit(&self, limit: Limit) -> eyre::Result { + let Limit { input, limit } = limit; + + let Some(input) = input else { + bail!("input must be set"); + }; + + let plan = Box::pin(self.to_logical_plan(*input)).await?; + + plan.limit(i64::from(limit), false).map_err(Into::into) + } + + async fn deduplicate(&self, deduplicate: Deduplicate) -> eyre::Result { + let Deduplicate { + input, + column_names, + .. + } = deduplicate; + + if !column_names.is_empty() { + not_yet_implemented!("Deduplicate with column names")?; + } + + let input = input.required("input")?; + + let plan = Box::pin(self.to_logical_plan(*input)).await?; + + plan.distinct().map_err(Into::into) + } + + async fn sort(&self, sort: Sort) -> eyre::Result { + let Sort { + input, + order, + is_global, + } = sort; + + let input = input.required("input")?; + + if is_global == Some(false) { + not_yet_implemented!("Non Global sort")?; + } + + let plan = Box::pin(self.to_logical_plan(*input)).await?; + if order.is_empty() { + return plan + .sort(vec![col("*")], vec![false], vec![false]) + .map_err(Into::into); + } + let mut sort_by = Vec::with_capacity(order.len()); + let mut descending = Vec::with_capacity(order.len()); + let mut nulls_first = Vec::with_capacity(order.len()); + + for SortOrder { + child, + direction, + null_ordering, + } in order + { + let expr = child.required("child")?; + let expr = self.to_daft_expr(&expr)?; + + let sort_direction = SortDirection::try_from(direction) + .wrap_err_with(|| format!("Invalid sort direction: {direction}"))?; + + let desc = match sort_direction { + SortDirection::Ascending => false, + SortDirection::Descending | SortDirection::Unspecified => true, + }; + + let null_ordering = NullOrdering::try_from(null_ordering) + .wrap_err_with(|| format!("Invalid sort nulls: {null_ordering}"))?; + + let nf = match null_ordering { + NullOrdering::SortNullsUnspecified => desc, + NullOrdering::SortNullsFirst => true, + NullOrdering::SortNullsLast => false, + }; + + sort_by.push(expr); + descending.push(desc); + nulls_first.push(nf); + } + + plan.sort(sort_by, descending, nulls_first) + .map_err(Into::into) + } + + fn range(&self, range: Range) -> eyre::Result { + use daft_scan::python::pylib::ScanOperatorHandle; + let Range { + start, + end, + step, + num_partitions, + } = range; + + let partitions = num_partitions.unwrap_or(1); + + ensure!(partitions > 0, "num_partitions must be greater than 0"); + + let start = start.unwrap_or(0); + + let step = usize::try_from(step).wrap_err("step must be a positive integer")?; + ensure!(step > 0, "step must be greater than 0"); + + let plan = Python::with_gil(|py| { + let range_module = + PyModule::import(py, "daft.io._range").wrap_err("Failed to import range module")?; + + let range = range_module + .getattr(pyo3::intern!(py, "RangeScanOperator")) + .wrap_err("Failed to get range function")?; + + let range = range + .call1((start, end, step, partitions)) + .wrap_err("Failed to create range scan operator")? + .into_pyobject(py) + .unwrap() + .unbind(); + + let scan_operator_handle = ScanOperatorHandle::from_python_scan_operator(range, py)?; + + let plan = LogicalPlanBuilder::table_scan(scan_operator_handle.into(), None)?; + + eyre::Result::<_>::Ok(plan) + }) + .wrap_err("Failed to create range scan")?; + + Ok(plan) + } + + async fn read(&self, read: spark_connect::Read) -> eyre::Result { + let spark_connect::Read { + is_streaming, + read_type, + } = read; + + if is_streaming { + not_yet_implemented!("Streaming read")?; + } + + let read_type = read_type.required("read_type")?; + + match read_type { + ReadType::NamedTable(table) => { + let name = table.unparsed_identifier; + not_yet_implemented!("NamedTable").context(format!("table: {name}")) + } + ReadType::DataSource(source) => self.read_datasource(source).await, + } + } + + async fn read_datasource( + &self, + data_source: spark_connect::read::DataSource, + ) -> eyre::Result { + let spark_connect::read::DataSource { + format, + schema, + options, + paths, + predicates, + } = data_source; + + let format = format.required("format")?; + + ensure!(!paths.is_empty(), "Paths are required"); + + if let Some(schema) = schema { + debug!("Ignoring schema: {schema:?}; not yet implemented"); + } + + if !options.is_empty() { + debug!("Ignoring options: {options:?}; not yet implemented"); + } + + if !predicates.is_empty() { + debug!("Ignoring predicates: {predicates:?}; not yet implemented"); + } + + Ok(match &*format { + "parquet" => ParquetScanBuilder::new(paths).finish().await?, + "csv" => CsvScanBuilder::new(paths).finish().await?, + "json" => { + // todo(completeness): implement json reading + not_yet_implemented!("read json")? + } + other => { + bail!("Unsupported format: {other}; only parquet and csv are supported"); + } + }) + } + + async fn aggregate( + &self, + aggregate: spark_connect::Aggregate, + ) -> eyre::Result { + fn check_grouptype(group_type: GroupType) -> eyre::Result<()> { + match group_type { + GroupType::Groupby => {} + GroupType::Unspecified => { + invalid_argument_err!("GroupType must be specified; got Unspecified")?; + } + GroupType::Rollup => { + not_yet_implemented!("GroupType.Rollup not yet supported")?; + } + GroupType::Cube => { + not_yet_implemented!("GroupType.Cube")?; + } + GroupType::Pivot => { + not_yet_implemented!("GroupType.Pivot")?; + } + GroupType::GroupingSets => { + not_yet_implemented!("GroupType.GroupingSets")?; + } + }; + Ok(()) + } + + let spark_connect::Aggregate { + input, + group_type, + grouping_expressions, + aggregate_expressions, + pivot, + grouping_sets, + } = aggregate; + + let input = input.required("input")?; + + let mut plan = Box::pin(self.to_logical_plan(*input)).await?; + + let group_type = GroupType::try_from(group_type)?; + + check_grouptype(group_type)?; + + if let Some(pivot) = pivot { + bail!("Pivot not yet supported; got {pivot:?}"); + } + + if !grouping_sets.is_empty() { + bail!("Grouping sets not yet supported; got {grouping_sets:?}"); + } + + let grouping_expressions: Vec<_> = grouping_expressions + .iter() + .map(|e| self.to_daft_expr(e)) + .try_collect()?; + + let aggregate_expressions: Vec<_> = aggregate_expressions + .iter() + .map(|e| self.to_daft_expr(e)) + .try_collect()?; + + plan = plan.aggregate(aggregate_expressions, grouping_expressions)?; + + Ok(plan) + } + + async fn drop(&self, drop: spark_connect::Drop) -> eyre::Result { + let spark_connect::Drop { + input, + columns, + column_names, + } = drop; + + let input = input.required("input")?; + + if !columns.is_empty() { + not_yet_implemented!("columns is not supported; use column_names instead")?; + } + + let plan = Box::pin(self.to_logical_plan(*input)).await?; + + let to_select = plan + .schema() + .exclude(&column_names)? + .names() + .into_iter() + .map(daft_dsl::col) + .collect(); + + // Use select to keep only the columns we want + Ok(plan.select(to_select)?) + } + + pub async fn filter(&self, filter: spark_connect::Filter) -> eyre::Result { + let spark_connect::Filter { input, condition } = filter; + + let input = input.required("input")?; + let condition = condition.required("condition")?; + let condition = self.to_daft_expr(&condition)?; + + let plan = Box::pin(self.to_logical_plan(*input)).await?; + Ok(plan.filter(condition)?) + } + + pub fn local_relation( + &self, + plan_id: i64, + plan: spark_connect::LocalRelation, + ) -> eyre::Result { + // We can ignore spark schema. The true schema is sent in the + // arrow data. (see read_stream_metadata) + // the schema inside the plan is actually wrong. See https://issues.apache.org/jira/browse/SPARK-50627 + let spark_connect::LocalRelation { data, schema: _ } = plan; + + let data = data.required("data")?; + + let mut reader = Cursor::new(&data); + let metadata = read_stream_metadata(&mut reader)?; + + let arrow_schema = metadata.schema.clone(); + let daft_schema = Arc::new( + Schema::try_from(&arrow_schema) + .wrap_err("Failed to convert Arrow schema to Daft schema.")?, + ); + + let reader = StreamReader::new(reader, metadata, None); + + let tables = reader.into_iter().map(|ss| { + let ss = ss.wrap_err("Failed to read next chunk from StreamReader.")?; + + let chunk = match ss { + StreamState::Some(chunk) => chunk, + StreamState::Waiting => { + bail!("StreamReader is waiting for data, but a chunk was expected. This likely indicates that the spark provided data is incomplete.") + } + }; + + + let arrays = chunk.into_arrays(); + let columns = zip_eq(arrays, &arrow_schema.fields) + .map(|(array, arrow_field)| { + let field = Arc::new(arrow_field.into()); + + let series = Series::from_arrow(field, array) + .wrap_err("Failed to create Series from Arrow array.")?; + + Ok(series) + }) + .collect::>>()?; + + let batch = Table::from_nonempty_columns(columns)?; + + Ok(batch) + }).collect::>>()?; + + self.create_in_memory_scan(plan_id as _, daft_schema, tables) + } + + async fn project(&self, project: spark_connect::Project) -> eyre::Result { + let spark_connect::Project { input, expressions } = project; + + let input = input.required("input")?; + + let mut plan = Box::pin(self.to_logical_plan(*input)).await?; + + let daft_exprs: Vec<_> = expressions + .iter() + .map(|e| self.to_daft_expr(e)) + .try_collect()?; + plan = plan.select(daft_exprs)?; + + Ok(plan) + } + + async fn with_columns( + &self, + with_columns: spark_connect::WithColumns, + ) -> eyre::Result { + let spark_connect::WithColumns { input, aliases } = with_columns; + + let input = input.required("input")?; + + let plan = Box::pin(self.to_logical_plan(*input)).await?; + + let daft_exprs: Vec<_> = aliases + .into_iter() + .map(|alias| { + let expression = Expression { + common: None, + expr_type: Some(ExprType::Alias(Box::new(alias))), + }; + + self.to_daft_expr(&expression) + }) + .try_collect()?; + + Ok(plan.with_columns(daft_exprs)?) + } + + async fn with_columns_renamed( + &self, + with_columns_renamed: spark_connect::WithColumnsRenamed, + ) -> eyre::Result { + let spark_connect::WithColumnsRenamed { + input, + rename_columns_map, + renames, + } = with_columns_renamed; + + let Some(input) = input else { + bail!("Input is required"); + }; + + let plan = Box::pin(self.to_logical_plan(*input)).await?; + + // todo: let's implement this directly into daft + + // Convert the rename mappings into expressions + let rename_exprs = if !rename_columns_map.is_empty() { + // Use rename_columns_map if provided (legacy format) + rename_columns_map + .into_iter() + .map(|(old_name, new_name)| col(old_name.as_str()).alias(new_name.as_str())) + .collect() + } else { + // Use renames if provided (new format) + renames + .into_iter() + .map(|rename| col(rename.col_name.as_str()).alias(rename.new_col_name.as_str())) + .collect() + }; + + // Apply the rename expressions to the plan + let plan = plan + .select(rename_exprs) + .wrap_err("Failed to apply rename expressions to logical plan")?; + + Ok(plan) + } + + async fn to_df(&self, to_df: spark_connect::ToDf) -> eyre::Result { + let spark_connect::ToDf { + input, + column_names, + } = to_df; + + let input = input.required("input")?; + + let mut plan = Box::pin(self.to_logical_plan(*input)).await?; + + let column_names: Vec<_> = column_names.into_iter().map(daft_dsl::col).collect(); + + plan = plan + .select(column_names) + .wrap_err("Failed to add columns to logical plan")?; + Ok(plan) + } + + pub async fn relation_to_spark_schema( + &self, + input: Relation, + ) -> eyre::Result { + let result = self.relation_to_daft_schema(input).await?; + + let fields: eyre::Result> = result + .fields + .iter() + .map(|(name, field)| { + let field_type = to_spark_datatype(&field.dtype); + Ok(StructField { + name: name.clone(), // todo(correctness): name vs field.name... will they always be the same? + data_type: Some(field_type), + nullable: true, // todo(correctness): is this correct? + metadata: None, // todo(completeness): might want to add metadata here + }) + }) + .collect(); + + Ok(spark_connect::DataType { + kind: Some(spark_connect::data_type::Kind::Struct( + spark_connect::data_type::Struct { + fields: fields?, + type_variation_reference: 0, + }, + )), + }) + } + + pub async fn relation_to_daft_schema(&self, input: Relation) -> eyre::Result { + if let Some(common) = &input.common { + if common.origin.is_some() { + debug!("Ignoring common metadata for relation: {common:?}; not yet implemented"); + } + } + + let plan = Box::pin(self.to_logical_plan(input)).await?; + + let result = plan.schema(); + + Ok(result) + } + + #[allow(deprecated)] + async fn sql(&self, sql: Sql) -> eyre::Result { + let Sql { + query, + args, + pos_args, + named_arguments, + pos_arguments, + } = sql; + if !args.is_empty() { + not_yet_implemented!("args")?; + } + if !pos_args.is_empty() { + not_yet_implemented!("pos_args")?; + } + if !named_arguments.is_empty() { + not_yet_implemented!("named_arguments")?; + } + if !pos_arguments.is_empty() { + not_yet_implemented!("pos_arguments")?; + } + + let catalog = self + .session + .catalog + .read() + .map_err(|e| eyre::eyre!("Failed to read catalog: {e}"))?; + let catalog = catalog.clone(); + + let mut planner = SQLPlanner::new(catalog); + let plan = planner.plan_sql(&query)?; + Ok(plan.into()) + } + + pub fn to_daft_expr(&self, expression: &Expression) -> eyre::Result { + if let Some(common) = &expression.common { + if common.origin.is_some() { + debug!("Ignoring common metadata for relation: {common:?}; not yet implemented"); + } + }; + + let Some(expr) = &expression.expr_type else { + bail!("Expression is required"); + }; + + match expr { + spark_expr::ExprType::Literal(l) => to_daft_literal(l), + spark_expr::ExprType::UnresolvedAttribute(attr) => { + let spark_expr::UnresolvedAttribute { + unparsed_identifier, + plan_id, + is_metadata_column, + } = attr; + + if let Some(plan_id) = plan_id { + debug!( + "Ignoring plan_id {plan_id} for attribute expressions; not yet implemented" + ); + } + + if let Some(is_metadata_column) = is_metadata_column { + debug!("Ignoring is_metadata_column {is_metadata_column} for attribute expressions; not yet implemented"); + } + + Ok(daft_dsl::col(unparsed_identifier.as_str())) + } + spark_expr::ExprType::UnresolvedFunction(f) => self.process_function(f), + spark_expr::ExprType::ExpressionString(_) => { + bail!("Expression string not yet supported") + } + spark_expr::ExprType::UnresolvedStar(_) => { + bail!("Unresolved star expressions not yet supported") + } + spark_expr::ExprType::Alias(alias) => { + let spark_expr::Alias { + expr, + name, + metadata, + } = &**alias; + + let Some(expr) = expr else { + bail!("Alias expr is required"); + }; + + let [name] = name.as_slice() else { + bail!("Alias name is required and currently only works with a single string; got {name:?}"); + }; + + if let Some(metadata) = metadata { + bail!("Alias metadata is not yet supported; got {metadata:?}"); + } + + let child = self.to_daft_expr(expr)?; + + let name = Arc::from(name.as_str()); + + Ok(child.alias(name)) + } + spark_expr::ExprType::Cast(c) => { + let spark_expr::Cast { + expr, + eval_mode, + cast_to_type, + } = &**c; + + let Some(expr) = expr else { + bail!("Cast expression is required"); + }; + + let expr = self.to_daft_expr(expr)?; + + let Some(cast_to_type) = cast_to_type else { + bail!("Cast to type is required"); + }; + + let data_type = match cast_to_type { + CastToType::Type(kind) => to_daft_datatype(kind).wrap_err_with(|| { + format!("Failed to convert spark datatype to daft datatype: {kind:?}") + })?, + CastToType::TypeStr(s) => { + bail!("Cast to type string not yet supported; tried to cast to {s}"); + } + }; + + let eval_mode = EvalMode::try_from(*eval_mode) + .wrap_err_with(|| format!("Invalid cast eval mode: {eval_mode}"))?; + + debug!("Ignoring cast eval mode: {eval_mode:?}"); + + Ok(expr.cast(&data_type)) + } + spark_expr::ExprType::SortOrder(s) => { + let spark_expr::SortOrder { + child, + direction, + null_ordering, + } = &**s; + + let Some(_child) = child else { + bail!("Sort order child is required"); + }; + + let _sort_direction = SortDirection::try_from(*direction) + .wrap_err_with(|| format!("Invalid sort direction: {direction}"))?; + + let _sort_nulls = NullOrdering::try_from(*null_ordering) + .wrap_err_with(|| format!("Invalid sort nulls: {null_ordering}"))?; + + bail!("Sort order expressions not yet supported"); + } + other => not_yet_implemented!("expression type: {other:?}")?, + } + } + + fn process_function(&self, f: &UnresolvedFunction) -> eyre::Result { + let UnresolvedFunction { + function_name, + arguments, + is_distinct, + is_user_defined_function, + } = f; + + if *is_distinct { + not_yet_implemented!("Distinct ")?; + } + + if *is_user_defined_function { + not_yet_implemented!("User-defined functions")?; + } + + let Some(f) = CONNECT_FUNCTIONS.get(function_name.as_str()) else { + return not_yet_implemented!("function: {function_name}")?; + }; + + f.to_expr(arguments, self) + } +} + +fn rel_name(rel: &RelType) -> &str { + match rel { + RelType::Read(_) => "Read", + RelType::Project(_) => "Project", + RelType::Filter(_) => "Filter", + RelType::Join(_) => "Join", + RelType::SetOp(_) => "SetOp", + RelType::Sort(_) => "Sort", + RelType::Limit(_) => "Limit", + RelType::Aggregate(_) => "Aggregate", + RelType::Sql(_) => "Sql", + RelType::LocalRelation(_) => "LocalRelation", + RelType::Sample(_) => "Sample", + RelType::Offset(_) => "Offset", + RelType::Deduplicate(_) => "Deduplicate", + RelType::Range(_) => "Range", + RelType::SubqueryAlias(_) => "SubqueryAlias", + RelType::Repartition(_) => "Repartition", + RelType::ToDf(_) => "ToDf", + RelType::WithColumnsRenamed(_) => "WithColumnsRenamed", + RelType::ShowString(_) => "ShowString", + RelType::Drop(_) => "Drop", + RelType::Tail(_) => "Tail", + RelType::WithColumns(_) => "WithColumns", + RelType::Hint(_) => "Hint", + RelType::Unpivot(_) => "Unpivot", + RelType::ToSchema(_) => "ToSchema", + RelType::RepartitionByExpression(_) => "RepartitionByExpression", + RelType::MapPartitions(_) => "MapPartitions", + RelType::CollectMetrics(_) => "CollectMetrics", + RelType::Parse(_) => "Parse", + RelType::GroupMap(_) => "GroupMap", + RelType::CoGroupMap(_) => "CoGroupMap", + RelType::WithWatermark(_) => "WithWatermark", + RelType::ApplyInPandasWithState(_) => "ApplyInPandasWithState", + RelType::HtmlString(_) => "HtmlString", + RelType::CachedLocalRelation(_) => "CachedLocalRelation", + RelType::CachedRemoteRelation(_) => "CachedRemoteRelation", + RelType::CommonInlineUserDefinedTableFunction(_) => "CommonInlineUserDefinedTableFunction", + RelType::AsOfJoin(_) => "AsOfJoin", + RelType::CommonInlineUserDefinedDataSource(_) => "CommonInlineUserDefinedDataSource", + RelType::WithRelations(_) => "WithRelations", + RelType::Transpose(_) => "Transpose", + RelType::FillNa(_) => "FillNa", + RelType::DropNa(_) => "DropNa", + RelType::Replace(_) => "Replace", + RelType::Summary(_) => "Summary", + RelType::Crosstab(_) => "Crosstab", + RelType::Describe(_) => "Describe", + RelType::Cov(_) => "Cov", + RelType::Corr(_) => "Corr", + RelType::ApproxQuantile(_) => "ApproxQuantile", + RelType::FreqItems(_) => "FreqItems", + RelType::SampleBy(_) => "SampleBy", + RelType::Catalog(_) => "Catalog", + RelType::Extension(_) => "Extension", + RelType::Unknown(_) => "Unknown", + } +} diff --git a/src/daft-connect/src/translation/datatype.rs b/src/daft-connect/src/spark_analyzer/datatype.rs similarity index 59% rename from src/daft-connect/src/translation/datatype.rs rename to src/daft-connect/src/spark_analyzer/datatype.rs index d6e51250c7..3b3065b56b 100644 --- a/src/daft-connect/src/translation/datatype.rs +++ b/src/daft-connect/src/spark_analyzer/datatype.rs @@ -1,70 +1,31 @@ use daft_schema::{dtype::DataType, field::Field, time_unit::TimeUnit}; use eyre::{bail, ensure, WrapErr}; use spark_connect::data_type::Kind; -use tracing::warn; +use tracing::debug; pub fn to_spark_datatype(datatype: &DataType) -> spark_connect::DataType { + macro_rules! simple_spark_type { + ($kind:ident) => { + spark_connect::DataType { + kind: Some(Kind::$kind(spark_connect::data_type::$kind { + type_variation_reference: 0, + })), + } + }; + } match datatype { - DataType::Null => spark_connect::DataType { - kind: Some(Kind::Null(spark_connect::data_type::Null { - type_variation_reference: 0, - })), - }, - DataType::Boolean => spark_connect::DataType { - kind: Some(Kind::Boolean(spark_connect::data_type::Boolean { - type_variation_reference: 0, - })), - }, - DataType::Int8 => spark_connect::DataType { - kind: Some(Kind::Byte(spark_connect::data_type::Byte { - type_variation_reference: 0, - })), - }, - DataType::Int16 => spark_connect::DataType { - kind: Some(Kind::Short(spark_connect::data_type::Short { - type_variation_reference: 0, - })), - }, - DataType::Int32 => spark_connect::DataType { - kind: Some(Kind::Integer(spark_connect::data_type::Integer { - type_variation_reference: 0, - })), - }, - DataType::Int64 => spark_connect::DataType { - kind: Some(Kind::Long(spark_connect::data_type::Long { - type_variation_reference: 0, - })), - }, - DataType::UInt8 => spark_connect::DataType { - kind: Some(Kind::Byte(spark_connect::data_type::Byte { - type_variation_reference: 0, - })), - }, - DataType::UInt16 => spark_connect::DataType { - kind: Some(Kind::Short(spark_connect::data_type::Short { - type_variation_reference: 0, - })), - }, - DataType::UInt32 => spark_connect::DataType { - kind: Some(Kind::Integer(spark_connect::data_type::Integer { - type_variation_reference: 0, - })), - }, - DataType::UInt64 => spark_connect::DataType { - kind: Some(Kind::Long(spark_connect::data_type::Long { - type_variation_reference: 0, - })), - }, - DataType::Float32 => spark_connect::DataType { - kind: Some(Kind::Float(spark_connect::data_type::Float { - type_variation_reference: 0, - })), - }, - DataType::Float64 => spark_connect::DataType { - kind: Some(Kind::Double(spark_connect::data_type::Double { - type_variation_reference: 0, - })), - }, + DataType::Null => simple_spark_type!(Null), + DataType::Boolean => simple_spark_type!(Boolean), + DataType::Int8 => simple_spark_type!(Byte), + DataType::Int16 => simple_spark_type!(Short), + DataType::Int32 => simple_spark_type!(Integer), + DataType::Int64 => simple_spark_type!(Long), + DataType::UInt8 => simple_spark_type!(Byte), + DataType::UInt16 => simple_spark_type!(Short), + DataType::UInt32 => simple_spark_type!(Integer), + DataType::UInt64 => simple_spark_type!(Long), + DataType::Float32 => simple_spark_type!(Float), + DataType::Float64 => simple_spark_type!(Double), DataType::Decimal128(precision, scale) => spark_connect::DataType { kind: Some(Kind::Decimal(spark_connect::data_type::Decimal { scale: Some(*scale as i32), @@ -73,23 +34,15 @@ pub fn to_spark_datatype(datatype: &DataType) -> spark_connect::DataType { })), }, DataType::Timestamp(unit, _) => { - warn!("Ignoring time unit {unit:?} for timestamp type"); + debug!("Ignoring time unit {unit:?} for timestamp type"); spark_connect::DataType { kind: Some(Kind::Timestamp(spark_connect::data_type::Timestamp { type_variation_reference: 0, })), } } - DataType::Date => spark_connect::DataType { - kind: Some(Kind::Date(spark_connect::data_type::Date { - type_variation_reference: 0, - })), - }, - DataType::Binary => spark_connect::DataType { - kind: Some(Kind::Binary(spark_connect::data_type::Binary { - type_variation_reference: 0, - })), - }, + DataType::Date => simple_spark_type!(Date), + DataType::Binary => simple_spark_type!(Binary), DataType::Utf8 => spark_connect::DataType { kind: Some(Kind::String(spark_connect::data_type::String { type_variation_reference: 0, @@ -122,42 +75,40 @@ pub fn to_daft_datatype(datatype: &spark_connect::DataType) -> eyre::Result {{ + ensure!($value.type_variation_reference == 0, type_variation_err); + Ok($dtype) + }}; + } + match kind { Kind::Null(value) => { - ensure!(value.type_variation_reference == 0, type_variation_err); - Ok(DataType::Null) + simple_type_case!(value, DataType::Null) } Kind::Binary(value) => { - ensure!(value.type_variation_reference == 0, type_variation_err); - Ok(DataType::Binary) + simple_type_case!(value, DataType::Binary) } Kind::Boolean(value) => { - ensure!(value.type_variation_reference == 0, type_variation_err); - Ok(DataType::Boolean) + simple_type_case!(value, DataType::Boolean) } Kind::Byte(value) => { - ensure!(value.type_variation_reference == 0, type_variation_err); - Ok(DataType::Int8) + simple_type_case!(value, DataType::Int8) } Kind::Short(value) => { - ensure!(value.type_variation_reference == 0, type_variation_err); - Ok(DataType::Int16) + simple_type_case!(value, DataType::Int16) } Kind::Integer(value) => { - ensure!(value.type_variation_reference == 0, type_variation_err); - Ok(DataType::Int32) + simple_type_case!(value, DataType::Int32) } Kind::Long(value) => { - ensure!(value.type_variation_reference == 0, type_variation_err); - Ok(DataType::Int64) + simple_type_case!(value, DataType::Int64) } Kind::Float(value) => { - ensure!(value.type_variation_reference == 0, type_variation_err); - Ok(DataType::Float32) + simple_type_case!(value, DataType::Float32) } Kind::Double(value) => { - ensure!(value.type_variation_reference == 0, type_variation_err); - Ok(DataType::Float64) + simple_type_case!(value, DataType::Float64) } Kind::Decimal(value) => { ensure!(value.type_variation_reference == 0, type_variation_err); @@ -179,20 +130,16 @@ pub fn to_daft_datatype(datatype: &spark_connect::DataType) -> eyre::Result { - ensure!(value.type_variation_reference == 0, type_variation_err); - Ok(DataType::Utf8) + simple_type_case!(value, DataType::Utf8) } Kind::Char(value) => { - ensure!(value.type_variation_reference == 0, type_variation_err); - Ok(DataType::Utf8) + simple_type_case!(value, DataType::Utf8) } Kind::VarChar(value) => { - ensure!(value.type_variation_reference == 0, type_variation_err); - Ok(DataType::Utf8) + simple_type_case!(value, DataType::Utf8) } Kind::Date(value) => { - ensure!(value.type_variation_reference == 0, type_variation_err); - Ok(DataType::Date) + simple_type_case!(value, DataType::Date) } Kind::Timestamp(value) => { ensure!(value.type_variation_reference == 0, type_variation_err); diff --git a/src/daft-connect/src/translation/literal.rs b/src/daft-connect/src/spark_analyzer/literal.rs similarity index 69% rename from src/daft-connect/src/translation/literal.rs rename to src/daft-connect/src/spark_analyzer/literal.rs index f6a26db84a..914c3bba23 100644 --- a/src/daft-connect/src/translation/literal.rs +++ b/src/daft-connect/src/spark_analyzer/literal.rs @@ -2,6 +2,8 @@ use daft_core::datatypes::IntervalValue; use eyre::bail; use spark_connect::expression::{literal::LiteralType, Literal}; +use crate::not_yet_implemented; + // todo(test): add tests for this esp in Python pub fn to_daft_literal(literal: &Literal) -> eyre::Result { let Some(literal) = &literal.literal_type else { @@ -9,18 +11,14 @@ pub fn to_daft_literal(literal: &Literal) -> eyre::Result { }; match literal { - LiteralType::Array(_) => bail!("Array literals not yet supported"), + LiteralType::Array(_) => not_yet_implemented!("array literals")?, LiteralType::Binary(bytes) => Ok(daft_dsl::lit(bytes.as_slice())), LiteralType::Boolean(b) => Ok(daft_dsl::lit(*b)), - LiteralType::Byte(_) => bail!("Byte literals not yet supported"), - LiteralType::CalendarInterval(_) => { - bail!("Calendar interval literals not yet supported") - } + LiteralType::Byte(_) => not_yet_implemented!("Byte literals")?, + LiteralType::CalendarInterval(_) => not_yet_implemented!("Calendar interval literals")?, LiteralType::Date(d) => Ok(daft_dsl::lit(*d)), - LiteralType::DayTimeInterval(_) => { - bail!("Day-time interval literals not yet supported") - } - LiteralType::Decimal(_) => bail!("Decimal literals not yet supported"), + LiteralType::DayTimeInterval(_) => not_yet_implemented!("Day-time interval literals")?, + LiteralType::Decimal(_) => not_yet_implemented!("Decimal literals")?, LiteralType::Double(d) => Ok(daft_dsl::lit(*d)), LiteralType::Float(f) => { let f = f64::from(*f); @@ -28,14 +26,14 @@ pub fn to_daft_literal(literal: &Literal) -> eyre::Result { } LiteralType::Integer(i) => Ok(daft_dsl::lit(*i)), LiteralType::Long(l) => Ok(daft_dsl::lit(*l)), - LiteralType::Map(_) => bail!("Map literals not yet supported"), + LiteralType::Map(_) => not_yet_implemented!("Map literals")?, LiteralType::Null(_) => { // todo(correctness): is it ok to assume type is i32 here? Ok(daft_dsl::null_lit()) } - LiteralType::Short(_) => bail!("Short literals not yet supported"), + LiteralType::Short(_) => not_yet_implemented!("Short literals")?, LiteralType::String(s) => Ok(daft_dsl::lit(s.as_str())), - LiteralType::Struct(_) => bail!("Struct literals not yet supported"), + LiteralType::Struct(_) => not_yet_implemented!("Struct literals")?, LiteralType::Timestamp(ts) => { // todo(correctness): is it ok that the type is different logically? Ok(daft_dsl::lit(*ts)) diff --git a/src/daft-connect/src/translation.rs b/src/daft-connect/src/translation.rs deleted file mode 100644 index 73dc2f998d..0000000000 --- a/src/daft-connect/src/translation.rs +++ /dev/null @@ -1,13 +0,0 @@ -//! Translation between Spark Connect and Daft - -mod datatype; -mod expr; -mod literal; -mod logical_plan; -mod schema; - -pub use datatype::{to_daft_datatype, to_spark_datatype}; -pub use expr::to_daft_expr; -pub use literal::to_daft_literal; -pub use logical_plan::SparkAnalyzer; -pub use schema::relation_to_spark_schema; diff --git a/src/daft-connect/src/translation/expr.rs b/src/daft-connect/src/translation/expr.rs deleted file mode 100644 index 0354dc504c..0000000000 --- a/src/daft-connect/src/translation/expr.rs +++ /dev/null @@ -1,166 +0,0 @@ -use std::sync::Arc; - -use eyre::{bail, Context}; -use spark_connect::{ - expression as spark_expr, - expression::{ - cast::{CastToType, EvalMode}, - sort_order::{NullOrdering, SortDirection}, - }, - Expression, -}; -use tracing::warn; -use unresolved_function::unresolved_to_daft_expr; - -use crate::translation::{to_daft_datatype, to_daft_literal}; - -mod unresolved_function; - -pub fn to_daft_expr(expression: &Expression) -> eyre::Result { - if let Some(common) = &expression.common { - if common.origin.is_some() { - warn!("Ignoring common metadata for relation: {common:?}; not yet implemented"); - } - }; - - let Some(expr) = &expression.expr_type else { - bail!("Expression is required"); - }; - - match expr { - spark_expr::ExprType::Literal(l) => to_daft_literal(l), - spark_expr::ExprType::UnresolvedAttribute(attr) => { - let spark_expr::UnresolvedAttribute { - unparsed_identifier, - plan_id, - is_metadata_column, - } = attr; - - if let Some(plan_id) = plan_id { - warn!("Ignoring plan_id {plan_id} for attribute expressions; not yet implemented"); - } - - if let Some(is_metadata_column) = is_metadata_column { - warn!("Ignoring is_metadata_column {is_metadata_column} for attribute expressions; not yet implemented"); - } - - Ok(daft_dsl::col(unparsed_identifier.as_str())) - } - spark_expr::ExprType::UnresolvedFunction(f) => { - unresolved_to_daft_expr(f).wrap_err("Failed to handle unresolved function") - } - spark_expr::ExprType::ExpressionString(_) => bail!("Expression string not yet supported"), - spark_expr::ExprType::UnresolvedStar(_) => { - bail!("Unresolved star expressions not yet supported") - } - spark_expr::ExprType::Alias(alias) => { - let spark_expr::Alias { - expr, - name, - metadata, - } = &**alias; - - let Some(expr) = expr else { - bail!("Alias expr is required"); - }; - - let [name] = name.as_slice() else { - bail!("Alias name is required and currently only works with a single string; got {name:?}"); - }; - - if let Some(metadata) = metadata { - bail!("Alias metadata is not yet supported; got {metadata:?}"); - } - - let child = to_daft_expr(expr)?; - - let name = Arc::from(name.as_str()); - - Ok(child.alias(name)) - } - spark_expr::ExprType::Cast(c) => { - // Cast { expr: Some(Expression { common: None, expr_type: Some(UnresolvedAttribute(UnresolvedAttribute { unparsed_identifier: "id", plan_id: None, is_metadata_column: None })) }), eval_mode: Unspecified, cast_to_type: Some(Type(DataType { kind: Some(String(String { type_variation_reference: 0, collation: "" })) })) } - // thread 'tokio-runtime-worker' panicked at src/daft-connect/src/trans - let spark_expr::Cast { - expr, - eval_mode, - cast_to_type, - } = &**c; - - let Some(expr) = expr else { - bail!("Cast expression is required"); - }; - - let expr = to_daft_expr(expr)?; - - let Some(cast_to_type) = cast_to_type else { - bail!("Cast to type is required"); - }; - - let data_type = match cast_to_type { - CastToType::Type(kind) => to_daft_datatype(kind).wrap_err_with(|| { - format!("Failed to convert spark datatype to daft datatype: {kind:?}") - })?, - CastToType::TypeStr(s) => { - bail!("Cast to type string not yet supported; tried to cast to {s}"); - } - }; - - let eval_mode = EvalMode::try_from(*eval_mode) - .wrap_err_with(|| format!("Invalid cast eval mode: {eval_mode}"))?; - - warn!("Ignoring cast eval mode: {eval_mode:?}"); - - Ok(expr.cast(&data_type)) - } - spark_expr::ExprType::UnresolvedRegex(_) => { - bail!("Unresolved regex expressions not yet supported") - } - spark_expr::ExprType::SortOrder(s) => { - let spark_expr::SortOrder { - child, - direction, - null_ordering, - } = &**s; - - let Some(_child) = child else { - bail!("Sort order child is required"); - }; - - let _sort_direction = SortDirection::try_from(*direction) - .wrap_err_with(|| format!("Invalid sort direction: {direction}"))?; - - let _sort_nulls = NullOrdering::try_from(*null_ordering) - .wrap_err_with(|| format!("Invalid sort nulls: {null_ordering}"))?; - - bail!("Sort order expressions not yet supported"); - } - spark_expr::ExprType::LambdaFunction(_) => { - bail!("Lambda function expressions not yet supported") - } - spark_expr::ExprType::Window(_) => bail!("Window expressions not yet supported"), - spark_expr::ExprType::UnresolvedExtractValue(_) => { - bail!("Unresolved extract value expressions not yet supported") - } - spark_expr::ExprType::UpdateFields(_) => { - bail!("Update fields expressions not yet supported") - } - spark_expr::ExprType::UnresolvedNamedLambdaVariable(_) => { - bail!("Unresolved named lambda variable expressions not yet supported") - } - spark_expr::ExprType::CommonInlineUserDefinedFunction(_) => { - bail!("Common inline user defined function expressions not yet supported") - } - spark_expr::ExprType::CallFunction(_) => { - bail!("Call function expressions not yet supported") - } - spark_expr::ExprType::NamedArgumentExpression(_) => { - bail!("Named argument expressions not yet supported") - } - spark_expr::ExprType::MergeAction(_) => bail!("Merge action expressions not yet supported"), - spark_expr::ExprType::TypedAggregateExpression(_) => { - bail!("Typed aggregate expressions not yet supported") - } - spark_expr::ExprType::Extension(_) => bail!("Extension expressions not yet supported"), - } -} diff --git a/src/daft-connect/src/translation/expr/unresolved_function.rs b/src/daft-connect/src/translation/expr/unresolved_function.rs deleted file mode 100644 index fc1e2dcba6..0000000000 --- a/src/daft-connect/src/translation/expr/unresolved_function.rs +++ /dev/null @@ -1,128 +0,0 @@ -use daft_core::count_mode::CountMode; -use eyre::{bail, Context}; -use spark_connect::expression::UnresolvedFunction; - -use crate::translation::to_daft_expr; - -pub fn unresolved_to_daft_expr(f: &UnresolvedFunction) -> eyre::Result { - let UnresolvedFunction { - function_name, - arguments, - is_distinct, - is_user_defined_function, - } = f; - - let arguments: Vec<_> = arguments.iter().map(to_daft_expr).try_collect()?; - - if *is_distinct { - bail!("Distinct not yet supported"); - } - - if *is_user_defined_function { - bail!("User-defined functions not yet supported"); - } - - match function_name.as_str() { - "%" => handle_binary_op(arguments, daft_dsl::Operator::Modulus), - "<" => handle_binary_op(arguments, daft_dsl::Operator::Lt), - "<=" => handle_binary_op(arguments, daft_dsl::Operator::LtEq), - "==" => handle_binary_op(arguments, daft_dsl::Operator::Eq), - ">" => handle_binary_op(arguments, daft_dsl::Operator::Gt), - ">=" => handle_binary_op(arguments, daft_dsl::Operator::GtEq), - "count" => handle_count(arguments), - "isnotnull" => handle_isnotnull(arguments), - "isnull" => handle_isnull(arguments), - "not" => not(arguments), - "sum" => handle_sum(arguments), - n => bail!("Unresolved function {n:?} not yet supported"), - } - .wrap_err_with(|| format!("Failed to handle function {function_name:?}")) -} - -pub fn handle_sum(arguments: Vec) -> eyre::Result { - let arguments: [daft_dsl::ExprRef; 1] = match arguments.try_into() { - Ok(arguments) => arguments, - Err(arguments) => { - bail!("requires exactly one argument; got {arguments:?}"); - } - }; - - let [arg] = arguments; - Ok(arg.sum()) -} - -/// If the arguments are exactly one, return it. Otherwise, return an error. -pub fn to_single(arguments: Vec) -> eyre::Result { - let arguments: [daft_dsl::ExprRef; 1] = match arguments.try_into() { - Ok(arguments) => arguments, - Err(arguments) => { - bail!("requires exactly one argument; got {arguments:?}"); - } - }; - - let [arg] = arguments; - - Ok(arg) -} - -pub fn not(arguments: Vec) -> eyre::Result { - let arg = to_single(arguments)?; - Ok(arg.not()) -} - -pub fn handle_binary_op( - arguments: Vec, - op: daft_dsl::Operator, -) -> eyre::Result { - let arguments: [daft_dsl::ExprRef; 2] = match arguments.try_into() { - Ok(arguments) => arguments, - Err(arguments) => { - bail!("requires exactly two arguments; got {arguments:?}"); - } - }; - - let [left, right] = arguments; - - Ok(daft_dsl::binary_op(op, left, right)) -} - -pub fn handle_count(arguments: Vec) -> eyre::Result { - let arguments: [daft_dsl::ExprRef; 1] = match arguments.try_into() { - Ok(arguments) => arguments, - Err(arguments) => { - bail!("requires exactly one argument; got {arguments:?}"); - } - }; - - let [arg] = arguments; - - let count = arg.count(CountMode::All); - - Ok(count) -} - -pub fn handle_isnull(arguments: Vec) -> eyre::Result { - let arguments: [daft_dsl::ExprRef; 1] = match arguments.try_into() { - Ok(arguments) => arguments, - Err(arguments) => { - bail!("requires exactly one argument; got {arguments:?}"); - } - }; - - let [arg] = arguments; - - Ok(arg.is_null()) -} - -pub fn handle_isnotnull(arguments: Vec) -> eyre::Result { - let arguments: [daft_dsl::ExprRef; 1] = match arguments.try_into() { - Ok(arguments) => arguments, - Err(arguments) => { - bail!("requires exactly one argument; got {arguments:?}"); - } - }; - - let [arg] = arguments; - - Ok(arg.not_null()) -} diff --git a/src/daft-connect/src/translation/logical_plan.rs b/src/daft-connect/src/translation/logical_plan.rs deleted file mode 100644 index 5bf831756e..0000000000 --- a/src/daft-connect/src/translation/logical_plan.rs +++ /dev/null @@ -1,201 +0,0 @@ -use std::sync::Arc; - -use common_daft_config::DaftExecutionConfig; -use daft_core::prelude::Schema; -use daft_dsl::LiteralValue; -use daft_local_execution::NativeExecutor; -use daft_logical_plan::LogicalPlanBuilder; -use daft_micropartition::{ - partitioning::{ - InMemoryPartitionSetCache, MicroPartitionSet, PartitionCacheEntry, PartitionMetadata, - PartitionSet, PartitionSetCache, - }, - MicroPartition, -}; -use daft_table::Table; -use eyre::{bail, Context}; -use futures::TryStreamExt; -use spark_connect::{relation::RelType, Limit, Relation, ShowString}; -use tracing::warn; - -mod aggregate; -mod drop; -mod filter; -mod local_relation; -mod project; -mod range; -mod read; -mod to_df; -mod with_columns; -mod with_columns_renamed; - -pub struct SparkAnalyzer<'a> { - pub psets: &'a InMemoryPartitionSetCache, -} - -impl SparkAnalyzer<'_> { - pub fn new(pset: &InMemoryPartitionSetCache) -> SparkAnalyzer { - SparkAnalyzer { psets: pset } - } - pub fn create_in_memory_scan( - &self, - plan_id: usize, - schema: Arc, - tables: Vec
, - ) -> eyre::Result { - let partition_key = uuid::Uuid::new_v4().to_string(); - - let pset = Arc::new(MicroPartitionSet::from_tables(plan_id, tables)?); - - let PartitionMetadata { - num_rows, - size_bytes, - } = pset.metadata(); - let num_partitions = pset.num_partitions(); - - self.psets.put_partition_set(&partition_key, &pset); - - let cache_entry = PartitionCacheEntry::new_rust(partition_key.clone(), pset); - - Ok(LogicalPlanBuilder::in_memory_scan( - &partition_key, - cache_entry, - schema, - num_partitions, - size_bytes, - num_rows, - )?) - } - - pub async fn to_logical_plan(&self, relation: Relation) -> eyre::Result { - let Some(common) = relation.common else { - bail!("Common metadata is required"); - }; - - if common.origin.is_some() { - warn!("Ignoring common metadata for relation: {common:?}; not yet implemented"); - } - - let Some(rel_type) = relation.rel_type else { - bail!("Relation type is required"); - }; - - match rel_type { - RelType::Limit(l) => self - .limit(*l) - .await - .wrap_err("Failed to apply limit to logical plan"), - RelType::Range(r) => self - .range(r) - .wrap_err("Failed to apply range to logical plan"), - RelType::Project(p) => self - .project(*p) - .await - .wrap_err("Failed to apply project to logical plan"), - RelType::Aggregate(a) => self - .aggregate(*a) - .await - .wrap_err("Failed to apply aggregate to logical plan"), - RelType::WithColumns(w) => self - .with_columns(*w) - .await - .wrap_err("Failed to apply with_columns to logical plan"), - RelType::ToDf(t) => self - .to_df(*t) - .await - .wrap_err("Failed to apply to_df to logical plan"), - RelType::LocalRelation(l) => { - let Some(plan_id) = common.plan_id else { - bail!("Plan ID is required for LocalRelation"); - }; - self.local_relation(plan_id, l) - .wrap_err("Failed to apply local_relation to logical plan") - } - RelType::WithColumnsRenamed(w) => self - .with_columns_renamed(*w) - .await - .wrap_err("Failed to apply with_columns_renamed to logical plan"), - RelType::Read(r) => read::read(r) - .await - .wrap_err("Failed to apply read to logical plan"), - RelType::Drop(d) => self - .drop(*d) - .await - .wrap_err("Failed to apply drop to logical plan"), - RelType::Filter(f) => self - .filter(*f) - .await - .wrap_err("Failed to apply filter to logical plan"), - RelType::ShowString(ss) => { - let Some(plan_id) = common.plan_id else { - bail!("Plan ID is required for LocalRelation"); - }; - self.show_string(plan_id, *ss) - .await - .wrap_err("Failed to show string") - } - plan => bail!("Unsupported relation type: {plan:?}"), - } - } - - async fn limit(&self, limit: Limit) -> eyre::Result { - let Limit { input, limit } = limit; - - let Some(input) = input else { - bail!("input must be set"); - }; - - let plan = Box::pin(self.to_logical_plan(*input)).await?; - - plan.limit(i64::from(limit), false) - .wrap_err("Failed to apply limit to logical plan") - } - - /// right now this just naively applies a limit to the logical plan - /// In the future, we want this to more closely match our daft implementation - async fn show_string( - &self, - plan_id: i64, - show_string: ShowString, - ) -> eyre::Result { - let ShowString { - input, - num_rows, - truncate: _, - vertical, - } = show_string; - - if vertical { - bail!("Vertical show string is not supported"); - } - - let Some(input) = input else { - bail!("input must be set"); - }; - - let plan = Box::pin(self.to_logical_plan(*input)).await?; - let plan = plan.limit(num_rows as i64, true)?; - - let optimized_plan = tokio::task::spawn_blocking(move || plan.optimize()) - .await - .unwrap()?; - - let cfg = Arc::new(DaftExecutionConfig::default()); - let native_executor = NativeExecutor::from_logical_plan_builder(&optimized_plan)?; - let result_stream = native_executor.run(self.psets, cfg, None)?.into_stream(); - let batch = result_stream.try_collect::>().await?; - let single_batch = MicroPartition::concat(batch)?; - let tbls = single_batch.get_tables()?; - let tbl = Table::concat(&tbls)?; - let output = tbl.to_comfy_table(None).to_string(); - - let s = LiteralValue::Utf8(output) - .into_single_value_series()? - .rename("show_string"); - - let tbl = Table::from_nonempty_columns(vec![s])?; - let schema = tbl.schema.clone(); - - self.create_in_memory_scan(plan_id as _, schema, vec![tbl]) - } -} diff --git a/src/daft-connect/src/translation/logical_plan/aggregate.rs b/src/daft-connect/src/translation/logical_plan/aggregate.rs deleted file mode 100644 index 2a46b0cbba..0000000000 --- a/src/daft-connect/src/translation/logical_plan/aggregate.rs +++ /dev/null @@ -1,78 +0,0 @@ -use daft_logical_plan::LogicalPlanBuilder; -use eyre::{bail, WrapErr}; -use spark_connect::aggregate::GroupType; - -use super::SparkAnalyzer; -use crate::translation::to_daft_expr; - -impl SparkAnalyzer<'_> { - pub async fn aggregate( - &self, - aggregate: spark_connect::Aggregate, - ) -> eyre::Result { - let spark_connect::Aggregate { - input, - group_type, - grouping_expressions, - aggregate_expressions, - pivot, - grouping_sets, - } = aggregate; - - let Some(input) = input else { - bail!("input is required"); - }; - - let mut plan = Box::pin(self.to_logical_plan(*input)).await?; - - let group_type = GroupType::try_from(group_type) - .wrap_err_with(|| format!("Invalid group type: {group_type:?}"))?; - - assert_groupby(group_type)?; - - if let Some(pivot) = pivot { - bail!("Pivot not yet supported; got {pivot:?}"); - } - - if !grouping_sets.is_empty() { - bail!("Grouping sets not yet supported; got {grouping_sets:?}"); - } - - let grouping_expressions: Vec<_> = grouping_expressions - .iter() - .map(to_daft_expr) - .try_collect()?; - - let aggregate_expressions: Vec<_> = aggregate_expressions - .iter() - .map(to_daft_expr) - .try_collect()?; - - plan = plan - .aggregate(aggregate_expressions.clone(), grouping_expressions.clone()) - .wrap_err_with(|| format!("Failed to apply aggregate to logical plan aggregate_expressions={aggregate_expressions:?} grouping_expressions={grouping_expressions:?}"))?; - - Ok(plan) - } -} - -fn assert_groupby(plan: GroupType) -> eyre::Result<()> { - match plan { - GroupType::Unspecified => { - bail!("GroupType must be specified; got Unspecified") - } - GroupType::Groupby => Ok(()), - GroupType::Rollup => { - bail!("Rollup not yet supported") - } - GroupType::Cube => { - bail!("Cube not yet supported") - } - GroupType::Pivot => { - bail!("Pivot not yet supported") - } - GroupType::GroupingSets => { - bail!("GroupingSets not yet supported") - } - } -} diff --git a/src/daft-connect/src/translation/logical_plan/drop.rs b/src/daft-connect/src/translation/logical_plan/drop.rs deleted file mode 100644 index b5cac5a41b..0000000000 --- a/src/daft-connect/src/translation/logical_plan/drop.rs +++ /dev/null @@ -1,40 +0,0 @@ -use daft_logical_plan::LogicalPlanBuilder; -use eyre::bail; - -use super::SparkAnalyzer; - -impl SparkAnalyzer<'_> { - pub async fn drop(&self, drop: spark_connect::Drop) -> eyre::Result { - let spark_connect::Drop { - input, - columns, - column_names, - } = drop; - - let Some(input) = input else { - bail!("input is required"); - }; - - if !columns.is_empty() { - bail!("columns is not supported; use column_names instead"); - } - - let plan = Box::pin(self.to_logical_plan(*input)).await?; - - // Get all column names from the schema - let all_columns = plan.schema().names(); - - // Create a set of columns to drop for efficient lookup - let columns_to_drop: std::collections::HashSet<_> = column_names.iter().collect(); - - // Create expressions for all columns except the ones being dropped - let to_select = all_columns - .iter() - .filter(|col_name| !columns_to_drop.contains(*col_name)) - .map(|col_name| daft_dsl::col(col_name.clone())) - .collect(); - - // Use select to keep only the columns we want - Ok(plan.select(to_select)?) - } -} diff --git a/src/daft-connect/src/translation/logical_plan/filter.rs b/src/daft-connect/src/translation/logical_plan/filter.rs deleted file mode 100644 index 43ad4c7a52..0000000000 --- a/src/daft-connect/src/translation/logical_plan/filter.rs +++ /dev/null @@ -1,24 +0,0 @@ -use daft_logical_plan::LogicalPlanBuilder; -use eyre::bail; - -use super::SparkAnalyzer; -use crate::translation::to_daft_expr; - -impl SparkAnalyzer<'_> { - pub async fn filter(&self, filter: spark_connect::Filter) -> eyre::Result { - let spark_connect::Filter { input, condition } = filter; - - let Some(input) = input else { - bail!("input is required"); - }; - - let Some(condition) = condition else { - bail!("condition is required"); - }; - - let condition = to_daft_expr(&condition)?; - - let plan = Box::pin(self.to_logical_plan(*input)).await?; - Ok(plan.filter(condition)?) - } -} diff --git a/src/daft-connect/src/translation/logical_plan/local_relation.rs b/src/daft-connect/src/translation/logical_plan/local_relation.rs deleted file mode 100644 index bdb7c01a16..0000000000 --- a/src/daft-connect/src/translation/logical_plan/local_relation.rs +++ /dev/null @@ -1,67 +0,0 @@ -use std::{io::Cursor, sync::Arc}; - -use arrow2::io::ipc::read::{read_stream_metadata, StreamReader, StreamState}; -use daft_core::{prelude::Schema, series::Series}; -use daft_logical_plan::LogicalPlanBuilder; -use daft_table::Table; -use eyre::{bail, WrapErr}; -use itertools::zip_eq; - -use super::SparkAnalyzer; - -impl SparkAnalyzer<'_> { - pub fn local_relation( - &self, - plan_id: i64, - plan: spark_connect::LocalRelation, - ) -> eyre::Result { - // We can ignore spark schema. The true schema is sent in the - // arrow data. (see read_stream_metadata) - let spark_connect::LocalRelation { data, schema: _ } = plan; - - let Some(data) = data else { - bail!("Data is required but was not provided in the LocalRelation plan.") - }; - - let mut reader = Cursor::new(&data); - let metadata = read_stream_metadata(&mut reader)?; - - let arrow_schema = metadata.schema.clone(); - let daft_schema = Arc::new( - Schema::try_from(&arrow_schema) - .wrap_err("Failed to convert Arrow schema to Daft schema.")?, - ); - - let reader = StreamReader::new(reader, metadata, None); - - let tables = reader.into_iter().map(|ss| { - let ss = ss.wrap_err("Failed to read next chunk from StreamReader.")?; - - let chunk = match ss { - StreamState::Some(chunk) => chunk, - StreamState::Waiting => { - bail!("StreamReader is waiting for data, but a chunk was expected. This likely indicates that the spark provided data is incomplete.") - } - }; - - - let arrays = chunk.into_arrays(); - let columns = zip_eq(arrays, &arrow_schema.fields) - .map(|(array, arrow_field)| { - let field = Arc::new(arrow_field.into()); - - let series = Series::from_arrow(field, array) - .wrap_err("Failed to create Series from Arrow array.")?; - - Ok(series) - }) - .collect::>>()?; - - let batch = Table::from_nonempty_columns(columns)?; - - Ok(batch) - }).collect::>>()?; - - self.create_in_memory_scan(plan_id as _, daft_schema, tables) - } -} diff --git a/src/daft-connect/src/translation/logical_plan/project.rs b/src/daft-connect/src/translation/logical_plan/project.rs deleted file mode 100644 index 448242d31d..0000000000 --- a/src/daft-connect/src/translation/logical_plan/project.rs +++ /dev/null @@ -1,28 +0,0 @@ -//! Project operation for selecting and manipulating columns from a dataset -//! -//! TL;DR: Project is Spark's equivalent of SQL SELECT - it selects columns, renames them via aliases, -//! and creates new columns from expressions. Example: `df.select(col("id").alias("my_number"))` - -use daft_logical_plan::LogicalPlanBuilder; -use eyre::bail; -use spark_connect::Project; - -use super::SparkAnalyzer; -use crate::translation::to_daft_expr; - -impl SparkAnalyzer<'_> { - pub async fn project(&self, project: Project) -> eyre::Result { - let Project { input, expressions } = project; - - let Some(input) = input else { - bail!("Project input is required"); - }; - - let mut plan = Box::pin(self.to_logical_plan(*input)).await?; - - let daft_exprs: Vec<_> = expressions.iter().map(to_daft_expr).try_collect()?; - plan = plan.select(daft_exprs)?; - - Ok(plan) - } -} diff --git a/src/daft-connect/src/translation/logical_plan/range.rs b/src/daft-connect/src/translation/logical_plan/range.rs deleted file mode 100644 index c1ec7197ad..0000000000 --- a/src/daft-connect/src/translation/logical_plan/range.rs +++ /dev/null @@ -1,62 +0,0 @@ -use daft_logical_plan::LogicalPlanBuilder; -use eyre::{ensure, Context}; -use spark_connect::Range; - -use super::SparkAnalyzer; - -impl SparkAnalyzer<'_> { - pub fn range(&self, range: Range) -> eyre::Result { - #[cfg(not(feature = "python"))] - { - use eyre::bail; - bail!("Range operations require Python feature to be enabled"); - } - - #[cfg(feature = "python")] - { - use daft_scan::python::pylib::ScanOperatorHandle; - use pyo3::prelude::*; - let Range { - start, - end, - step, - num_partitions, - } = range; - - let partitions = num_partitions.unwrap_or(1); - - ensure!(partitions > 0, "num_partitions must be greater than 0"); - - let start = start.unwrap_or(0); - - let step = usize::try_from(step).wrap_err("step must be a positive integer")?; - ensure!(step > 0, "step must be greater than 0"); - - let plan = Python::with_gil(|py| { - let range_module = PyModule::import(py, "daft.io._range") - .wrap_err("Failed to import range module")?; - - let range = range_module - .getattr(pyo3::intern!(py, "RangeScanOperator")) - .wrap_err("Failed to get range function")?; - - let range = range - .call1((start, end, step, partitions)) - .wrap_err("Failed to create range scan operator")? - .into_pyobject(py) - .unwrap() - .unbind(); - - let scan_operator_handle = - ScanOperatorHandle::from_python_scan_operator(range, py)?; - - let plan = LogicalPlanBuilder::table_scan(scan_operator_handle.into(), None)?; - - eyre::Result::<_>::Ok(plan) - }) - .wrap_err("Failed to create range scan")?; - - Ok(plan) - } - } -} diff --git a/src/daft-connect/src/translation/logical_plan/read.rs b/src/daft-connect/src/translation/logical_plan/read.rs deleted file mode 100644 index 9a73783191..0000000000 --- a/src/daft-connect/src/translation/logical_plan/read.rs +++ /dev/null @@ -1,31 +0,0 @@ -use daft_logical_plan::LogicalPlanBuilder; -use eyre::{bail, WrapErr}; -use spark_connect::read::ReadType; -use tracing::warn; - -mod data_source; - -pub async fn read(read: spark_connect::Read) -> eyre::Result { - let spark_connect::Read { - is_streaming, - read_type, - } = read; - - warn!("Ignoring is_streaming: {is_streaming}"); - - let Some(read_type) = read_type else { - bail!("Read type is required"); - }; - - let builder = match read_type { - ReadType::NamedTable(table) => { - let name = table.unparsed_identifier; - bail!("Tried to read from table {name} but it is not yet implemented. Try to read from a path instead."); - } - ReadType::DataSource(source) => data_source::data_source(source) - .await - .wrap_err("Failed to create data source"), - }?; - - Ok(builder) -} diff --git a/src/daft-connect/src/translation/logical_plan/read/data_source.rs b/src/daft-connect/src/translation/logical_plan/read/data_source.rs deleted file mode 100644 index 863b5e8f1d..0000000000 --- a/src/daft-connect/src/translation/logical_plan/read/data_source.rs +++ /dev/null @@ -1,54 +0,0 @@ -use daft_logical_plan::LogicalPlanBuilder; -use daft_scan::builder::{CsvScanBuilder, ParquetScanBuilder}; -use eyre::{bail, ensure, WrapErr}; -use tracing::warn; - -pub async fn data_source( - data_source: spark_connect::read::DataSource, -) -> eyre::Result { - let spark_connect::read::DataSource { - format, - schema, - options, - paths, - predicates, - } = data_source; - - let Some(format) = format else { - bail!("Format is required"); - }; - - ensure!(!paths.is_empty(), "Paths are required"); - - if let Some(schema) = schema { - warn!("Ignoring schema: {schema:?}; not yet implemented"); - } - - if !options.is_empty() { - warn!("Ignoring options: {options:?}; not yet implemented"); - } - - if !predicates.is_empty() { - warn!("Ignoring predicates: {predicates:?}; not yet implemented"); - } - - let plan = match &*format { - "parquet" => ParquetScanBuilder::new(paths) - .finish() - .await - .wrap_err("Failed to create parquet scan builder")?, - "csv" => CsvScanBuilder::new(paths) - .finish() - .await - .wrap_err("Failed to create csv scan builder")?, - "json" => { - // todo(completeness): implement json reading - bail!("json reading is not yet implemented"); - } - other => { - bail!("Unsupported format: {other}; only parquet and csv are supported"); - } - }; - - Ok(plan) -} diff --git a/src/daft-connect/src/translation/logical_plan/to_df.rs b/src/daft-connect/src/translation/logical_plan/to_df.rs deleted file mode 100644 index e3d172661b..0000000000 --- a/src/daft-connect/src/translation/logical_plan/to_df.rs +++ /dev/null @@ -1,28 +0,0 @@ -use daft_logical_plan::LogicalPlanBuilder; -use eyre::{bail, WrapErr}; - -use super::SparkAnalyzer; -impl SparkAnalyzer<'_> { - pub async fn to_df(&self, to_df: spark_connect::ToDf) -> eyre::Result { - let spark_connect::ToDf { - input, - column_names, - } = to_df; - - let Some(input) = input else { - bail!("Input is required"); - }; - - let mut plan = Box::pin(self.to_logical_plan(*input)).await?; - - let column_names: Vec<_> = column_names - .iter() - .map(|s| daft_dsl::col(s.as_str())) - .collect(); - - plan = plan - .select(column_names) - .wrap_err("Failed to add columns to logical plan")?; - Ok(plan) - } -} diff --git a/src/daft-connect/src/translation/logical_plan/with_columns.rs b/src/daft-connect/src/translation/logical_plan/with_columns.rs deleted file mode 100644 index 97b3c3d1d1..0000000000 --- a/src/daft-connect/src/translation/logical_plan/with_columns.rs +++ /dev/null @@ -1,35 +0,0 @@ -use daft_logical_plan::LogicalPlanBuilder; -use eyre::bail; -use spark_connect::{expression::ExprType, Expression}; - -use super::SparkAnalyzer; -use crate::translation::to_daft_expr; - -impl SparkAnalyzer<'_> { - pub async fn with_columns( - &self, - with_columns: spark_connect::WithColumns, - ) -> eyre::Result { - let spark_connect::WithColumns { input, aliases } = with_columns; - - let Some(input) = input else { - bail!("input is required"); - }; - - let plan = Box::pin(self.to_logical_plan(*input)).await?; - - let daft_exprs: Vec<_> = aliases - .into_iter() - .map(|alias| { - let expression = Expression { - common: None, - expr_type: Some(ExprType::Alias(Box::new(alias))), - }; - - to_daft_expr(&expression) - }) - .try_collect()?; - - Ok(plan.with_columns(daft_exprs)?) - } -} diff --git a/src/daft-connect/src/translation/logical_plan/with_columns_renamed.rs b/src/daft-connect/src/translation/logical_plan/with_columns_renamed.rs deleted file mode 100644 index 856a7214fc..0000000000 --- a/src/daft-connect/src/translation/logical_plan/with_columns_renamed.rs +++ /dev/null @@ -1,48 +0,0 @@ -use daft_dsl::col; -use daft_logical_plan::LogicalPlanBuilder; -use eyre::{bail, Context}; - -use crate::translation::SparkAnalyzer; - -impl SparkAnalyzer<'_> { - pub async fn with_columns_renamed( - &self, - with_columns_renamed: spark_connect::WithColumnsRenamed, - ) -> eyre::Result { - let spark_connect::WithColumnsRenamed { - input, - rename_columns_map, - renames, - } = with_columns_renamed; - - let Some(input) = input else { - bail!("Input is required"); - }; - - let plan = Box::pin(self.to_logical_plan(*input)).await?; - - // todo: let's implement this directly into daft - - // Convert the rename mappings into expressions - let rename_exprs = if !rename_columns_map.is_empty() { - // Use rename_columns_map if provided (legacy format) - rename_columns_map - .into_iter() - .map(|(old_name, new_name)| col(old_name.as_str()).alias(new_name.as_str())) - .collect() - } else { - // Use renames if provided (new format) - renames - .into_iter() - .map(|rename| col(rename.col_name.as_str()).alias(rename.new_col_name.as_str())) - .collect() - }; - - // Apply the rename expressions to the plan - let plan = plan - .select(rename_exprs) - .wrap_err("Failed to apply rename expressions to logical plan")?; - - Ok(plan) - } -} diff --git a/src/daft-connect/src/translation/schema.rs b/src/daft-connect/src/translation/schema.rs deleted file mode 100644 index 0cbd3cd7a1..0000000000 --- a/src/daft-connect/src/translation/schema.rs +++ /dev/null @@ -1,54 +0,0 @@ -use daft_micropartition::partitioning::InMemoryPartitionSetCache; -use daft_schema::schema::SchemaRef; -use spark_connect::{ - data_type::{Kind, Struct, StructField}, - DataType, Relation, -}; -use tracing::warn; - -use super::SparkAnalyzer; -use crate::translation::to_spark_datatype; - -#[tracing::instrument(skip_all)] -pub async fn relation_to_spark_schema(input: Relation) -> eyre::Result { - let result = relation_to_daft_schema(input).await?; - - let fields: eyre::Result> = result - .fields - .iter() - .map(|(name, field)| { - let field_type = to_spark_datatype(&field.dtype); - Ok(StructField { - name: name.clone(), // todo(correctness): name vs field.name... will they always be the same? - data_type: Some(field_type), - nullable: true, // todo(correctness): is this correct? - metadata: None, // todo(completeness): might want to add metadata here - }) - }) - .collect(); - - Ok(DataType { - kind: Some(Kind::Struct(Struct { - fields: fields?, - type_variation_reference: 0, - })), - }) -} - -#[tracing::instrument(skip_all)] -pub async fn relation_to_daft_schema(input: Relation) -> eyre::Result { - if let Some(common) = &input.common { - if common.origin.is_some() { - warn!("Ignoring common metadata for relation: {common:?}; not yet implemented"); - } - } - - // We're just checking the schema here, so we don't need to use a persistent cache as it won't be used - let pset = InMemoryPartitionSetCache::empty(); - let translator = SparkAnalyzer::new(&pset); - let plan = Box::pin(translator.to_logical_plan(input)).await?; - - let result = plan.schema(); - - Ok(result) -} diff --git a/src/daft-connect/src/util.rs b/src/daft-connect/src/util.rs index cbec2211b2..8ebdb79903 100644 --- a/src/daft-connect/src/util.rs +++ b/src/daft-connect/src/util.rs @@ -1,5 +1,7 @@ use std::net::ToSocketAddrs; +use tonic::Status; + pub fn parse_spark_connect_address(addr: &str) -> eyre::Result { // Check if address starts with "sc://" if !addr.starts_with("sc://") { @@ -19,6 +21,26 @@ pub fn parse_spark_connect_address(addr: &str) -> eyre::Result { + /// Converts an optional protobuf field to a different type, returning an + /// error if None. + fn required(self, field: impl Into) -> Result; +} + +impl FromOptionalField for Option { + fn required(self, field: impl Into) -> Result { + match self { + None => Err(Status::internal(format!( + "Required field '{}' is missing", + field.into() + ))), + Some(t) => Ok(t), + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/daft-core/src/array/ops/binary.rs b/src/daft-core/src/array/ops/binary.rs new file mode 100644 index 0000000000..6f1ed25556 --- /dev/null +++ b/src/daft-core/src/array/ops/binary.rs @@ -0,0 +1,308 @@ +use std::iter; + +use arrow2::bitmap::utils::{BitmapIter, ZipValidity}; +use common_error::{DaftError, DaftResult}; + +use crate::{ + array::ops::as_arrow::AsArrow, + datatypes::{ + BinaryArray, DaftIntegerType, DaftNumericType, DataArray, FixedSizeBinaryArray, UInt64Array, + }, +}; + +enum BroadcastedBinaryIter<'a> { + Repeat(std::iter::Take>>), + NonRepeat( + ZipValidity< + &'a [u8], + arrow2::array::ArrayValuesIter<'a, arrow2::array::BinaryArray>, + arrow2::bitmap::utils::BitmapIter<'a>, + >, + ), +} + +enum BroadcastedFixedSizeBinaryIter<'a> { + Repeat(std::iter::Take>>), + NonRepeat(ZipValidity<&'a [u8], std::slice::ChunksExact<'a, u8>, BitmapIter<'a>>), +} + +enum BroadcastedNumericIter<'a, T: 'a, U> +where + T: DaftIntegerType, + T::Native: TryInto + Ord, +{ + Repeat( + std::iter::Take::Native>>>, + std::marker::PhantomData, + ), + NonRepeat( + ZipValidity< + &'a ::Native, + std::slice::Iter<'a, ::Native>, + BitmapIter<'a>, + >, + std::marker::PhantomData, + ), +} + +impl<'a> Iterator for BroadcastedFixedSizeBinaryIter<'a> { + type Item = Option<&'a [u8]>; + + fn next(&mut self) -> Option { + match self { + BroadcastedFixedSizeBinaryIter::Repeat(iter) => iter.next(), + BroadcastedFixedSizeBinaryIter::NonRepeat(iter) => iter.next(), + } + } +} + +impl<'a> Iterator for BroadcastedBinaryIter<'a> { + type Item = Option<&'a [u8]>; + + fn next(&mut self) -> Option { + match self { + BroadcastedBinaryIter::Repeat(iter) => iter.next(), + BroadcastedBinaryIter::NonRepeat(iter) => iter.next(), + } + } +} + +impl<'a, T: 'a, U> Iterator for BroadcastedNumericIter<'a, T, U> +where + T: DaftIntegerType + Clone, + T::Native: TryInto + Ord, +{ + type Item = DaftResult>; + + fn next(&mut self) -> Option { + match self { + BroadcastedNumericIter::Repeat(iter, _) => iter.next().map(|x| { + x.map(|x| { + x.try_into().map_err(|_| { + DaftError::ComputeError( + "Failed to cast numeric value to target type".to_string(), + ) + }) + }) + .transpose() + }), + BroadcastedNumericIter::NonRepeat(iter, _) => iter.next().map(|x| { + x.map(|x| { + (*x).try_into().map_err(|_| { + DaftError::ComputeError( + "Failed to cast numeric value to target type".to_string(), + ) + }) + }) + .transpose() + }), + } + } +} + +fn create_broadcasted_binary_iter(arr: &BinaryArray, len: usize) -> BroadcastedBinaryIter<'_> { + if arr.len() == 1 { + BroadcastedBinaryIter::Repeat(std::iter::repeat(arr.as_arrow().get(0)).take(len)) + } else { + BroadcastedBinaryIter::NonRepeat(arr.as_arrow().iter()) + } +} + +fn create_broadcasted_fixed_size_binary_iter( + arr: &FixedSizeBinaryArray, + len: usize, +) -> BroadcastedFixedSizeBinaryIter<'_> { + if arr.len() == 1 { + BroadcastedFixedSizeBinaryIter::Repeat(iter::repeat(arr.as_arrow().get(0)).take(len)) + } else { + BroadcastedFixedSizeBinaryIter::NonRepeat(arr.as_arrow().iter()) + } +} + +fn create_broadcasted_numeric_iter( + arr: &DataArray, + len: usize, +) -> BroadcastedNumericIter +where + T: DaftIntegerType, + T::Native: TryInto + Ord, +{ + if arr.len() == 1 { + BroadcastedNumericIter::Repeat( + iter::repeat(arr.as_arrow().get(0)).take(len), + std::marker::PhantomData, + ) + } else { + let x = arr.as_arrow().iter(); + BroadcastedNumericIter::NonRepeat(x, std::marker::PhantomData) + } +} + +impl BinaryArray { + pub fn length(&self) -> DaftResult { + let self_arrow = self.as_arrow(); + let offsets = self_arrow.offsets(); + let arrow_result = arrow2::array::UInt64Array::from_iter( + offsets.windows(2).map(|w| Some((w[1] - w[0]) as u64)), + ) + .with_validity(self_arrow.validity().cloned()); + Ok(UInt64Array::from((self.name(), Box::new(arrow_result)))) + } + + pub fn binary_concat(&self, other: &Self) -> DaftResult { + let self_arrow = self.as_arrow(); + let other_arrow = other.as_arrow(); + + if self_arrow.len() == 0 || other_arrow.len() == 0 { + return Ok(Self::from(( + self.name(), + Box::new(arrow2::array::BinaryArray::::new_empty( + self_arrow.data_type().clone(), + )), + ))); + } + + let output_len = if self_arrow.len() == 1 || other_arrow.len() == 1 { + std::cmp::max(self_arrow.len(), other_arrow.len()) + } else { + self_arrow.len() + }; + + let self_iter = create_broadcasted_binary_iter(self, output_len); + let other_iter = create_broadcasted_binary_iter(other, output_len); + + let arrow_result = self_iter + .zip(other_iter) + .map(|(left_val, right_val)| match (left_val, right_val) { + (Some(left), Some(right)) => Some([left, right].concat()), + _ => None, + }) + .collect::>(); + + Ok(Self::from((self.name(), Box::new(arrow_result)))) + } + + pub fn binary_slice( + &self, + start: &DataArray, + length: Option<&DataArray>, + ) -> DaftResult + where + I: DaftIntegerType, + ::Native: Ord + TryInto, + J: DaftIntegerType, + ::Native: Ord + TryInto, + { + let self_arrow = self.as_arrow(); + let output_len = if self_arrow.len() == 1 { + std::cmp::max(start.len(), length.map_or(1, |l| l.len())) + } else { + self_arrow.len() + }; + + let self_iter = create_broadcasted_binary_iter(self, output_len); + let start_iter = create_broadcasted_numeric_iter::(start, output_len); + + let mut builder = arrow2::array::MutableBinaryArray::::new(); + + let arrow_result = match length { + Some(length) => { + let length_iter = create_broadcasted_numeric_iter::(length, output_len); + + for ((val, start), length) in self_iter.zip(start_iter).zip(length_iter) { + match (val, start?, length?) { + (Some(val), Some(start), Some(length)) => { + if start >= val.len() || length == 0 { + builder.push(Some(&[])); + } else { + let end = (start + length).min(val.len()); + let slice = &val[start..end]; + builder.push(Some(slice)); + } + } + _ => { + builder.push::<&[u8]>(None); + } + } + } + builder.into() + } + None => { + for (val, start) in self_iter.zip(start_iter) { + match (val, start?) { + (Some(val), Some(start)) => { + if start >= val.len() { + builder.push(Some(&[])); + } else { + let slice = &val[start..]; + builder.push(Some(slice)); + } + } + _ => { + builder.push::<&[u8]>(None); + } + } + } + builder.into() + } + }; + + Ok(Self::from((self.name(), Box::new(arrow_result)))) + } +} + +impl FixedSizeBinaryArray { + pub fn length(&self) -> DaftResult { + let self_arrow = self.as_arrow(); + let size = self_arrow.size(); + let arrow_result = arrow2::array::UInt64Array::from_iter( + iter::repeat(Some(size as u64)).take(self_arrow.len()), + ) + .with_validity(self_arrow.validity().cloned()); + Ok(UInt64Array::from((self.name(), Box::new(arrow_result)))) + } + + pub fn binary_concat(&self, other: &Self) -> std::result::Result { + let self_arrow = self.as_arrow(); + let other_arrow = other.as_arrow(); + let self_size = self_arrow.size(); + let other_size = other_arrow.size(); + let combined_size = self_size + other_size; + + // Create a new FixedSizeBinaryArray with the combined size + let mut values = Vec::with_capacity(self_arrow.len() * combined_size); + let mut validity = arrow2::bitmap::MutableBitmap::new(); + + let output_len = if self_arrow.len() == 1 || other_arrow.len() == 1 { + std::cmp::max(self_arrow.len(), other_arrow.len()) + } else { + self_arrow.len() + }; + + let self_iter = create_broadcasted_fixed_size_binary_iter(self, output_len); + let other_iter = create_broadcasted_fixed_size_binary_iter(other, output_len); + + for (val1, val2) in self_iter.zip(other_iter) { + match (val1, val2) { + (Some(val1), Some(val2)) => { + values.extend_from_slice(val1); + values.extend_from_slice(val2); + validity.push(true); + } + _ => { + values.extend(std::iter::repeat(0u8).take(combined_size)); + validity.push(false); + } + } + } + + // Create a new FixedSizeBinaryArray with the combined size + let result = arrow2::array::FixedSizeBinaryArray::try_new( + arrow2::datatypes::DataType::FixedSizeBinary(combined_size), + values.into(), + Some(validity.into()), + )?; + + Ok(Self::from((self.name(), Box::new(result)))) + } +} diff --git a/src/daft-core/src/array/ops/comparison.rs b/src/daft-core/src/array/ops/comparison.rs index 2b9f855286..f6f7d2b216 100644 --- a/src/daft-core/src/array/ops/comparison.rs +++ b/src/daft-core/src/array/ops/comparison.rs @@ -69,6 +69,83 @@ where } } + fn eq_null_safe(&self, rhs: &Self) -> Self::Output { + match (self.len(), rhs.len()) { + (x, y) if x == y => { + let l_validity = self.as_arrow().validity(); + let r_validity = rhs.as_arrow().validity(); + + let mut result_values = comparison::eq(self.as_arrow(), rhs.as_arrow()) + .values() + .clone(); + + match (l_validity, r_validity) { + (None, None) => {} + (None, Some(r_valid)) => { + result_values = arrow2::bitmap::and(&result_values, r_valid); + } + (Some(l_valid), None) => { + result_values = arrow2::bitmap::and(&result_values, l_valid); + } + (Some(l_valid), Some(r_valid)) => { + let nulls_match = arrow2::bitmap::bitwise_eq(l_valid, r_valid); + result_values = arrow2::bitmap::and(&result_values, &nulls_match); + } + } + + Ok(BooleanArray::from(( + self.name(), + arrow2::array::BooleanArray::new( + arrow2::datatypes::DataType::Boolean, + result_values, + None, + ), + ))) + } + (l_size, 1) => { + if let Some(value) = rhs.get(0) { + Ok(self.eq_null_safe(value)) + } else { + let result_values = match self.as_arrow().validity() { + None => arrow2::bitmap::Bitmap::new_zeroed(l_size), + Some(validity) => validity.not(), + }; + Ok(BooleanArray::from(( + self.name(), + arrow2::array::BooleanArray::new( + arrow2::datatypes::DataType::Boolean, + result_values, + None, + ), + ))) + } + } + (1, r_size) => { + if let Some(value) = self.get(0) { + Ok(rhs.eq_null_safe(value)) + } else { + let result_values = match rhs.as_arrow().validity() { + None => arrow2::bitmap::Bitmap::new_zeroed(r_size), + Some(validity) => validity.not(), + }; + Ok(BooleanArray::from(( + self.name(), + arrow2::array::BooleanArray::new( + arrow2::datatypes::DataType::Boolean, + result_values, + None, + ), + ))) + } + } + (l, r) => Err(DaftError::ValueError(format!( + "trying to compare different length arrays: {}: {l} vs {}: {r}", + self.name(), + rhs.name() + ))), + } + } + fn not_equal(&self, rhs: &Self) -> Self::Output { match (self.len(), rhs.len()) { (x, y) if x == y => { @@ -335,6 +412,30 @@ where NumCast::from(rhs).expect("could not cast to underlying DataArray type"); self.compare_to_scalar(rhs, comparison::gt_eq_scalar) } + + fn eq_null_safe(&self, rhs: Scalar) -> Self::Output { + let rhs: T::Native = + NumCast::from(rhs).expect("could not cast to underlying DataArray type"); + + let arrow_array = self.as_arrow(); + let scalar = PrimitiveScalar::new(arrow_array.data_type().clone(), Some(rhs)); + + let result_values = comparison::eq_scalar(arrow_array, &scalar).values().clone(); + + let final_values = match arrow_array.validity() { + None => result_values, + Some(valid) => arrow2::bitmap::and(&result_values, valid), + }; + + BooleanArray::from(( + self.name(), + arrow2::array::BooleanArray::new( + arrow2::datatypes::DataType::Boolean, + final_values, + None, + ), + )) + } } impl DaftCompare<&Self> for BooleanArray { @@ -531,6 +632,83 @@ impl DaftCompare<&Self> for BooleanArray { ))), } } + + fn eq_null_safe(&self, rhs: &Self) -> Self::Output { + match (self.len(), rhs.len()) { + (x, y) if x == y => { + let l_validity = self.as_arrow().validity(); + let r_validity = rhs.as_arrow().validity(); + + let mut result_values = comparison::eq(self.as_arrow(), rhs.as_arrow()) + .values() + .clone(); + + match (l_validity, r_validity) { + (None, None) => {} + (None, Some(r_valid)) => { + result_values = arrow2::bitmap::and(&result_values, r_valid); + } + (Some(l_valid), None) => { + result_values = arrow2::bitmap::and(&result_values, l_valid); + } + (Some(l_valid), Some(r_valid)) => { + let nulls_match = arrow2::bitmap::bitwise_eq(l_valid, r_valid); + result_values = arrow2::bitmap::and(&result_values, &nulls_match); + } + } + + Ok(Self::from(( + self.name(), + arrow2::array::BooleanArray::new( + arrow2::datatypes::DataType::Boolean, + result_values, + None, + ), + ))) + } + (l_size, 1) => { + if let Some(value) = rhs.get(0) { + Ok(self.eq_null_safe(value)?) + } else { + let result_values = match self.as_arrow().validity() { + None => arrow2::bitmap::Bitmap::new_zeroed(l_size), + Some(validity) => validity.not(), + }; + Ok(Self::from(( + self.name(), + arrow2::array::BooleanArray::new( + arrow2::datatypes::DataType::Boolean, + result_values, + None, + ), + ))) + } + } + (1, r_size) => { + if let Some(value) = self.get(0) { + Ok(rhs.eq_null_safe(value)?) + } else { + let result_values = match rhs.as_arrow().validity() { + None => arrow2::bitmap::Bitmap::new_zeroed(r_size), + Some(validity) => validity.not(), + }; + Ok(Self::from(( + self.name(), + arrow2::array::BooleanArray::new( + arrow2::datatypes::DataType::Boolean, + result_values, + None, + ), + ))) + } + } + (l, r) => Err(DaftError::ValueError(format!( + "trying to compare different length arrays: {}: {l} vs {}: {r}", + self.name(), + rhs.name() + ))), + } + } } impl DaftCompare for BooleanArray { @@ -583,6 +761,26 @@ impl DaftCompare for BooleanArray { Ok(Self::from((self.name(), arrow_result))) } + + fn eq_null_safe(&self, rhs: bool) -> Self::Output { + let result_values = comparison::boolean::eq_scalar(self.as_arrow(), rhs) + .values() + .clone(); + + let final_values = match self.as_arrow().validity() { + None => result_values, + Some(valid) => arrow2::bitmap::and(&result_values, valid), + }; + + Ok(Self::from(( + self.name(), + arrow2::array::BooleanArray::new( + arrow2::datatypes::DataType::Boolean, + final_values, + None, + ), + ))) + } } impl Not for &BooleanArray { @@ -819,6 +1017,7 @@ impl DaftCompare<&Self> for NullArray { null_array_comparison_method!(lte); null_array_comparison_method!(gt); null_array_comparison_method!(gte); + null_array_comparison_method!(eq_null_safe); } impl DaftLogical for BooleanArray { @@ -1104,6 +1303,83 @@ impl DaftCompare<&Self> for Utf8Array { ))), } } + + fn eq_null_safe(&self, rhs: &Self) -> Self::Output { + match (self.len(), rhs.len()) { + (x, y) if x == y => { + let l_validity = self.as_arrow().validity(); + let r_validity = rhs.as_arrow().validity(); + + let mut result_values = comparison::eq(self.as_arrow(), rhs.as_arrow()) + .values() + .clone(); + + match (l_validity, r_validity) { + (None, None) => {} + (None, Some(r_valid)) => { + result_values = arrow2::bitmap::and(&result_values, r_valid); + } + (Some(l_valid), None) => { + result_values = arrow2::bitmap::and(&result_values, l_valid); + } + (Some(l_valid), Some(r_valid)) => { + let nulls_match = arrow2::bitmap::bitwise_eq(l_valid, r_valid); + result_values = arrow2::bitmap::and(&result_values, &nulls_match); + } + } + + Ok(BooleanArray::from(( + self.name(), + arrow2::array::BooleanArray::new( + arrow2::datatypes::DataType::Boolean, + result_values, + None, + ), + ))) + } + (l_size, 1) => { + if let Some(value) = rhs.get(0) { + Ok(self.eq_null_safe(value)?) + } else { + let result_values = match self.as_arrow().validity() { + None => arrow2::bitmap::Bitmap::new_zeroed(l_size), + Some(validity) => validity.not(), + }; + Ok(BooleanArray::from(( + self.name(), + arrow2::array::BooleanArray::new( + arrow2::datatypes::DataType::Boolean, + result_values, + None, + ), + ))) + } + } + (1, r_size) => { + if let Some(value) = self.get(0) { + Ok(rhs.eq_null_safe(value)?) + } else { + let result_values = match rhs.as_arrow().validity() { + None => arrow2::bitmap::Bitmap::new_zeroed(r_size), + Some(validity) => validity.not(), + }; + Ok(BooleanArray::from(( + self.name(), + arrow2::array::BooleanArray::new( + arrow2::datatypes::DataType::Boolean, + result_values, + None, + ), + ))) + } + } + (l, r) => Err(DaftError::ValueError(format!( + "trying to compare different length arrays: {}: {l} vs {}: {r}", + self.name(), + rhs.name() + ))), + } + } } impl DaftCompare<&str> for Utf8Array { @@ -1156,6 +1432,28 @@ impl DaftCompare<&str> for Utf8Array { Ok(BooleanArray::from((self.name(), arrow_result))) } + + fn eq_null_safe(&self, rhs: &str) -> Self::Output { + let arrow_array = self.as_arrow(); + + let result_values = comparison::utf8::eq_scalar(arrow_array, rhs) + .values() + .clone(); + + let final_values = match arrow_array.validity() { + None => result_values, + Some(valid) => arrow2::bitmap::and(&result_values, valid), + }; + + Ok(BooleanArray::from(( + self.name(), + arrow2::array::BooleanArray::new( + arrow2::datatypes::DataType::Boolean, + final_values, + None, + ), + ))) + } } impl DaftCompare<&Self> for BinaryArray { @@ -1400,6 +1698,83 @@ impl DaftCompare<&Self> for BinaryArray { ))), } } + + fn eq_null_safe(&self, rhs: &Self) -> Self::Output { + match (self.len(), rhs.len()) { + (x, y) if x == y => { + let l_validity = self.as_arrow().validity(); + let r_validity = rhs.as_arrow().validity(); + + let mut result_values = comparison::eq(self.as_arrow(), rhs.as_arrow()) + .values() + .clone(); + + match (l_validity, r_validity) { + (None, None) => {} + (None, Some(r_valid)) => { + result_values = arrow2::bitmap::and(&result_values, r_valid); + } + (Some(l_valid), None) => { + result_values = arrow2::bitmap::and(&result_values, l_valid); + } + (Some(l_valid), Some(r_valid)) => { + let nulls_match = arrow2::bitmap::bitwise_eq(l_valid, r_valid); + result_values = arrow2::bitmap::and(&result_values, &nulls_match); + } + } + + Ok(BooleanArray::from(( + self.name(), + arrow2::array::BooleanArray::new( + arrow2::datatypes::DataType::Boolean, + result_values, + None, + ), + ))) + } + (l_size, 1) => { + if let Some(value) = rhs.get(0) { + Ok(self.eq_null_safe(value)?) + } else { + let result_values = match self.as_arrow().validity() { + None => arrow2::bitmap::Bitmap::new_zeroed(l_size), + Some(validity) => validity.not(), + }; + Ok(BooleanArray::from(( + self.name(), + arrow2::array::BooleanArray::new( + arrow2::datatypes::DataType::Boolean, + result_values, + None, + ), + ))) + } + } + (1, r_size) => { + if let Some(value) = self.get(0) { + Ok(rhs.eq_null_safe(value)?) + } else { + let result_values = match rhs.as_arrow().validity() { + None => arrow2::bitmap::Bitmap::new_zeroed(r_size), + Some(validity) => validity.not(), + }; + Ok(BooleanArray::from(( + self.name(), + arrow2::array::BooleanArray::new( + arrow2::datatypes::DataType::Boolean, + result_values, + None, + ), + ))) + } + } + (l, r) => Err(DaftError::ValueError(format!( + "trying to compare different length arrays: {}: {l} vs {}: {r}", + self.name(), + rhs.name() + ))), + } + } } impl DaftCompare<&[u8]> for BinaryArray { @@ -1452,6 +1827,28 @@ impl DaftCompare<&[u8]> for BinaryArray { Ok(BooleanArray::from((self.name(), arrow_result))) } + + fn eq_null_safe(&self, rhs: &[u8]) -> Self::Output { + let arrow_array = self.as_arrow(); + + let result_values = comparison::binary::eq_scalar(arrow_array, rhs) + .values() + .clone(); + + let final_values = match arrow_array.validity() { + None => result_values, + Some(valid) => arrow2::bitmap::and(&result_values, valid), + }; + + Ok(BooleanArray::from(( + self.name(), + arrow2::array::BooleanArray::new( + arrow2::datatypes::DataType::Boolean, + final_values, + None, + ), + ))) + } } fn compare_fixed_size_binary( @@ -1711,6 +2108,85 @@ impl DaftCompare<&Self> for FixedSizeBinaryArray { ))), } } + + fn eq_null_safe(&self, rhs: &Self) -> Self::Output { + println!("Starting eq_null_safe for DaftCompare<&Self> for FixedSizeBinaryArray"); + + match (self.len(), rhs.len()) { + (x, y) if x == y => { + let l_validity = self.as_arrow().validity(); + let r_validity = rhs.as_arrow().validity(); + + let mut result_values = comparison::eq(self.as_arrow(), rhs.as_arrow()) + .values() + .clone(); + + match (l_validity, r_validity) { + (None, None) => {} + (None, Some(r_valid)) => { + result_values = arrow2::bitmap::and(&result_values, r_valid); + } + (Some(l_valid), None) => { + result_values = arrow2::bitmap::and(&result_values, l_valid); + } + (Some(l_valid), Some(r_valid)) => { + let nulls_match = arrow2::bitmap::bitwise_eq(l_valid, r_valid); + result_values = arrow2::bitmap::and(&result_values, &nulls_match); + } + } + + Ok(BooleanArray::from(( + self.name(), + arrow2::array::BooleanArray::new( + arrow2::datatypes::DataType::Boolean, + result_values, + None, + ), + ))) + } + (l_size, 1) => { + if let Some(value) = rhs.get(0) { + Ok(self.eq_null_safe(value)?) + } else { + let result_values = match self.as_arrow().validity() { + None => arrow2::bitmap::Bitmap::new_zeroed(l_size), + Some(validity) => validity.not(), + }; + Ok(BooleanArray::from(( + self.name(), + arrow2::array::BooleanArray::new( + arrow2::datatypes::DataType::Boolean, + result_values, + None, + ), + ))) + } + } + (1, r_size) => { + if let Some(value) = self.get(0) { + Ok(rhs.eq_null_safe(value)?) + } else { + let result_values = match rhs.as_arrow().validity() { + None => arrow2::bitmap::Bitmap::new_zeroed(r_size), + Some(validity) => validity.not(), + }; + Ok(BooleanArray::from(( + self.name(), + arrow2::array::BooleanArray::new( + arrow2::datatypes::DataType::Boolean, + result_values, + None, + ), + ))) + } + } + (l, r) => Err(DaftError::ValueError(format!( + "trying to compare different length arrays: {}: {l} vs {}: {r}", + self.name(), + rhs.name() + ))), + } + } } impl DaftCompare<&[u8]> for FixedSizeBinaryArray { @@ -1739,6 +2215,28 @@ impl DaftCompare<&[u8]> for FixedSizeBinaryArray { fn gte(&self, rhs: &[u8]) -> Self::Output { cmp_fixed_size_binary_scalar(self, rhs, |lhs, rhs| lhs >= rhs) } + + fn eq_null_safe(&self, rhs: &[u8]) -> Self::Output { + let arrow_array = self.as_arrow(); + + let result_values = arrow2::bitmap::Bitmap::from_trusted_len_iter( + arrow_array.values_iter().map(|lhs| lhs == rhs), + ); + + let final_values = match arrow_array.validity() { + None => result_values, + Some(valid) => arrow2::bitmap::and(&result_values, valid), + }; + + Ok(BooleanArray::from(( + self.name(), + arrow2::array::BooleanArray::new( + arrow2::datatypes::DataType::Boolean, + final_values, + None, + ), + ))) + } } #[cfg(test)] diff --git a/src/daft-core/src/array/ops/mod.rs b/src/daft-core/src/array/ops/mod.rs index 97583cca0d..89392c6c8b 100644 --- a/src/daft-core/src/array/ops/mod.rs +++ b/src/daft-core/src/array/ops/mod.rs @@ -7,6 +7,7 @@ mod arithmetic; pub mod arrow2; pub mod as_arrow; mod between; +mod binary; mod bitwise; pub(crate) mod broadcast; pub(crate) mod cast; @@ -75,6 +76,9 @@ pub trait DaftCompare { /// equality. fn equal(&self, rhs: Rhs) -> Self::Output; + /// null-safe equality. + fn eq_null_safe(&self, rhs: Rhs) -> Self::Output; + /// inequality. fn not_equal(&self, rhs: Rhs) -> Self::Output; diff --git a/src/daft-core/src/series/ops/binary.rs b/src/daft-core/src/series/ops/binary.rs new file mode 100644 index 0000000000..6a69bd6331 --- /dev/null +++ b/src/daft-core/src/series/ops/binary.rs @@ -0,0 +1,22 @@ +use common_error::{DaftError, DaftResult}; + +use crate::{datatypes::*, series::Series}; + +impl Series { + pub fn with_binary_array( + &self, + f: impl Fn(&BinaryArray) -> DaftResult, + ) -> DaftResult { + match self.data_type() { + DataType::Binary => f(self.binary()?), + DataType::FixedSizeBinary(_) => Err(DaftError::TypeError(format!( + "Operation not implemented for type {}", + self.data_type() + ))), + DataType::Null => Ok(self.clone()), + dt => Err(DaftError::TypeError(format!( + "Operation not implemented for type {dt}" + ))), + } + } +} diff --git a/src/daft-core/src/series/ops/comparison.rs b/src/daft-core/src/series/ops/comparison.rs index 2d0fd65c79..8eb18bd20a 100644 --- a/src/daft-core/src/series/ops/comparison.rs +++ b/src/daft-core/src/series/ops/comparison.rs @@ -65,4 +65,5 @@ impl DaftCompare<&Self> for Series { impl_compare_method!(lte, le); impl_compare_method!(gt, gt); impl_compare_method!(gte, ge); + impl_compare_method!(eq_null_safe, eq_null_safe); } diff --git a/src/daft-core/src/series/ops/mod.rs b/src/daft-core/src/series/ops/mod.rs index 4e5b9b404d..0fc54cb4fc 100644 --- a/src/daft-core/src/series/ops/mod.rs +++ b/src/daft-core/src/series/ops/mod.rs @@ -7,6 +7,7 @@ pub mod abs; pub mod agg; pub mod arithmetic; pub mod between; +pub mod binary; pub mod broadcast; pub mod cast; pub mod cbrt; diff --git a/src/daft-dsl/Cargo.toml b/src/daft-dsl/Cargo.toml index 6e04a977aa..87b2bc1bbc 100644 --- a/src/daft-dsl/Cargo.toml +++ b/src/daft-dsl/Cargo.toml @@ -10,10 +10,8 @@ daft-sketch = {path = "../daft-sketch", default-features = false} derive_more = {workspace = true} indexmap = {workspace = true} itertools = {workspace = true} -log = {workspace = true} pyo3 = {workspace = true, optional = true} serde = {workspace = true} -typed-builder = {workspace = true} typetag = {workspace = true} [features] diff --git a/src/daft-dsl/src/expr/mod.rs b/src/daft-dsl/src/expr/mod.rs index a8df3040a0..460c7bc3d9 100644 --- a/src/daft-dsl/src/expr/mod.rs +++ b/src/daft-dsl/src/expr/mod.rs @@ -5,6 +5,7 @@ use std::{ any::Any, hash::{DefaultHasher, Hash, Hasher}, io::{self, Write}, + str::FromStr, sync::Arc, }; @@ -999,7 +1000,8 @@ impl Expr { | Operator::Eq | Operator::NotEq | Operator::LtEq - | Operator::GtEq => { + | Operator::GtEq + | Operator::EqNullSafe => { let (result_type, _intermediate, _comp_type) = InferDataType::from(&left_field.dtype) .comparison_op(&InferDataType::from(&right_field.dtype))?; @@ -1153,6 +1155,7 @@ impl Expr { to_sql_inner(left, buffer)?; let op = match op { Operator::Eq => "=", + Operator::EqNullSafe => "<=>", Operator::NotEq => "!=", Operator::Lt => "<", Operator::LtEq => "<=", @@ -1247,12 +1250,18 @@ impl Expr { Self::InSubquery(expr, _) => expr.has_compute(), } } + + pub fn eq_null_safe(self: ExprRef, other: ExprRef) -> ExprRef { + binary_op(Operator::EqNullSafe, self, other) + } } #[derive(Display, Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] pub enum Operator { #[display("==")] Eq, + #[display("<=>")] + EqNullSafe, #[display("!=")] NotEq, #[display("<")] @@ -1293,6 +1302,7 @@ impl Operator { matches!( self, Self::Eq + | Self::EqNullSafe | Self::NotEq | Self::Lt | Self::LtEq @@ -1309,6 +1319,32 @@ impl Operator { } } +impl FromStr for Operator { + type Err = DaftError; + fn from_str(s: &str) -> Result { + match s { + "==" => Ok(Self::Eq), + "!=" => Ok(Self::NotEq), + "<" => Ok(Self::Lt), + "<=" => Ok(Self::LtEq), + ">" => Ok(Self::Gt), + ">=" => Ok(Self::GtEq), + "+" => Ok(Self::Plus), + "-" => Ok(Self::Minus), + "*" => Ok(Self::Multiply), + "/" => Ok(Self::TrueDivide), + "//" => Ok(Self::FloorDivide), + "%" => Ok(Self::Modulus), + "&" => Ok(Self::And), + "|" => Ok(Self::Or), + "^" => Ok(Self::Xor), + "<<" => Ok(Self::ShiftLeft), + ">>" => Ok(Self::ShiftRight), + _ => Err(DaftError::ComputeError(format!("Invalid operator: {}", s))), + } + } +} + // Check if one set of columns is a reordering of the other pub fn is_partition_compatible(a: &[ExprRef], b: &[ExprRef]) -> bool { // sort a and b by name @@ -1353,3 +1389,85 @@ pub fn count_actor_pool_udfs(exprs: &[ExprRef]) -> usize { }) .sum() } + +pub fn estimated_selectivity(expr: &Expr, schema: &Schema) -> f64 { + match expr { + // Boolean operations that filter rows + Expr::BinaryOp { op, left, right } => { + let left_selectivity = estimated_selectivity(left, schema); + let right_selectivity = estimated_selectivity(right, schema); + match op { + // Fixed selectivity for all common comparisons + Operator::Eq => 0.1, + Operator::EqNullSafe => 0.1, + Operator::NotEq => 0.9, + Operator::Lt | Operator::LtEq | Operator::Gt | Operator::GtEq => 0.2, + + // Logical operators with fixed estimates + // P(A and B) = P(A) * P(B) + Operator::And => left_selectivity * right_selectivity, + // P(A or B) = P(A) + P(B) - P(A and B) + Operator::Or => left_selectivity + .mul_add(-right_selectivity, left_selectivity + right_selectivity), + // P(A xor B) = P(A) + P(B) - 2 * P(A and B) + Operator::Xor => 2.0f64.mul_add( + -(left_selectivity * right_selectivity), + left_selectivity + right_selectivity, + ), + + // Non-boolean operators don't filter + Operator::Plus + | Operator::Minus + | Operator::Multiply + | Operator::TrueDivide + | Operator::FloorDivide + | Operator::Modulus + | Operator::ShiftLeft + | Operator::ShiftRight => 1.0, + } + } + + // Revert selectivity for NOT + Expr::Not(expr) => 1.0 - estimated_selectivity(expr, schema), + + // Fixed selectivity for IS NULL and IS NOT NULL, assume not many nulls + Expr::IsNull(_) => 0.1, + Expr::NotNull(_) => 0.9, + + // All membership operations use same selectivity + Expr::IsIn(_, _) | Expr::Between(_, _, _) | Expr::InSubquery(_, _) | Expr::Exists(_) => 0.2, + + // Pass through for expressions that wrap other expressions + Expr::Cast(expr, _) | Expr::Alias(expr, _) => estimated_selectivity(expr, schema), + + // Boolean literals + Expr::Literal(lit) => match lit { + lit::LiteralValue::Boolean(true) => 1.0, + lit::LiteralValue::Boolean(false) => 0.0, + _ => 1.0, + }, + + // Everything else that could be boolean gets 0.2, non-boolean gets 1.0 + Expr::ScalarFunction(_) + | Expr::Function { .. } + | Expr::Column(_) + | Expr::OuterReferenceColumn(_) + | Expr::IfElse { .. } + | Expr::FillNull(_, _) => match expr.to_field(schema) { + Ok(field) if field.dtype == DataType::Boolean => 0.2, + _ => 1.0, + }, + + // Everything else doesn't filter + Expr::Subquery(_) => 1.0, + Expr::Agg(_) => panic!("Aggregates are not allowed in WHERE clauses"), + } +} + +pub fn exprs_to_schema(exprs: &[ExprRef], input_schema: SchemaRef) -> DaftResult { + let fields = exprs + .iter() + .map(|e| e.to_field(&input_schema)) + .collect::>()?; + Ok(Arc::new(Schema::new(fields)?)) +} diff --git a/src/daft-dsl/src/lib.rs b/src/daft-dsl/src/lib.rs index fe7c44f068..c29ca9f779 100644 --- a/src/daft-dsl/src/lib.rs +++ b/src/daft-dsl/src/lib.rs @@ -11,18 +11,16 @@ pub mod optimization; mod pyobj_serde; #[cfg(feature = "python")] pub mod python; -mod resolve_expr; mod treenode; pub use common_treenode; pub use expr::{ - binary_op, col, count_actor_pool_udfs, has_agg, is_actor_pool_udf, is_partition_compatible, - AggExpr, ApproxPercentileParams, Expr, ExprRef, Operator, OuterReferenceColumn, SketchType, - Subquery, SubqueryPlan, + binary_op, col, count_actor_pool_udfs, estimated_selectivity, exprs_to_schema, has_agg, + is_actor_pool_udf, is_partition_compatible, AggExpr, ApproxPercentileParams, Expr, ExprRef, + Operator, OuterReferenceColumn, SketchType, Subquery, SubqueryPlan, }; pub use lit::{lit, literal_value, literals_to_series, null_lit, Literal, LiteralValue}; #[cfg(feature = "python")] use pyo3::prelude::*; -pub use resolve_expr::{check_column_name_validity, ExprResolver}; #[cfg(feature = "python")] pub fn register_modules(parent: &Bound) -> PyResult<()> { @@ -41,10 +39,6 @@ pub fn register_modules(parent: &Bound) -> PyResult<()> { parent.add_function(wrap_pyfunction!(python::initialize_udfs, parent)?)?; parent.add_function(wrap_pyfunction!(python::get_udf_names, parent)?)?; parent.add_function(wrap_pyfunction!(python::eq, parent)?)?; - parent.add_function(wrap_pyfunction!( - python::check_column_name_validity, - parent - )?)?; Ok(()) } diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index a12058f4bd..df380bd154 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -257,11 +257,6 @@ pub fn eq(expr1: &PyExpr, expr2: &PyExpr) -> PyResult { Ok(expr1.expr == expr2.expr) } -#[pyfunction] -pub fn check_column_name_validity(name: &str, schema: &PySchema) -> PyResult<()> { - Ok(crate::check_column_name_validity(name, &schema.schema)?) -} - #[derive(FromPyObject)] pub enum ApproxPercentileInput { Single(f64), @@ -430,6 +425,10 @@ impl PyExpr { Ok(self.expr.clone().fill_null(fill_value.expr.clone()).into()) } + pub fn eq_null_safe(&self, other: &Self) -> PyResult { + Ok(crate::binary_op(crate::Operator::EqNullSafe, self.into(), other.into()).into()) + } + pub fn is_in(&self, other: Vec) -> PyResult { let other = other.into_iter().map(|e| e.into()).collect(); diff --git a/src/daft-functions/Cargo.toml b/src/daft-functions/Cargo.toml index 498c5fb590..03dcc18d19 100644 --- a/src/daft-functions/Cargo.toml +++ b/src/daft-functions/Cargo.toml @@ -3,7 +3,6 @@ arrow2 = {workspace = true} base64 = {workspace = true} common-error = {path = "../common/error", default-features = false} common-hashable-float-wrapper = {path = "../common/hashable-float-wrapper"} -common-io-config = {path = "../common/io-config", default-features = false} common-runtime = {path = "../common/runtime", default-features = false} daft-core = {path = "../daft-core", default-features = false} daft-dsl = {path = "../daft-dsl", default-features = false} @@ -25,7 +24,6 @@ snafu.workspace = true [features] python = [ "common-error/python", - "common-io-config/python", "daft-core/python", "daft-dsl/python", "daft-image/python", diff --git a/src/daft-functions/src/binary/concat.rs b/src/daft-functions/src/binary/concat.rs new file mode 100644 index 0000000000..2a92334d40 --- /dev/null +++ b/src/daft-functions/src/binary/concat.rs @@ -0,0 +1,122 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{ + array::ops::as_arrow::AsArrow, + datatypes::{BinaryArray, DataType, Field, FixedSizeBinaryArray}, + prelude::Schema, + series::{IntoSeries, Series}, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct BinaryConcat {} + +#[typetag::serde] +impl ScalarUDF for BinaryConcat { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "concat" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + match inputs { + [left, right] => { + let left_field = left.to_field(schema)?; + let right_field = right.to_field(schema)?; + match (&left_field.dtype, &right_field.dtype) { + (DataType::Binary, DataType::Binary) => { + Ok(Field::new(left_field.name, DataType::Binary)) + } + (DataType::Binary, DataType::Null) | (DataType::Null, DataType::Binary) => { + Ok(Field::new(left_field.name, DataType::Binary)) + } + (DataType::FixedSizeBinary(size1), DataType::FixedSizeBinary(size2)) => Ok( + Field::new(left_field.name, DataType::FixedSizeBinary(size1 + size2)), + ), + (DataType::FixedSizeBinary(_), DataType::Binary) + | (DataType::Binary, DataType::FixedSizeBinary(_)) => { + Ok(Field::new(left_field.name, DataType::Binary)) + } + (DataType::FixedSizeBinary(_), DataType::Null) + | (DataType::Null, DataType::FixedSizeBinary(_)) => { + Ok(Field::new(left_field.name, DataType::Binary)) + } + _ => Err(DaftError::TypeError(format!( + "Expects inputs to concat to be binary, but received {} and {}", + format_field_type_for_error(&left_field), + format_field_type_for_error(&right_field), + ))), + } + } + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 2 input args, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + let result_name = inputs[0].name(); + match (inputs[0].data_type(), inputs[1].data_type()) { + (DataType::Binary, DataType::Binary) => { + let left_array = inputs[0].downcast::()?; + let right_array = inputs[1].downcast::()?; + let result = left_array.binary_concat(right_array)?; + Ok( + BinaryArray::from((result_name, Box::new(result.as_arrow().clone()))) + .into_series(), + ) + } + (DataType::FixedSizeBinary(_), DataType::FixedSizeBinary(_)) => { + let left_array = inputs[0].downcast::()?; + let right_array = inputs[1].downcast::()?; + let result = left_array.binary_concat(right_array)?; + Ok( + FixedSizeBinaryArray::from((result_name, Box::new(result.as_arrow().clone()))) + .into_series(), + ) + } + (DataType::FixedSizeBinary(_), DataType::Binary) + | (DataType::Binary, DataType::FixedSizeBinary(_)) => { + let left_array = match inputs[0].data_type() { + DataType::FixedSizeBinary(_) => inputs[0] + .downcast::()? + .cast(&DataType::Binary)?, + _ => inputs[0].downcast::()?.clone().into_series(), + }; + let right_array = match inputs[1].data_type() { + DataType::FixedSizeBinary(_) => inputs[1] + .downcast::()? + .cast(&DataType::Binary)?, + _ => inputs[1].downcast::()?.clone().into_series(), + }; + let result = left_array.binary()?.binary_concat(right_array.binary()?)?; + Ok( + BinaryArray::from((result_name, Box::new(result.as_arrow().clone()))) + .into_series(), + ) + } + (_, DataType::Null) | (DataType::Null, _) => { + let len = inputs[0].len().max(inputs[1].len()); + Ok(Series::full_null(result_name, &DataType::Binary, len)) + } + _ => unreachable!("Type checking done in to_field"), + } + } +} + +pub fn binary_concat(left: ExprRef, right: ExprRef) -> ExprRef { + ScalarFunction::new(BinaryConcat {}, vec![left, right]).into() +} + +fn format_field_type_for_error(field: &Field) -> String { + match field.dtype { + DataType::FixedSizeBinary(_) => format!("{}#Binary", field.name), + _ => format!("{}#{}", field.name, field.dtype), + } +} diff --git a/src/daft-functions/src/binary/length.rs b/src/daft-functions/src/binary/length.rs new file mode 100644 index 0000000000..088a0e782a --- /dev/null +++ b/src/daft-functions/src/binary/length.rs @@ -0,0 +1,60 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{ + datatypes::{BinaryArray, DataType, Field, FixedSizeBinaryArray}, + prelude::Schema, + series::{IntoSeries, Series}, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct BinaryLength {} + +#[typetag::serde] +impl ScalarUDF for BinaryLength { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "length" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + let data = &inputs[0]; + match data.to_field(schema) { + Ok(data_field) => match &data_field.dtype { + DataType::Binary | DataType::FixedSizeBinary(_) => { + Ok(Field::new(data_field.name, DataType::UInt64)) + } + _ => Err(DaftError::TypeError(format!( + "Expects input to length to be binary, but received {data_field}", + ))), + }, + Err(e) => Err(e), + } + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + match inputs[0].data_type() { + DataType::Binary => { + let binary_array = inputs[0].downcast::()?; + let result = binary_array.length()?; + Ok(result.into_series()) + } + DataType::FixedSizeBinary(_size) => { + let binary_array = inputs[0].downcast::()?; + let result = binary_array.length()?; + Ok(result.into_series()) + } + _ => unreachable!("Type checking is done in to_field"), + } + } +} + +#[must_use] +pub fn binary_length(input: ExprRef) -> ExprRef { + ScalarFunction::new(BinaryLength {}, vec![input]).into() +} diff --git a/src/daft-functions/src/binary/mod.rs b/src/daft-functions/src/binary/mod.rs new file mode 100644 index 0000000000..da7cc0c59d --- /dev/null +++ b/src/daft-functions/src/binary/mod.rs @@ -0,0 +1,7 @@ +pub mod concat; +pub mod length; +pub mod slice; + +pub use concat::{binary_concat, BinaryConcat}; +pub use length::{binary_length, BinaryLength}; +pub use slice::{binary_slice, BinarySlice}; diff --git a/src/daft-functions/src/binary/slice.rs b/src/daft-functions/src/binary/slice.rs new file mode 100644 index 0000000000..69032c9531 --- /dev/null +++ b/src/daft-functions/src/binary/slice.rs @@ -0,0 +1,96 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{ + datatypes::DataType, + prelude::{Field, IntoSeries, Schema}, + series::Series, + with_match_integer_daft_types, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct BinarySlice {} + +#[typetag::serde] +impl ScalarUDF for BinarySlice { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "binary_slice" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + match inputs { + [data, start, length] => { + let data = data.to_field(schema)?; + let start = start.to_field(schema)?; + let length = length.to_field(schema)?; + + match &data.dtype { + DataType::Binary | DataType::FixedSizeBinary(_) => { + if start.dtype.is_integer() && (length.dtype.is_integer() || length.dtype.is_null()) { + Ok(Field::new(data.name, DataType::Binary)) + } else { + Err(DaftError::TypeError(format!( + "Expects inputs to binary_slice to be binary, integer and integer or null but received {}, {} and {}", + data.dtype, start.dtype, length.dtype + ))) + } + } + _ => Err(DaftError::TypeError(format!( + "Expects inputs to binary_slice to be binary, integer and integer or null but received {}, {} and {}", + data.dtype, start.dtype, length.dtype + ))), + } + } + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 3 input args, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + let data = &inputs[0]; + let start = &inputs[1]; + let length = &inputs[2]; + + match data.data_type() { + DataType::Binary | DataType::FixedSizeBinary(_) => { + let binary_data = match data.data_type() { + DataType::Binary => data.clone(), + _ => data.cast(&DataType::Binary)?, + }; + binary_data.with_binary_array(|arr| { + with_match_integer_daft_types!(start.data_type(), |$T| { + if length.data_type().is_integer() { + with_match_integer_daft_types!(length.data_type(), |$U| { + Ok(arr.binary_slice(start.downcast::<<$T as DaftDataType>::ArrayType>()?, Some(length.downcast::<<$U as DaftDataType>::ArrayType>()?))?.into_series()) + }) + } else if length.data_type().is_null() { + Ok(arr.binary_slice(start.downcast::<<$T as DaftDataType>::ArrayType>()?, None::<&DataArray>)?.into_series()) + } else { + Err(DaftError::TypeError(format!( + "slice not implemented for length type {}", + length.data_type() + ))) + } + }) + }) + } + DataType::Null => Ok(data.clone()), + dt => Err(DaftError::TypeError(format!( + "Operation not implemented for type {dt}" + ))), + } + } +} + +#[must_use] +pub fn binary_slice(input: ExprRef, start: ExprRef, length: ExprRef) -> ExprRef { + ScalarFunction::new(BinarySlice {}, vec![input, start, length]).into() +} diff --git a/src/daft-functions/src/lib.rs b/src/daft-functions/src/lib.rs index 20c17e358c..00c787b2af 100644 --- a/src/daft-functions/src/lib.rs +++ b/src/daft-functions/src/lib.rs @@ -1,4 +1,5 @@ #![feature(async_closure)] +pub mod binary; pub mod coalesce; pub mod count_matches; pub mod distance; diff --git a/src/daft-functions/src/list/list_fill.rs b/src/daft-functions/src/list/list_fill.rs index 39f633ef7c..a9fdadbaf9 100644 --- a/src/daft-functions/src/list/list_fill.rs +++ b/src/daft-functions/src/list/list_fill.rs @@ -125,8 +125,6 @@ mod tests { vec![Some(1), Some(0), Some(10)].into_iter(), ) .into_series(); - let str = Utf8Array::from_iter("s2", vec![None, Some("hello"), Some("world")].into_iter()) - .into_series(); let error = fill.evaluate(&[num.clone()]).unwrap_err(); assert_eq!( diff --git a/src/daft-functions/src/python/binary.rs b/src/daft-functions/src/python/binary.rs new file mode 100644 index 0000000000..5bf2225197 --- /dev/null +++ b/src/daft-functions/src/python/binary.rs @@ -0,0 +1,11 @@ +use daft_dsl::python::PyExpr; +use pyo3::{pyfunction, PyResult}; + +use crate::binary::{ + concat::binary_concat as concat_fn, length::binary_length as length_fn, + slice::binary_slice as slice_fn, +}; + +simple_python_wrapper!(binary_length, length_fn, [input: PyExpr]); +simple_python_wrapper!(binary_concat, concat_fn, [left: PyExpr, right: PyExpr]); +simple_python_wrapper!(binary_slice, slice_fn, [input: PyExpr, start: PyExpr, length: PyExpr]); diff --git a/src/daft-functions/src/python/mod.rs b/src/daft-functions/src/python/mod.rs index 4032ce9d60..b37ef91f9d 100644 --- a/src/daft-functions/src/python/mod.rs +++ b/src/daft-functions/src/python/mod.rs @@ -12,6 +12,7 @@ macro_rules! simple_python_wrapper { }; } +mod binary; mod coalesce; mod distance; mod float; @@ -38,6 +39,9 @@ pub fn register(parent: &Bound) -> PyResult<()> { add!(coalesce::coalesce); add!(distance::cosine_distance); + add!(binary::binary_length); + add!(binary::binary_concat); + add!(binary::binary_slice); add!(float::is_inf); add!(float::is_nan); diff --git a/src/daft-functions/src/python/uri.rs b/src/daft-functions/src/python/uri.rs index d7548c0125..50eb0c33b9 100644 --- a/src/daft-functions/src/python/uri.rs +++ b/src/daft-functions/src/python/uri.rs @@ -2,6 +2,8 @@ use daft_dsl::python::PyExpr; use daft_io::python::IOConfig; use pyo3::{exceptions::PyValueError, pyfunction, PyResult}; +use crate::uri::{self, download::UrlDownloadArgs, upload::UrlUploadArgs}; + #[pyfunction] pub fn url_download( expr: PyExpr, @@ -15,15 +17,13 @@ pub fn url_download( "max_connections must be positive and non_zero: {max_connections}" ))); } - - Ok(crate::uri::download( - expr.into(), + let args = UrlDownloadArgs::new( max_connections as usize, raise_error_on_failure, multi_thread, Some(config.config), - ) - .into()) + ); + Ok(uri::download(expr.into(), Some(args)).into()) } #[pyfunction(signature = ( @@ -49,14 +49,12 @@ pub fn url_upload( "max_connections must be positive and non_zero: {max_connections}" ))); } - Ok(crate::uri::upload( - expr.into(), - folder_location.into(), + let args = UrlUploadArgs::new( max_connections as usize, raise_error_on_failure, multi_thread, is_single_folder, io_config.map(|io_config| io_config.config), - ) - .into()) + ); + Ok(uri::upload(expr.into(), folder_location.into(), Some(args)).into()) } diff --git a/src/daft-functions/src/uri/download.rs b/src/daft-functions/src/uri/download.rs index 24d3f89d33..79bfd89c10 100644 --- a/src/daft-functions/src/uri/download.rs +++ b/src/daft-functions/src/uri/download.rs @@ -11,16 +11,52 @@ use snafu::prelude::*; use crate::InvalidArgumentSnafu; +/// Container for the keyword arguments of `url_download` +/// ex: +/// ```text +/// url_decode(input) +/// url_decode(input, max_connections=32) +/// url_decode(input, on_error='raise') +/// url_decode(input, on_error='null') +/// url_decode(input, max_connections=32, on_error='raise') +/// ``` #[derive(Debug, Clone, Serialize, serde::Deserialize, PartialEq, Eq, Hash)] -pub(super) struct DownloadFunction { - pub(super) max_connections: usize, - pub(super) raise_error_on_failure: bool, - pub(super) multi_thread: bool, - pub(super) config: Arc, +pub struct UrlDownloadArgs { + pub max_connections: usize, + pub raise_error_on_failure: bool, + pub multi_thread: bool, + pub io_config: Arc, +} + +impl UrlDownloadArgs { + pub fn new( + max_connections: usize, + raise_error_on_failure: bool, + multi_thread: bool, + io_config: Option, + ) -> Self { + Self { + max_connections, + raise_error_on_failure, + multi_thread, + io_config: io_config.unwrap_or_default().into(), + } + } +} + +impl Default for UrlDownloadArgs { + fn default() -> Self { + Self { + max_connections: 32, + raise_error_on_failure: true, + multi_thread: true, + io_config: IOConfig::default().into(), + } + } } #[typetag::serde] -impl ScalarUDF for DownloadFunction { +impl ScalarUDF for UrlDownloadArgs { fn as_any(&self) -> &dyn std::any::Any { self } @@ -34,7 +70,7 @@ impl ScalarUDF for DownloadFunction { max_connections, raise_error_on_failure, multi_thread, - config, + io_config, } = self; match inputs { @@ -47,7 +83,7 @@ impl ScalarUDF for DownloadFunction { *max_connections, *raise_error_on_failure, *multi_thread, - config.clone(), + io_config.clone(), Some(io_stats), )?; Ok(result.into_series()) diff --git a/src/daft-functions/src/uri/mod.rs b/src/daft-functions/src/uri/mod.rs index 67418fa1df..541af8ef8d 100644 --- a/src/daft-functions/src/uri/mod.rs +++ b/src/daft-functions/src/uri/mod.rs @@ -1,50 +1,18 @@ -mod download; -mod upload; +pub mod download; +pub mod upload; -use common_io_config::IOConfig; use daft_dsl::{functions::ScalarFunction, ExprRef}; -use download::DownloadFunction; -use upload::UploadFunction; +use download::UrlDownloadArgs; +use upload::UrlUploadArgs; +/// Creates a `url_download` ExprRef from the positional and optional named arguments. #[must_use] -pub fn download( - input: ExprRef, - max_connections: usize, - raise_error_on_failure: bool, - multi_thread: bool, - config: Option, -) -> ExprRef { - ScalarFunction::new( - DownloadFunction { - max_connections, - raise_error_on_failure, - multi_thread, - config: config.unwrap_or_default().into(), - }, - vec![input], - ) - .into() +pub fn download(input: ExprRef, args: Option) -> ExprRef { + ScalarFunction::new(args.unwrap_or_default(), vec![input]).into() } +/// Creates a `url_upload` ExprRef from the positional and optional named arguments. #[must_use] -pub fn upload( - input: ExprRef, - location: ExprRef, - max_connections: usize, - raise_error_on_failure: bool, - multi_thread: bool, - is_single_folder: bool, - config: Option, -) -> ExprRef { - ScalarFunction::new( - UploadFunction { - max_connections, - raise_error_on_failure, - multi_thread, - is_single_folder, - config: config.unwrap_or_default().into(), - }, - vec![input, location], - ) - .into() +pub fn upload(input: ExprRef, location: ExprRef, args: Option) -> ExprRef { + ScalarFunction::new(args.unwrap_or_default(), vec![input, location]).into() } diff --git a/src/daft-functions/src/uri/upload.rs b/src/daft-functions/src/uri/upload.rs index 5b01858b94..530f23f984 100644 --- a/src/daft-functions/src/uri/upload.rs +++ b/src/daft-functions/src/uri/upload.rs @@ -9,16 +9,46 @@ use futures::{StreamExt, TryStreamExt}; use serde::Serialize; #[derive(Debug, Clone, Serialize, serde::Deserialize, PartialEq, Eq, Hash)] -pub(super) struct UploadFunction { - pub(super) max_connections: usize, - pub(super) raise_error_on_failure: bool, - pub(super) multi_thread: bool, - pub(super) is_single_folder: bool, - pub(super) config: Arc, +pub struct UrlUploadArgs { + pub max_connections: usize, + pub raise_error_on_failure: bool, + pub multi_thread: bool, + pub is_single_folder: bool, + pub io_config: Arc, +} + +impl UrlUploadArgs { + pub fn new( + max_connections: usize, + raise_error_on_failure: bool, + multi_thread: bool, + is_single_folder: bool, + io_config: Option, + ) -> Self { + Self { + max_connections, + raise_error_on_failure, + multi_thread, + is_single_folder, + io_config: io_config.unwrap_or_default().into(), + } + } +} + +impl Default for UrlUploadArgs { + fn default() -> Self { + Self { + max_connections: 32, + raise_error_on_failure: true, + multi_thread: true, + is_single_folder: false, + io_config: IOConfig::default().into(), + } + } } #[typetag::serde] -impl ScalarUDF for UploadFunction { +impl ScalarUDF for UrlUploadArgs { fn as_any(&self) -> &dyn std::any::Any { self } @@ -29,11 +59,11 @@ impl ScalarUDF for UploadFunction { fn evaluate(&self, inputs: &[Series]) -> DaftResult { let Self { - config, max_connections, raise_error_on_failure, multi_thread, is_single_folder, + io_config, } = self; match inputs { @@ -44,7 +74,7 @@ impl ScalarUDF for UploadFunction { *raise_error_on_failure, *multi_thread, *is_single_folder, - config.clone(), + io_config.clone(), None, ), _ => Err(DaftError::ValueError(format!( diff --git a/src/daft-local-execution/src/lib.rs b/src/daft-local-execution/src/lib.rs index e1752eac9a..ef6cdbe93b 100644 --- a/src/daft-local-execution/src/lib.rs +++ b/src/daft-local-execution/src/lib.rs @@ -26,7 +26,7 @@ use common_runtime::{RuntimeRef, RuntimeTask}; use lazy_static::lazy_static; use progress_bar::{OperatorProgressBar, ProgressBarColor, ProgressBarManager}; use resource_manager::MemoryManager; -pub use run::{run_local, ExecutionEngineResult, NativeExecutor}; +pub use run::{ExecutionEngineResult, NativeExecutor}; use runtime_stats::{RuntimeStatsContext, TimedFuture}; use snafu::{futures::TryFutureExt, ResultExt, Snafu}; use tracing::Instrument; @@ -124,7 +124,7 @@ pub(crate) struct ExecutionRuntimeContext { worker_set: TaskSet>, default_morsel_size: usize, memory_manager: Arc, - progress_bar_manager: Option>, + progress_bar_manager: Option>, } impl ExecutionRuntimeContext { @@ -132,7 +132,7 @@ impl ExecutionRuntimeContext { pub fn new( default_morsel_size: usize, memory_manager: Arc, - progress_bar_manager: Option>, + progress_bar_manager: Option>, ) -> Self { Self { worker_set: TaskSet::new(), diff --git a/src/daft-local-execution/src/pipeline.rs b/src/daft-local-execution/src/pipeline.rs index 64fb079150..efc1f1c7db 100644 --- a/src/daft-local-execution/src/pipeline.rs +++ b/src/daft-local-execution/src/pipeline.rs @@ -287,17 +287,15 @@ pub fn physical_plan_to_pipeline( StatsState::Materialized(left_stats), StatsState::Materialized(right_stats), ) => { - let left_size = left_stats.approx_stats.upper_bound_bytes; - let right_size = right_stats.approx_stats.upper_bound_bytes; - left_size.zip(right_size).map_or(true, |(l, r)| l <= r) + let left_size = left_stats.approx_stats.size_bytes; + let right_size = right_stats.approx_stats.size_bytes; + left_size <= right_size } // If stats are only available on the right side of the join, and the upper bound bytes on the // right are under the broadcast join size threshold, we build on the right instead of the left. (StatsState::NotMaterialized, StatsState::Materialized(right_stats)) => { - right_stats - .approx_stats - .upper_bound_bytes - .map_or(true, |size| size > cfg.broadcast_join_size_bytes_threshold) + right_stats.approx_stats.size_bytes + > cfg.broadcast_join_size_bytes_threshold } _ => true, }, @@ -308,21 +306,15 @@ pub fn physical_plan_to_pipeline( StatsState::Materialized(left_stats), StatsState::Materialized(right_stats), ) => { - let left_size = left_stats.approx_stats.upper_bound_bytes; - let right_size = right_stats.approx_stats.upper_bound_bytes; - left_size - .zip(right_size) - .map_or(false, |(l, r)| (r as f64) >= ((l as f64) * 1.5)) + let left_size = left_stats.approx_stats.size_bytes; + let right_size = right_stats.approx_stats.size_bytes; + right_size as f64 >= left_size as f64 * 1.5 } // If stats are only available on the left side of the join, and the upper bound bytes on the left // are under the broadcast join size threshold, we build on the left instead of the right. (StatsState::Materialized(left_stats), StatsState::NotMaterialized) => { - left_stats - .approx_stats - .upper_bound_bytes - .map_or(false, |size| { - size <= cfg.broadcast_join_size_bytes_threshold - }) + left_stats.approx_stats.size_bytes + <= cfg.broadcast_join_size_bytes_threshold } _ => false, }, @@ -333,19 +325,15 @@ pub fn physical_plan_to_pipeline( StatsState::Materialized(left_stats), StatsState::Materialized(right_stats), ) => { - let left_size = left_stats.approx_stats.upper_bound_bytes; - let right_size = right_stats.approx_stats.upper_bound_bytes; - left_size - .zip(right_size) - .map_or(true, |(l, r)| ((r as f64) * 1.5) >= (l as f64)) + let left_size = left_stats.approx_stats.size_bytes; + let right_size = right_stats.approx_stats.size_bytes; + (right_size as f64 * 1.5) >= left_size as f64 } // If stats are only available on the right side of the join, and the upper bound bytes on the // right are under the broadcast join size threshold, we build on the right instead of the left. (StatsState::NotMaterialized, StatsState::Materialized(right_stats)) => { - right_stats - .approx_stats - .upper_bound_bytes - .map_or(true, |size| size > cfg.broadcast_join_size_bytes_threshold) + right_stats.approx_stats.size_bytes + > cfg.broadcast_join_size_bytes_threshold } _ => true, }, @@ -356,21 +344,15 @@ pub fn physical_plan_to_pipeline( StatsState::Materialized(left_stats), StatsState::Materialized(right_stats), ) => { - let left_size = left_stats.approx_stats.upper_bound_bytes; - let right_size = right_stats.approx_stats.upper_bound_bytes; - left_size - .zip(right_size) - .map_or(false, |(l, r)| (r as f64) > ((l as f64) * 1.5)) + let left_size = left_stats.approx_stats.size_bytes; + let right_size = right_stats.approx_stats.size_bytes; + right_size as f64 > left_size as f64 * 1.5 } // If stats are only available on the left side of the join, and the upper bound bytes on the left // are under the broadcast join size threshold, we build on the left instead of the right. (StatsState::Materialized(left_stats), StatsState::NotMaterialized) => { - left_stats - .approx_stats - .upper_bound_bytes - .map_or(false, |size| { - size <= cfg.broadcast_join_size_bytes_threshold - }) + left_stats.approx_stats.size_bytes + <= cfg.broadcast_join_size_bytes_threshold } // Else, default to building on the right _ => false, @@ -498,15 +480,13 @@ pub fn physical_plan_to_pipeline( // the larger side to stream so that it can be parallelized via an intermediate op. Default to left side. let stream_on_left = match (left_stats_state, right_stats_state) { (StatsState::Materialized(left_stats), StatsState::Materialized(right_stats)) => { - left_stats.approx_stats.upper_bound_bytes - > right_stats.approx_stats.upper_bound_bytes + left_stats.approx_stats.num_rows > right_stats.approx_stats.num_rows } // If stats are only available on the left side of the join, and the upper bound bytes on the // left are under the broadcast join size threshold, we stream on the right. - (StatsState::Materialized(left_stats), StatsState::NotMaterialized) => left_stats - .approx_stats - .upper_bound_bytes - .map_or(true, |size| size > cfg.broadcast_join_size_bytes_threshold), + (StatsState::Materialized(left_stats), StatsState::NotMaterialized) => { + left_stats.approx_stats.size_bytes > cfg.broadcast_join_size_bytes_threshold + } // If stats are not available, we fall back and stream on the left by default. _ => true, }; diff --git a/src/daft-local-execution/src/progress_bar.rs b/src/daft-local-execution/src/progress_bar.rs index 3b42333d49..d865826da5 100644 --- a/src/daft-local-execution/src/progress_bar.rs +++ b/src/daft-local-execution/src/progress_bar.rs @@ -16,7 +16,7 @@ pub trait ProgressBar: Send + Sync { fn close(&self) -> DaftResult<()>; } -pub trait ProgressBarManager { +pub trait ProgressBarManager: std::fmt::Debug + Send + Sync { fn make_new_bar( &self, color: ProgressBarColor, @@ -128,6 +128,7 @@ impl ProgressBar for IndicatifProgressBar { } } +#[derive(Debug)] struct IndicatifProgressBarManager { multi_progress: indicatif::MultiProgress, } @@ -168,19 +169,19 @@ impl ProgressBarManager for IndicatifProgressBarManager { } } -pub fn make_progress_bar_manager() -> Box { +pub fn make_progress_bar_manager() -> Arc { #[cfg(feature = "python")] { if python::in_notebook() { - Box::new(python::TqdmProgressBarManager::new()) + Arc::new(python::TqdmProgressBarManager::new()) } else { - Box::new(IndicatifProgressBarManager::new()) + Arc::new(IndicatifProgressBarManager::new()) } } #[cfg(not(feature = "python"))] { - Box::new(IndicatifProgressBarManager::new()) + Arc::new(IndicatifProgressBarManager::new()) } } @@ -215,7 +216,7 @@ mod python { } } - #[derive(Clone)] + #[derive(Clone, Debug)] pub struct TqdmProgressBarManager { inner: Arc, } diff --git a/src/daft-local-execution/src/run.rs b/src/daft-local-execution/src/run.rs index 050e6f34f1..2d7fbaec57 100644 --- a/src/daft-local-execution/src/run.rs +++ b/src/daft-local-execution/src/run.rs @@ -9,7 +9,7 @@ use std::{ use common_daft_config::DaftExecutionConfig; use common_error::DaftResult; use common_tracing::refresh_chrome_trace; -use daft_local_plan::{translate, LocalPhysicalPlan}; +use daft_local_plan::translate; use daft_logical_plan::LogicalPlanBuilder; use daft_micropartition::{ partitioning::{InMemoryPartitionSetCache, MicroPartitionSet, PartitionSetCache}, @@ -30,7 +30,7 @@ use { use crate::{ channel::{create_channel, Receiver}, pipeline::{physical_plan_to_pipeline, viz_pipeline}, - progress_bar::make_progress_bar_manager, + progress_bar::{make_progress_bar_manager, ProgressBarManager}, resource_manager::get_or_init_memory_manager, Error, ExecutionRuntimeContext, }; @@ -61,25 +61,28 @@ pub struct PyNativeExecutor { executor: NativeExecutor, } +#[cfg(feature = "python")] +impl Default for PyNativeExecutor { + fn default() -> Self { + Self::new() + } +} + #[cfg(feature = "python")] #[pymethods] impl PyNativeExecutor { - #[staticmethod] - pub fn from_logical_plan_builder( - logical_plan_builder: &PyLogicalPlanBuilder, - py: Python, - ) -> PyResult { - py.allow_threads(|| { - Ok(Self { - executor: NativeExecutor::from_logical_plan_builder(&logical_plan_builder.builder)?, - }) - }) + #[new] + pub fn new() -> Self { + Self { + executor: NativeExecutor::new(), + } } - #[pyo3(signature = (psets, cfg, results_buffer_size=None))] + #[pyo3(signature = (logical_plan_builder, psets, cfg, results_buffer_size=None))] pub fn run<'a>( &self, py: Python<'a>, + logical_plan_builder: &PyLogicalPlanBuilder, psets: HashMap>, cfg: PyDaftExecutionConfig, results_buffer_size: Option, @@ -102,7 +105,12 @@ impl PyNativeExecutor { let psets = InMemoryPartitionSetCache::new(&native_psets); let out = py.allow_threads(|| { self.executor - .run(&psets, cfg.config, results_buffer_size) + .run( + &logical_plan_builder.builder, + &psets, + cfg.config, + results_buffer_size, + ) .map(|res| res.into_iter()) })?; let iter = Box::new(out.map(|part| { @@ -118,37 +126,134 @@ impl PyNativeExecutor { } } +#[derive(Debug, Clone)] pub struct NativeExecutor { - local_physical_plan: Arc, cancel: CancellationToken, + runtime: Option>, + pb_manager: Option>, + enable_explain_analyze: bool, +} + +impl Default for NativeExecutor { + fn default() -> Self { + Self { + cancel: CancellationToken::new(), + runtime: None, + pb_manager: should_enable_progress_bar().then(make_progress_bar_manager), + enable_explain_analyze: should_enable_explain_analyze(), + } + } } impl NativeExecutor { - pub fn from_logical_plan_builder( - logical_plan_builder: &LogicalPlanBuilder, - ) -> DaftResult { - let logical_plan = logical_plan_builder.build(); - let local_physical_plan = translate(&logical_plan)?; + pub fn new() -> Self { + Self::default() + } - Ok(Self { - local_physical_plan, - cancel: CancellationToken::new(), - }) + pub fn with_runtime(mut self, runtime: Arc) -> Self { + self.runtime = Some(runtime); + self + } + + pub fn with_progress_bar_manager(mut self, pb_manager: Arc) -> Self { + self.pb_manager = Some(pb_manager); + self + } + + pub fn enable_explain_analyze(mut self, b: bool) -> Self { + self.enable_explain_analyze = b; + self } pub fn run( &self, + logical_plan_builder: &LogicalPlanBuilder, psets: &(impl PartitionSetCache> + ?Sized), cfg: Arc, results_buffer_size: Option, ) -> DaftResult { - run_local( - &self.local_physical_plan, - psets, - cfg, - results_buffer_size, - self.cancel.clone(), - ) + let logical_plan = logical_plan_builder.build(); + let physical_plan = translate(&logical_plan)?; + refresh_chrome_trace(); + let cancel = self.cancel.clone(); + let pipeline = physical_plan_to_pipeline(&physical_plan, psets, &cfg)?; + let (tx, rx) = create_channel(results_buffer_size.unwrap_or(0)); + + let rt = self.runtime.clone(); + let pb_manager = self.pb_manager.clone(); + let enable_explain_analyze = self.enable_explain_analyze; + // todo: split this into a run and run_async method + // the run_async should spawn a task instead of a thread like this + let handle = std::thread::spawn(move || { + let runtime = rt.unwrap_or_else(|| { + Arc::new( + tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .expect("Failed to create tokio runtime"), + ) + }); + let execution_task = async { + let memory_manager = get_or_init_memory_manager(); + let mut runtime_handle = ExecutionRuntimeContext::new( + cfg.default_morsel_size, + memory_manager.clone(), + pb_manager, + ); + let receiver = pipeline.start(true, &mut runtime_handle)?; + + while let Some(val) = receiver.recv().await { + if tx.send(val).await.is_err() { + break; + } + } + + while let Some(result) = runtime_handle.join_next().await { + match result { + Ok(Err(e)) => { + runtime_handle.shutdown().await; + return DaftResult::Err(e.into()); + } + Err(e) => { + runtime_handle.shutdown().await; + return DaftResult::Err(Error::JoinError { source: e }.into()); + } + _ => {} + } + } + if enable_explain_analyze { + let curr_ms = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Time went backwards") + .as_millis(); + let file_name = format!("explain-analyze-{curr_ms}-mermaid.md"); + let mut file = File::create(file_name)?; + writeln!(file, "```mermaid\n{}\n```", viz_pipeline(pipeline.as_ref()))?; + } + Ok(()) + }; + + let local_set = tokio::task::LocalSet::new(); + local_set.block_on(&runtime, async { + tokio::select! { + biased; + () = cancel.cancelled() => { + log::info!("Execution engine cancelled"); + Ok(()) + } + _ = tokio::signal::ctrl_c() => { + log::info!("Received Ctrl-C, shutting down execution engine"); + Ok(()) + } + result = execution_task => result, + } + }) + }); + + Ok(ExecutionEngineResult { + handle, + receiver: rx, + }) } } @@ -261,82 +366,3 @@ impl IntoIterator for ExecutionEngineResult { } } } - -pub fn run_local( - physical_plan: &LocalPhysicalPlan, - psets: &(impl PartitionSetCache> + ?Sized), - cfg: Arc, - results_buffer_size: Option, - cancel: CancellationToken, -) -> DaftResult { - refresh_chrome_trace(); - let pipeline = physical_plan_to_pipeline(physical_plan, psets, &cfg)?; - let (tx, rx) = create_channel(results_buffer_size.unwrap_or(0)); - let handle = std::thread::spawn(move || { - let pb_manager = should_enable_progress_bar().then(make_progress_bar_manager); - let runtime = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .expect("Failed to create tokio runtime"); - let execution_task = async { - let memory_manager = get_or_init_memory_manager(); - let mut runtime_handle = ExecutionRuntimeContext::new( - cfg.default_morsel_size, - memory_manager.clone(), - pb_manager, - ); - let receiver = pipeline.start(true, &mut runtime_handle)?; - - while let Some(val) = receiver.recv().await { - if tx.send(val).await.is_err() { - break; - } - } - - while let Some(result) = runtime_handle.join_next().await { - match result { - Ok(Err(e)) => { - runtime_handle.shutdown().await; - return DaftResult::Err(e.into()); - } - Err(e) => { - runtime_handle.shutdown().await; - return DaftResult::Err(Error::JoinError { source: e }.into()); - } - _ => {} - } - } - if should_enable_explain_analyze() { - let curr_ms = SystemTime::now() - .duration_since(UNIX_EPOCH) - .expect("Time went backwards") - .as_millis(); - let file_name = format!("explain-analyze-{curr_ms}-mermaid.md"); - let mut file = File::create(file_name)?; - writeln!(file, "```mermaid\n{}\n```", viz_pipeline(pipeline.as_ref()))?; - } - Ok(()) - }; - - let local_set = tokio::task::LocalSet::new(); - local_set.block_on(&runtime, async { - tokio::select! { - biased; - () = cancel.cancelled() => { - log::info!("Execution engine cancelled"); - Ok(()) - } - _ = tokio::signal::ctrl_c() => { - log::info!("Received Ctrl-C, shutting down execution engine"); - Ok(()) - } - result = execution_task => result, - } - }) - }); - - Ok(ExecutionEngineResult { - handle, - receiver: rx, - }) -} diff --git a/src/daft-logical-plan/Cargo.toml b/src/daft-logical-plan/Cargo.toml index cf70c38998..0ff2d2ac1c 100644 --- a/src/daft-logical-plan/Cargo.toml +++ b/src/daft-logical-plan/Cargo.toml @@ -21,6 +21,8 @@ log = {workspace = true} pyo3 = {workspace = true, optional = true} serde = {workspace = true, features = ["rc"]} snafu = {workspace = true} +tokio = {workspace = true} +typed-builder = {workspace = true} uuid = {version = "1", features = ["v4"]} [dev-dependencies] @@ -39,6 +41,7 @@ python = [ "common-io-config/python", "common-daft-config/python", "common-resource-request/python", + "common-partitioning/python", "common-scan-info/python", "daft-core/python", "daft-dsl/python", diff --git a/src/daft-logical-plan/src/builder.rs b/src/daft-logical-plan/src/builder/mod.rs similarity index 83% rename from src/daft-logical-plan/src/builder.rs rename to src/daft-logical-plan/src/builder/mod.rs index 937fb45f44..9006505fca 100644 --- a/src/daft-logical-plan/src/builder.rs +++ b/src/daft-logical-plan/src/builder/mod.rs @@ -1,5 +1,10 @@ +mod resolve_expr; +#[cfg(test)] +mod tests; + use std::{ collections::{HashMap, HashSet}, + future::Future, sync::Arc, }; @@ -12,6 +17,10 @@ use common_scan_info::{PhysicalScanInfo, Pushdowns, ScanOperatorRef}; use daft_core::join::{JoinStrategy, JoinType}; use daft_dsl::{col, ExprRef}; use daft_schema::schema::{Schema, SchemaRef}; +use indexmap::IndexSet; +#[cfg(feature = "python")] +pub use resolve_expr::py_check_column_name_validity; +use resolve_expr::ExprResolver; #[cfg(feature = "python")] use { crate::sink_info::{CatalogInfo, IcebergCatalogInfo}, @@ -188,11 +197,19 @@ impl LogicalPlanBuilder { } pub fn select(&self, to_select: Vec) -> DaftResult { + let expr_resolver = ExprResolver::builder().allow_actor_pool_udf(true).build(); + + let (to_select, _) = expr_resolver.resolve(to_select, &self.schema())?; + let logical_plan: LogicalPlan = ops::Project::try_new(self.plan.clone(), to_select)?.into(); Ok(self.with_new_plan(logical_plan)) } pub fn with_columns(&self, columns: Vec) -> DaftResult { + let expr_resolver = ExprResolver::builder().allow_actor_pool_udf(true).build(); + + let (columns, _) = expr_resolver.resolve(columns, &self.schema())?; + let fields = &self.schema().fields; let current_col_names = fields .iter() @@ -245,6 +262,10 @@ impl LogicalPlanBuilder { } pub fn filter(&self, predicate: ExprRef) -> DaftResult { + let expr_resolver = ExprResolver::default(); + + let (predicate, _) = expr_resolver.resolve_single(predicate, &self.schema())?; + let logical_plan: LogicalPlan = ops::Filter::try_new(self.plan.clone(), predicate)?.into(); Ok(self.with_new_plan(logical_plan)) } @@ -255,6 +276,10 @@ impl LogicalPlanBuilder { } pub fn explode(&self, to_explode: Vec) -> DaftResult { + let expr_resolver = ExprResolver::default(); + + let (to_explode, _) = expr_resolver.resolve(to_explode, &self.schema())?; + let logical_plan: LogicalPlan = ops::Explode::try_new(self.plan.clone(), to_explode)?.into(); Ok(self.with_new_plan(logical_plan)) @@ -264,25 +289,24 @@ impl LogicalPlanBuilder { &self, ids: Vec, values: Vec, - variable_name: &str, - value_name: &str, + variable_name: String, + value_name: String, ) -> DaftResult { + let expr_resolver = ExprResolver::default(); + let (values, _) = expr_resolver.resolve(values, &self.schema())?; + let (ids, _) = expr_resolver.resolve(ids, &self.schema())?; + let values = if values.is_empty() { - let ids_set = HashSet::<_>::from_iter(ids.iter()); + let ids_set = IndexSet::<_>::from_iter(ids.iter().cloned()); - self.schema() + let columns_set = self + .schema() .fields - .iter() - .filter_map(|(name, _)| { - let column = col(name.clone()); + .keys() + .map(|name| col(name.clone())) + .collect::>(); - if ids_set.contains(&column) { - None - } else { - Some(column) - } - }) - .collect() + columns_set.difference(&ids_set).cloned().collect() } else { values }; @@ -299,6 +323,10 @@ impl LogicalPlanBuilder { descending: Vec, nulls_first: Vec, ) -> DaftResult { + let expr_resolver = ExprResolver::default(); + + let (sort_by, _) = expr_resolver.resolve(sort_by, &self.schema())?; + let logical_plan: LogicalPlan = ops::Sort::try_new(self.plan.clone(), sort_by, descending, nulls_first)?.into(); Ok(self.with_new_plan(logical_plan)) @@ -309,28 +337,32 @@ impl LogicalPlanBuilder { num_partitions: Option, partition_by: Vec, ) -> DaftResult { - let logical_plan: LogicalPlan = ops::Repartition::try_new( + let expr_resolver = ExprResolver::default(); + + let (partition_by, _) = expr_resolver.resolve(partition_by, &self.schema())?; + + let logical_plan: LogicalPlan = ops::Repartition::new( self.plan.clone(), RepartitionSpec::Hash(HashRepartitionConfig::new(num_partitions, partition_by)), - )? + ) .into(); Ok(self.with_new_plan(logical_plan)) } pub fn random_shuffle(&self, num_partitions: Option) -> DaftResult { - let logical_plan: LogicalPlan = ops::Repartition::try_new( + let logical_plan: LogicalPlan = ops::Repartition::new( self.plan.clone(), RepartitionSpec::Random(RandomShuffleConfig::new(num_partitions)), - )? + ) .into(); Ok(self.with_new_plan(logical_plan)) } pub fn into_partitions(&self, num_partitions: usize) -> DaftResult { - let logical_plan: LogicalPlan = ops::Repartition::try_new( + let logical_plan: LogicalPlan = ops::Repartition::new( self.plan.clone(), RepartitionSpec::IntoPartitions(IntoPartitionsConfig::new(num_partitions)), - )? + ) .into(); Ok(self.with_new_plan(logical_plan)) } @@ -356,6 +388,12 @@ impl LogicalPlanBuilder { agg_exprs: Vec, groupby_exprs: Vec, ) -> DaftResult { + let groupby_resolver = ExprResolver::default(); + let (groupby_exprs, _) = groupby_resolver.resolve(groupby_exprs, &self.schema())?; + + let agg_resolver = ExprResolver::builder().groupby(&groupby_exprs).build(); + let (agg_exprs, _) = agg_resolver.resolve(agg_exprs, &self.schema())?; + let logical_plan: LogicalPlan = ops::Aggregate::try_new(self.plan.clone(), agg_exprs, groupby_exprs)?.into(); Ok(self.with_new_plan(logical_plan)) @@ -369,6 +407,14 @@ impl LogicalPlanBuilder { agg_expr: ExprRef, names: Vec, ) -> DaftResult { + let agg_resolver = ExprResolver::builder().groupby(&group_by).build(); + let (agg_expr, _) = agg_resolver.resolve_single(agg_expr, &self.schema())?; + + let expr_resolver = ExprResolver::default(); + let (group_by, _) = expr_resolver.resolve(group_by, &self.schema())?; + let (pivot_column, _) = expr_resolver.resolve_single(pivot_column, &self.schema())?; + let (value_column, _) = expr_resolver.resolve_single(value_column, &self.schema())?; + let pivot_logical_plan: LogicalPlan = ops::Pivot::try_new( self.plan.clone(), group_by, @@ -438,17 +484,36 @@ impl LogicalPlanBuilder { join_prefix: Option<&str>, keep_join_keys: bool, ) -> DaftResult { + let left_plan = self.plan.clone(); + let right_plan = right.into(); + + let expr_resolver = ExprResolver::default(); + + let (left_on, _) = expr_resolver.resolve(left_on, &left_plan.schema())?; + let (right_on, _) = expr_resolver.resolve(right_on, &right_plan.schema())?; + + // TODO(kevin): we should do this, but it has not been properly used before and is nondeterministic, which causes some tests to break + // let (left_on, right_on) = ops::Join::rename_join_keys(left_on, right_on); + + let (right_plan, right_on) = ops::Join::rename_right_columns( + left_plan.clone(), + right_plan, + left_on.clone(), + right_on, + join_type, + join_suffix, + join_prefix, + keep_join_keys, + )?; + let logical_plan: LogicalPlan = ops::Join::try_new( - self.plan.clone(), - right.into(), + left_plan, + right_plan, left_on, right_on, null_equals_nulls, join_type, join_strategy, - join_suffix, - join_prefix, - keep_join_keys, )? .into(); Ok(self.with_new_plan(logical_plan)) @@ -501,7 +566,7 @@ impl LogicalPlanBuilder { pub fn add_monotonically_increasing_id(&self, column_name: Option<&str>) -> DaftResult { let logical_plan: LogicalPlan = - ops::MonotonicallyIncreasingId::new(self.plan.clone(), column_name).into(); + ops::MonotonicallyIncreasingId::try_new(self.plan.clone(), column_name)?.into(); Ok(self.with_new_plan(logical_plan)) } @@ -513,6 +578,16 @@ impl LogicalPlanBuilder { compression: Option, io_config: Option, ) -> DaftResult { + let partition_cols = partition_cols + .map(|cols| { + let expr_resolver = ExprResolver::default(); + + expr_resolver + .resolve(cols, &self.schema()) + .map(|(resolved_cols, _)| resolved_cols) + }) + .transpose()?; + let sink_info = SinkInfo::OutputFileInfo(OutputFileInfo::new( root_dir.into(), file_format, @@ -614,19 +689,79 @@ impl LogicalPlanBuilder { Ok(self.with_new_plan(logical_plan)) } + /// Async equivalent of `optimize` + /// This is safe to call from a tokio runtime + pub fn optimize_async(&self) -> impl Future> { + let cfg = self.config.clone(); + + // Run LogicalPlan optimizations + let unoptimized_plan = self.build(); + let (tx, rx) = tokio::sync::oneshot::channel(); + + std::thread::spawn(move || { + let optimizer = OptimizerBuilder::default() + .when( + cfg.as_ref() + .map_or(false, |conf| conf.enable_join_reordering), + |builder| builder.reorder_joins(), + ) + .simplify_expressions() + .build(); + + let optimized_plan = optimizer.optimize( + unoptimized_plan, + |new_plan, rule_batch, pass, transformed, seen| { + if transformed { + log::debug!( + "Rule batch {:?} transformed plan on pass {}, and produced {} plan:\n{}", + rule_batch, + pass, + if seen { "an already seen" } else { "a new" }, + new_plan.repr_ascii(true), + ); + } else { + log::debug!( + "Rule batch {:?} did NOT transform plan on pass {} for plan:\n{}", + rule_batch, + pass, + new_plan.repr_ascii(true), + ); + } + }, + ); + tx.send(optimized_plan).unwrap(); + }); + + let cfg = self.config.clone(); + async move { + rx.await + .map_err(|e| { + DaftError::InternalError(format!("Error optimizing logical plan: {:?}", e)) + })? + .map(|plan| Self::new(plan, cfg)) + } + } + + /// optimize the logical plan + /// + /// **Important**: Do not call this method from the main thread as there is a `block_on` call deep within this method + /// Calling will result in a runtime panic pub fn optimize(&self) -> DaftResult { + // TODO: remove the `block_on` to make this method safe to call from the main thread + + let cfg = self.config.clone(); + + let unoptimized_plan = self.build(); + let optimizer = OptimizerBuilder::default() .when( - self.config - .as_ref() + cfg.as_ref() .map_or(false, |conf| conf.enable_join_reordering), |builder| builder.reorder_joins(), ) .simplify_expressions() .build(); - // Run LogicalPlan optimizations - let unoptimized_plan = self.build(); let optimized_plan = optimizer.optimize( unoptimized_plan, |new_plan, rule_batch, pass, transformed, seen| { @@ -649,7 +784,7 @@ impl LogicalPlanBuilder { }, )?; - let builder = Self::new(optimized_plan, self.config.clone()); + let builder = Self::new(optimized_plan, cfg); Ok(builder) } @@ -752,8 +887,8 @@ impl PyLogicalPlanBuilder { &self, ids: Vec, values: Vec, - variable_name: &str, - value_name: &str, + variable_name: String, + value_name: String, ) -> PyResult { let ids_exprs = ids .iter() diff --git a/src/daft-dsl/src/resolve_expr/mod.rs b/src/daft-logical-plan/src/builder/resolve_expr.rs similarity index 93% rename from src/daft-dsl/src/resolve_expr/mod.rs rename to src/daft-logical-plan/src/builder/resolve_expr.rs index 35d97bc9a8..cd98930ca7 100644 --- a/src/daft-dsl/src/resolve_expr/mod.rs +++ b/src/daft-logical-plan/src/builder/resolve_expr.rs @@ -1,6 +1,3 @@ -#[cfg(test)] -mod tests; - use std::{ cmp::Ordering, collections::{BinaryHeap, HashMap, HashSet}, @@ -10,15 +7,16 @@ use std::{ use common_error::{DaftError, DaftResult}; use common_treenode::{Transformed, TransformedResult, TreeNode}; use daft_core::prelude::*; +#[cfg(feature = "python")] +use daft_core::python::PySchema; +use daft_dsl::{col, functions::FunctionExpr, has_agg, is_actor_pool_udf, AggExpr, Expr, ExprRef}; +#[cfg(feature = "python")] +use pyo3::prelude::*; use typed_builder::TypedBuilder; -use crate::{ - col, expr::has_agg, functions::FunctionExpr, is_actor_pool_udf, AggExpr, Expr, ExprRef, -}; - // Calculates all the possible struct get expressions in a schema. // For each sugared string, calculates all possible corresponding expressions, in order of priority. -fn calculate_struct_expr_map(schema: &Schema) -> HashMap> { +pub fn calculate_struct_expr_map(schema: &Schema) -> HashMap> { #[derive(PartialEq, Eq)] struct BfsState<'a> { name: String, @@ -61,7 +59,7 @@ fn calculate_struct_expr_map(schema: &Schema) -> HashMap> { for child in children { pq.push(BfsState { name: format!("{}.{}", name, child.name), - expr: crate::functions::struct_::get(expr.clone(), &child.name), + expr: daft_dsl::functions::struct_::get(expr.clone(), &child.name), field: child, }); } @@ -76,7 +74,7 @@ fn calculate_struct_expr_map(schema: &Schema) -> HashMap> { /// /// For example, if col("a.b.c") could be interpreted as either col("a.b").struct.get("c") /// or col("a").struct.get("b.c"), this function will resolve it to col("a.b").struct.get("c"). -fn transform_struct_gets( +pub fn transform_struct_gets( expr: ExprRef, struct_expr_map: &HashMap>, ) -> DaftResult { @@ -103,7 +101,10 @@ fn transform_struct_gets( // Finds the names of all the wildcard expressions in an expression tree. // Needs the schema because column names with stars must not count as wildcards -fn find_wildcards(expr: ExprRef, struct_expr_map: &HashMap>) -> Vec> { +pub fn find_wildcards( + expr: ExprRef, + struct_expr_map: &HashMap>, +) -> Vec> { match expr.as_ref() { Expr::Column(name) => { if name.contains('*') { @@ -346,7 +347,7 @@ impl<'a> ExprResolver<'a> { } } -pub fn check_column_name_validity(name: &str, schema: &Schema) -> DaftResult<()> { +fn check_column_name_validity(name: &str, schema: &Schema) -> DaftResult<()> { let struct_expr_map = calculate_struct_expr_map(schema); let names = if name == "*" || name.ends_with(".*") { @@ -371,3 +372,9 @@ pub fn check_column_name_validity(name: &str, schema: &Schema) -> DaftResult<()> Ok(()) } + +#[cfg(feature = "python")] +#[pyfunction(name = "check_column_name_validity")] +pub fn py_check_column_name_validity(name: &str, schema: &PySchema) -> PyResult<()> { + Ok(check_column_name_validity(name, &schema.schema)?) +} diff --git a/src/daft-dsl/src/resolve_expr/tests.rs b/src/daft-logical-plan/src/builder/tests.rs similarity index 94% rename from src/daft-dsl/src/resolve_expr/tests.rs rename to src/daft-logical-plan/src/builder/tests.rs index dcb3147207..f8e98526f0 100644 --- a/src/daft-dsl/src/resolve_expr/tests.rs +++ b/src/daft-logical-plan/src/builder/tests.rs @@ -1,4 +1,11 @@ -use super::*; +use std::sync::Arc; + +use common_error::{DaftError, DaftResult}; +use daft_core::prelude::Schema; +use daft_dsl::{col, ExprRef}; +use daft_schema::{dtype::DataType, field::Field}; + +use super::resolve_expr::*; fn substitute_expr_getter_sugar(expr: ExprRef, schema: &Schema) -> DaftResult { let struct_expr_map = calculate_struct_expr_map(schema); @@ -7,7 +14,7 @@ fn substitute_expr_getter_sugar(expr: ExprRef, schema: &Schema) -> DaftResult DaftResult<()> { - use crate::functions::struct_::get as struct_get; + use daft_dsl::functions::struct_::get as struct_get; let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64)])?); diff --git a/src/daft-logical-plan/src/display.rs b/src/daft-logical-plan/src/display.rs index 84db90d273..3958477b77 100644 --- a/src/daft-logical-plan/src/display.rs +++ b/src/daft-logical-plan/src/display.rs @@ -94,7 +94,7 @@ mod test { startswith(col("last_name"), lit("S")).and(endswith(col("last_name"), lit("n"))), )? .limit(1000, false)? - .add_monotonically_increasing_id(None)? + .add_monotonically_increasing_id(Some("id2"))? .distinct()? .sort(vec![col("last_name")], vec![false], vec![false])? .build(); @@ -124,7 +124,7 @@ Filter2["Filter: col(first_name) == lit('hello')"] Join3["Join: Type = Inner Strategy = Auto On = col(id) -Output schema = id#Int32, text#Utf8, first_name#Utf8, last_name#Utf8"] +Output schema = id#Int32, text#Utf8, id2#UInt64, first_name#Utf8, last_name#Utf8"] Filter4["Filter: col(id) == lit(1)"] Source5["PlaceHolder: Source ID = 0 @@ -168,7 +168,7 @@ Project1 --> Limit0 startswith(col("last_name"), lit("S")).and(endswith(col("last_name"), lit("n"))), )? .limit(1000, false)? - .add_monotonically_increasing_id(None)? + .add_monotonically_increasing_id(Some("id2"))? .distinct()? .sort(vec![col("last_name")], vec![false], vec![false])? .build(); diff --git a/src/daft-logical-plan/src/lib.rs b/src/daft-logical-plan/src/lib.rs index 5296d99a23..317a92535e 100644 --- a/src/daft-logical-plan/src/lib.rs +++ b/src/daft-logical-plan/src/lib.rs @@ -41,6 +41,10 @@ pub fn register_modules(parent: &Bound) -> PyResult<()> { parent.add_class::()?; parent.add_class::()?; parent.add_class::()?; + parent.add_function(wrap_pyfunction!( + builder::py_check_column_name_validity, + parent + )?)?; Ok(()) } diff --git a/src/daft-logical-plan/src/logical_plan.rs b/src/daft-logical-plan/src/logical_plan.rs index 4abd15fcaa..7a3ccd04b2 100644 --- a/src/daft-logical-plan/src/logical_plan.rs +++ b/src/daft-logical-plan/src/logical_plan.rs @@ -334,12 +334,12 @@ impl LogicalPlan { Self::Limit(Limit { limit, eager, .. }) => Self::Limit(Limit::new(input.clone(), *limit, *eager)), Self::Explode(Explode { to_explode, .. }) => Self::Explode(Explode::try_new(input.clone(), to_explode.clone()).unwrap()), Self::Sort(Sort { sort_by, descending, nulls_first, .. }) => Self::Sort(Sort::try_new(input.clone(), sort_by.clone(), descending.clone(), nulls_first.clone()).unwrap()), - Self::Repartition(Repartition { repartition_spec: scheme_config, .. }) => Self::Repartition(Repartition::try_new(input.clone(), scheme_config.clone()).unwrap()), + Self::Repartition(Repartition { repartition_spec: scheme_config, .. }) => Self::Repartition(Repartition::new(input.clone(), scheme_config.clone())), Self::Distinct(_) => Self::Distinct(Distinct::new(input.clone())), Self::Aggregate(Aggregate { aggregations, groupby, ..}) => Self::Aggregate(Aggregate::try_new(input.clone(), aggregations.clone(), groupby.clone()).unwrap()), Self::Pivot(Pivot { group_by, pivot_column, value_column, aggregation, names, ..}) => Self::Pivot(Pivot::try_new(input.clone(), group_by.clone(), pivot_column.clone(), value_column.clone(), aggregation.into(), names.clone()).unwrap()), Self::Sink(Sink { sink_info, .. }) => Self::Sink(Sink::try_new(input.clone(), sink_info.clone()).unwrap()), - Self::MonotonicallyIncreasingId(MonotonicallyIncreasingId {column_name, .. }) => Self::MonotonicallyIncreasingId(MonotonicallyIncreasingId::new(input.clone(), Some(column_name))), + Self::MonotonicallyIncreasingId(MonotonicallyIncreasingId {column_name, .. }) => Self::MonotonicallyIncreasingId(MonotonicallyIncreasingId::try_new(input.clone(), Some(column_name)).unwrap()), Self::Unpivot(Unpivot {ids, values, variable_name, value_name, output_schema, ..}) => Self::Unpivot(Unpivot::new(input.clone(), ids.clone(), values.clone(), variable_name.clone(), value_name.clone(), output_schema.clone())), Self::Sample(Sample {fraction, with_replacement, seed, ..}) => Self::Sample(Sample::new(input.clone(), *fraction, *with_replacement, *seed)), @@ -361,9 +361,6 @@ impl LogicalPlan { null_equals_nulls.clone(), *join_type, *join_strategy, - None, // The suffix is already eagerly computed in the constructor - None, // the prefix is already eagerly computed in the constructor - false // this is already eagerly computed in the constructor ).unwrap()), _ => panic!("Logical op {} has one input, but got two", self), }, diff --git a/src/daft-logical-plan/src/ops/actor_pool_project.rs b/src/daft-logical-plan/src/ops/actor_pool_project.rs index d9a2aa0c4b..f76d9b9fc3 100644 --- a/src/daft-logical-plan/src/ops/actor_pool_project.rs +++ b/src/daft-logical-plan/src/ops/actor_pool_project.rs @@ -3,16 +3,15 @@ use std::sync::Arc; use common_error::DaftError; use common_resource_request::ResourceRequest; use daft_dsl::{ - count_actor_pool_udfs, + count_actor_pool_udfs, exprs_to_schema, functions::python::{get_concurrency, get_resource_request, get_udf_names}, - ExprRef, ExprResolver, + ExprRef, }; -use daft_schema::schema::{Schema, SchemaRef}; +use daft_schema::schema::SchemaRef; use itertools::Itertools; -use snafu::ResultExt; use crate::{ - logical_plan::{CreationSnafu, Error, Result}, + logical_plan::{Error, Result}, stats::StatsState, LogicalPlan, }; @@ -28,17 +27,12 @@ pub struct ActorPoolProject { impl ActorPoolProject { pub(crate) fn try_new(input: Arc, projection: Vec) -> Result { - let expr_resolver = ExprResolver::builder().allow_actor_pool_udf(true).build(); - let (projection, fields) = expr_resolver - .resolve(projection, input.schema().as_ref()) - .context(CreationSnafu)?; - let num_actor_pool_udfs: usize = count_actor_pool_udfs(&projection); if !num_actor_pool_udfs == 1 { return Err(Error::CreationError { source: DaftError::InternalError(format!("Expected ActorPoolProject to have exactly 1 actor pool UDF expression but found: {num_actor_pool_udfs}")) }); } - let projected_schema = Schema::new(fields).context(CreationSnafu)?.into(); + let projected_schema = exprs_to_schema(&projection, input.schema())?; Ok(Self { input, diff --git a/src/daft-logical-plan/src/ops/agg.rs b/src/daft-logical-plan/src/ops/agg.rs index 5b99338b1c..826e98c9a8 100644 --- a/src/daft-logical-plan/src/ops/agg.rs +++ b/src/daft-logical-plan/src/ops/agg.rs @@ -1,12 +1,11 @@ use std::sync::Arc; -use daft_dsl::{ExprRef, ExprResolver}; -use daft_schema::schema::{Schema, SchemaRef}; +use daft_dsl::{exprs_to_schema, ExprRef}; +use daft_schema::schema::SchemaRef; use itertools::Itertools; -use snafu::ResultExt; use crate::{ - logical_plan::{self, CreationSnafu}, + logical_plan::{self}, stats::{ApproxStats, PlanStats, StatsState}, LogicalPlan, }; @@ -36,21 +35,10 @@ impl Aggregate { aggregations: Vec, groupby: Vec, ) -> logical_plan::Result { - let upstream_schema = input.schema(); - - let agg_resolver = ExprResolver::builder().groupby(&groupby).build(); - let (aggregations, aggregation_fields) = agg_resolver - .resolve(aggregations, &upstream_schema) - .context(CreationSnafu)?; - - let groupby_resolver = ExprResolver::default(); - let (groupby, groupby_fields) = groupby_resolver - .resolve(groupby, &upstream_schema) - .context(CreationSnafu)?; - - let fields = [groupby_fields, aggregation_fields].concat(); - - let output_schema = Schema::new(fields).context(CreationSnafu)?.into(); + let output_schema = exprs_to_schema( + &[groupby.as_slice(), aggregations.as_slice()].concat(), + input.schema(), + )?; Ok(Self { input, @@ -64,33 +52,19 @@ impl Aggregate { pub(crate) fn with_materialized_stats(mut self) -> Self { // TODO(desmond): We can use the schema here for better estimations. For now, use the old logic. let input_stats = self.input.materialized_stats(); - let est_bytes_per_row_lower = input_stats.approx_stats.lower_bound_bytes - / (input_stats.approx_stats.lower_bound_rows.max(1)); - let est_bytes_per_row_upper = - input_stats - .approx_stats - .upper_bound_bytes - .and_then(|bytes| { - input_stats - .approx_stats - .upper_bound_rows - .map(|rows| bytes / rows.max(1)) - }); + let est_bytes_per_row = + input_stats.approx_stats.size_bytes / (input_stats.approx_stats.num_rows.max(1)); let approx_stats = if self.groupby.is_empty() { ApproxStats { - lower_bound_rows: input_stats.approx_stats.lower_bound_rows.min(1), - upper_bound_rows: Some(1), - lower_bound_bytes: input_stats.approx_stats.lower_bound_bytes.min(1) - * est_bytes_per_row_lower, - upper_bound_bytes: est_bytes_per_row_upper, + num_rows: 1, + size_bytes: est_bytes_per_row, } } else { + // Assume high cardinality for group by columns, and 80% of rows are unique. + let est_num_groups = input_stats.approx_stats.num_rows * 4 / 5; ApproxStats { - lower_bound_rows: input_stats.approx_stats.lower_bound_rows.min(1), - upper_bound_rows: input_stats.approx_stats.upper_bound_rows, - lower_bound_bytes: input_stats.approx_stats.lower_bound_bytes.min(1) - * est_bytes_per_row_lower, - upper_bound_bytes: input_stats.approx_stats.upper_bound_bytes, + num_rows: est_num_groups, + size_bytes: est_bytes_per_row * est_num_groups, } }; self.stats_state = StatsState::Materialized(PlanStats::new(approx_stats).into()); diff --git a/src/daft-logical-plan/src/ops/concat.rs b/src/daft-logical-plan/src/ops/concat.rs index fb18441c4c..207bceffed 100644 --- a/src/daft-logical-plan/src/ops/concat.rs +++ b/src/daft-logical-plan/src/ops/concat.rs @@ -18,14 +18,6 @@ pub struct Concat { } impl Concat { - pub(crate) fn new(input: Arc, other: Arc) -> Self { - Self { - input, - other, - stats_state: StatsState::NotMaterialized, - } - } - pub(crate) fn try_new( input: Arc, other: Arc, @@ -39,6 +31,7 @@ impl Concat { ))) .context(CreationSnafu); } + Ok(Self { input, other, diff --git a/src/daft-logical-plan/src/ops/distinct.rs b/src/daft-logical-plan/src/ops/distinct.rs index 899dab940b..dba86e7e44 100644 --- a/src/daft-logical-plan/src/ops/distinct.rs +++ b/src/daft-logical-plan/src/ops/distinct.rs @@ -23,14 +23,13 @@ impl Distinct { pub(crate) fn with_materialized_stats(mut self) -> Self { // TODO(desmond): We can simply use NDVs here. For now, do a naive estimation. let input_stats = self.input.materialized_stats(); - let est_bytes_per_row_lower = input_stats.approx_stats.lower_bound_bytes - / (input_stats.approx_stats.lower_bound_rows.max(1)); + let est_bytes_per_row = + input_stats.approx_stats.size_bytes / (input_stats.approx_stats.num_rows.max(1)); + // Assume high cardinality, 80% of rows are distinct. + let est_distinct_values = input_stats.approx_stats.num_rows * 4 / 5; let approx_stats = ApproxStats { - lower_bound_rows: input_stats.approx_stats.lower_bound_rows.min(1), - upper_bound_rows: input_stats.approx_stats.upper_bound_rows, - lower_bound_bytes: input_stats.approx_stats.lower_bound_bytes.min(1) - * est_bytes_per_row_lower, - upper_bound_bytes: input_stats.approx_stats.upper_bound_bytes, + num_rows: est_distinct_values, + size_bytes: est_distinct_values * est_bytes_per_row, }; self.stats_state = StatsState::Materialized(PlanStats::new(approx_stats).into()); self diff --git a/src/daft-logical-plan/src/ops/explode.rs b/src/daft-logical-plan/src/ops/explode.rs index 00624102f4..ed214430e8 100644 --- a/src/daft-logical-plan/src/ops/explode.rs +++ b/src/daft-logical-plan/src/ops/explode.rs @@ -1,12 +1,11 @@ use std::sync::Arc; -use daft_dsl::{ExprRef, ExprResolver}; +use daft_dsl::{exprs_to_schema, ExprRef}; use daft_schema::schema::{Schema, SchemaRef}; use itertools::Itertools; -use snafu::ResultExt; use crate::{ - logical_plan::{self, CreationSnafu}, + logical_plan::{self}, stats::{ApproxStats, PlanStats, StatsState}, LogicalPlan, }; @@ -26,35 +25,23 @@ impl Explode { input: Arc, to_explode: Vec, ) -> logical_plan::Result { - let upstream_schema = input.schema(); - - let expr_resolver = ExprResolver::default(); + let exploded_schema = { + let explode_exprs = to_explode + .iter() + .cloned() + .map(daft_functions::list::explode) + .collect::>(); - let (to_explode, _) = expr_resolver - .resolve(to_explode, &upstream_schema) - .context(CreationSnafu)?; + let explode_schema = exprs_to_schema(&explode_exprs, input.schema())?; - let explode_exprs = to_explode - .iter() - .cloned() - .map(daft_functions::list::explode) - .collect::>(); - let exploded_schema = { - let explode_schema = { - let explode_fields = explode_exprs - .iter() - .map(|e| e.to_field(&upstream_schema)) - .collect::>>() - .context(CreationSnafu)?; - Schema::new(explode_fields).context(CreationSnafu)? - }; - let fields = upstream_schema + let fields = input + .schema() .fields .iter() .map(|(name, field)| explode_schema.fields.get(name).unwrap_or(field)) .cloned() .collect::>(); - Schema::new(fields).context(CreationSnafu)?.into() + Schema::new(fields)?.into() }; Ok(Self { @@ -67,11 +54,10 @@ impl Explode { pub(crate) fn with_materialized_stats(mut self) -> Self { let input_stats = self.input.materialized_stats(); + let est_num_exploded_rows = input_stats.approx_stats.num_rows * 4; let approx_stats = ApproxStats { - lower_bound_rows: input_stats.approx_stats.lower_bound_rows, - upper_bound_rows: None, - lower_bound_bytes: input_stats.approx_stats.lower_bound_bytes, - upper_bound_bytes: None, + num_rows: est_num_exploded_rows, + size_bytes: input_stats.approx_stats.size_bytes, }; self.stats_state = StatsState::Materialized(PlanStats::new(approx_stats).into()); self diff --git a/src/daft-logical-plan/src/ops/filter.rs b/src/daft-logical-plan/src/ops/filter.rs index 62bb34a46a..a8f6507641 100644 --- a/src/daft-logical-plan/src/ops/filter.rs +++ b/src/daft-logical-plan/src/ops/filter.rs @@ -2,11 +2,11 @@ use std::sync::Arc; use common_error::DaftError; use daft_core::prelude::*; -use daft_dsl::{ExprRef, ExprResolver}; +use daft_dsl::{estimated_selectivity, ExprRef}; use snafu::ResultExt; use crate::{ - logical_plan::{CreationSnafu, Result}, + logical_plan::{self, CreationSnafu}, stats::{ApproxStats, PlanStats, StatsState}, LogicalPlan, }; @@ -21,17 +21,16 @@ pub struct Filter { } impl Filter { - pub(crate) fn try_new(input: Arc, predicate: ExprRef) -> Result { - let expr_resolver = ExprResolver::default(); + pub(crate) fn try_new( + input: Arc, + predicate: ExprRef, + ) -> logical_plan::Result { + let dtype = predicate.to_field(&input.schema())?.dtype; - let (predicate, field) = expr_resolver - .resolve_single(predicate, &input.schema()) - .context(CreationSnafu)?; - - if !matches!(field.dtype, DataType::Boolean) { + if !matches!(dtype, DataType::Boolean) { return Err(DaftError::ValueError(format!( "Expected expression {predicate} to resolve to type Boolean, but received: {}", - field.dtype + dtype ))) .context(CreationSnafu); } @@ -46,13 +45,12 @@ impl Filter { // Assume no row/column pruning in cardinality-affecting operations. // TODO(desmond): We can do better estimations here. For now, reuse the old logic. let input_stats = self.input.materialized_stats(); - let upper_bound_rows = input_stats.approx_stats.upper_bound_rows; - let upper_bound_bytes = input_stats.approx_stats.upper_bound_bytes; + let estimated_selectivity = estimated_selectivity(&self.predicate, &self.input.schema()); let approx_stats = ApproxStats { - lower_bound_rows: 0, - upper_bound_rows, - lower_bound_bytes: 0, - upper_bound_bytes, + num_rows: (input_stats.approx_stats.num_rows as f64 * estimated_selectivity).ceil() + as usize, + size_bytes: (input_stats.approx_stats.size_bytes as f64 * estimated_selectivity).ceil() + as usize, }; self.stats_state = StatsState::Materialized(PlanStats::new(approx_stats).into()); self diff --git a/src/daft-logical-plan/src/ops/join.rs b/src/daft-logical-plan/src/ops/join.rs index 5484a5c701..f7ad07737c 100644 --- a/src/daft-logical-plan/src/ops/join.rs +++ b/src/daft-logical-plan/src/ops/join.rs @@ -9,7 +9,7 @@ use daft_dsl::{ col, join::{get_common_join_keys, infer_join_schema}, optimization::replace_columns_with_expressions, - Expr, ExprRef, ExprResolver, + Expr, ExprRef, }; use itertools::Itertools; use snafu::ResultExt; @@ -19,7 +19,7 @@ use crate::{ logical_plan::{self, CreationSnafu}, ops::Project, stats::{ApproxStats, PlanStats, StatsState}, - LogicalPlan, + LogicalPlan, LogicalPlanRef, }; #[derive(Clone, Debug, PartialEq, Eq)] @@ -51,30 +51,6 @@ impl std::hash::Hash for Join { } impl Join { - #[allow(clippy::too_many_arguments)] - pub(crate) fn new( - left: Arc, - right: Arc, - left_on: Vec, - right_on: Vec, - null_equals_nulls: Option>, - join_type: JoinType, - join_strategy: Option, - output_schema: SchemaRef, - ) -> Self { - Self { - left, - right, - left_on, - right_on, - null_equals_nulls, - join_type, - join_strategy, - output_schema, - stats_state: StatsState::NotMaterialized, - } - } - #[allow(clippy::too_many_arguments)] pub(crate) fn try_new( left: Arc, @@ -84,45 +60,11 @@ impl Join { null_equals_nulls: Option>, join_type: JoinType, join_strategy: Option, - join_suffix: Option<&str>, - join_prefix: Option<&str>, - // if true, then duplicate column names will be kept - // ex: select * from a left join b on a.id = b.id - // if true, then the resulting schema will have two columns named id (id, and b.id) - // In SQL the join column is always kept, while in dataframes it is not - keep_join_keys: bool, ) -> logical_plan::Result { - let expr_resolver = ExprResolver::default(); - - let (left_on, _) = expr_resolver - .resolve(left_on, &left.schema()) - .context(CreationSnafu)?; - let (right_on, _) = expr_resolver - .resolve(right_on, &right.schema()) - .context(CreationSnafu)?; - - let (unique_left_on, unique_right_on) = - Self::rename_join_keys(left_on.clone(), right_on.clone()); - - let left_fields: Vec = unique_left_on - .iter() - .map(|e| e.to_field(&left.schema())) - .collect::>>() - .context(CreationSnafu)?; - - let right_fields: Vec = unique_right_on - .iter() - .map(|e| e.to_field(&right.schema())) - .collect::>>() - .context(CreationSnafu)?; - - for (on_exprs, on_fields) in [ - (&unique_left_on, &left_fields), - (&unique_right_on, &right_fields), - ] { - for (field, expr) in on_fields.iter().zip(on_exprs.iter()) { + for (on_exprs, side) in [(&left_on, &left), (&right_on, &right)] { + for expr in on_exprs { // Null type check for both fields and expressions - if matches!(field.dtype, DataType::Null) { + if matches!(expr.to_field(&side.schema())?.dtype, DataType::Null) { return Err(DaftError::ValueError(format!( "Can't join on null type expressions: {expr}" ))) @@ -141,22 +83,42 @@ impl Join { } } - if matches!(join_type, JoinType::Anti | JoinType::Semi) { - // The output schema is the same as the left input schema for anti and semi joins. + let output_schema = infer_join_schema( + &left.schema(), + &right.schema(), + &left_on, + &right_on, + join_type, + )?; - let output_schema = left.schema(); + Ok(Self { + left, + right, + left_on, + right_on, + null_equals_nulls, + join_type, + join_strategy, + output_schema, + stats_state: StatsState::NotMaterialized, + }) + } - Ok(Self { - left, - right, - left_on, - right_on, - null_equals_nulls, - join_type, - join_strategy, - output_schema, - stats_state: StatsState::NotMaterialized, - }) + /// Add a project under the right side plan when necessary in order to resolve naming conflicts + /// between left and right side columns. + #[allow(clippy::too_many_arguments)] + pub(crate) fn rename_right_columns( + left: LogicalPlanRef, + right: LogicalPlanRef, + left_on: Vec, + right_on: Vec, + join_type: JoinType, + join_suffix: Option<&str>, + join_prefix: Option<&str>, + keep_join_keys: bool, + ) -> DaftResult<(LogicalPlanRef, Vec)> { + if matches!(join_type, JoinType::Anti | JoinType::Semi) { + Ok((right, right_on)) } else { let common_join_keys: HashSet<_> = get_common_join_keys(left_on.as_slice(), right_on.as_slice()) @@ -202,8 +164,8 @@ impl Join { }) .collect(); - let (right, right_on) = if right_rename_mapping.is_empty() { - (right, right_on) + if right_rename_mapping.is_empty() { + Ok((right, right_on)) } else { // projection to update the right side with the new column names let new_right_projection: Vec<_> = right_names @@ -230,29 +192,8 @@ impl Join { .map(|expr| replace_columns_with_expressions(expr, &right_on_replace_map)) .collect::>(); - (new_right.into(), new_right_on) - }; - - let output_schema = infer_join_schema( - &left.schema(), - &right.schema(), - &left_on, - &right_on, - join_type, - ) - .context(CreationSnafu)?; - - Ok(Self { - left, - right, - left_on, - right_on, - null_equals_nulls, - join_type, - join_strategy, - output_schema, - stats_state: StatsState::NotMaterialized, - }) + Ok((new_right.into(), new_right_on)) + } } } @@ -282,8 +223,8 @@ impl Join { /// ``` /// /// For more details, see [issue #2649](https://github.com/Eventual-Inc/Daft/issues/2649). - - fn rename_join_keys( + #[allow(dead_code)] + pub(crate) fn rename_join_keys( left_exprs: Vec>, right_exprs: Vec>, ) -> (Vec>, Vec>) { @@ -317,16 +258,14 @@ impl Join { let left_stats = self.left.materialized_stats(); let right_stats = self.right.materialized_stats(); let approx_stats = ApproxStats { - lower_bound_rows: 0, - upper_bound_rows: left_stats + num_rows: left_stats .approx_stats - .upper_bound_rows - .and_then(|l| right_stats.approx_stats.upper_bound_rows.map(|r| l.max(r))), - lower_bound_bytes: 0, - upper_bound_bytes: left_stats + .num_rows + .max(right_stats.approx_stats.num_rows), + size_bytes: left_stats .approx_stats - .upper_bound_bytes - .and_then(|l| right_stats.approx_stats.upper_bound_bytes.map(|r| l.max(r))), + .size_bytes + .max(right_stats.approx_stats.size_bytes), }; self.stats_state = StatsState::Materialized(PlanStats::new(approx_stats).into()); self diff --git a/src/daft-logical-plan/src/ops/limit.rs b/src/daft-logical-plan/src/ops/limit.rs index fdb2ecab7c..5b8176fa6a 100644 --- a/src/daft-logical-plan/src/ops/limit.rs +++ b/src/daft-logical-plan/src/ops/limit.rs @@ -30,29 +30,15 @@ impl Limit { pub(crate) fn with_materialized_stats(mut self) -> Self { let input_stats = self.input.materialized_stats(); let limit = self.limit as usize; - let est_bytes_per_row_lower = input_stats.approx_stats.lower_bound_bytes - / input_stats.approx_stats.lower_bound_rows.max(1); - let est_bytes_per_row_upper = - input_stats - .approx_stats - .upper_bound_bytes - .and_then(|bytes| { - input_stats - .approx_stats - .upper_bound_rows - .map(|rows| bytes / rows.max(1)) - }); - let new_lower_rows = input_stats.approx_stats.lower_bound_rows.min(limit); - let new_upper_rows = input_stats - .approx_stats - .upper_bound_rows - .map(|ub| ub.min(limit)) - .unwrap_or(limit); let approx_stats = ApproxStats { - lower_bound_rows: new_lower_rows, - upper_bound_rows: Some(new_upper_rows), - lower_bound_bytes: new_lower_rows * est_bytes_per_row_lower, - upper_bound_bytes: est_bytes_per_row_upper.map(|x| x * new_upper_rows), + num_rows: limit.min(input_stats.approx_stats.num_rows), + size_bytes: if input_stats.approx_stats.num_rows > limit { + let est_bytes_per_row = + input_stats.approx_stats.size_bytes / input_stats.approx_stats.num_rows.max(1); + limit * est_bytes_per_row + } else { + input_stats.approx_stats.size_bytes + }, }; self.stats_state = StatsState::Materialized(PlanStats::new(approx_stats).into()); self diff --git a/src/daft-logical-plan/src/ops/monotonically_increasing_id.rs b/src/daft-logical-plan/src/ops/monotonically_increasing_id.rs index 170296fa2a..ea288ab446 100644 --- a/src/daft-logical-plan/src/ops/monotonically_increasing_id.rs +++ b/src/daft-logical-plan/src/ops/monotonically_increasing_id.rs @@ -2,7 +2,11 @@ use std::sync::Arc; use daft_core::prelude::*; -use crate::{stats::StatsState, LogicalPlan}; +use crate::{ + logical_plan::{self}, + stats::StatsState, + LogicalPlan, +}; #[derive(Hash, Eq, PartialEq, Debug, Clone)] pub struct MonotonicallyIncreasingId { @@ -13,25 +17,23 @@ pub struct MonotonicallyIncreasingId { } impl MonotonicallyIncreasingId { - pub(crate) fn new(input: Arc, column_name: Option<&str>) -> Self { + pub(crate) fn try_new( + input: Arc, + column_name: Option<&str>, + ) -> logical_plan::Result { let column_name = column_name.unwrap_or("id"); - let mut schema_with_id_index_map = input.schema().fields.clone(); - schema_with_id_index_map.shift_insert( - 0, - column_name.to_string(), - Field::new(column_name, DataType::UInt64), - ); - let schema_with_id = Schema { - fields: schema_with_id_index_map, - }; - - Self { + let fields_with_id = std::iter::once(Field::new(column_name, DataType::UInt64)) + .chain(input.schema().fields.values().cloned()) + .collect(); + let schema_with_id = Schema::new(fields_with_id)?; + + Ok(Self { input, schema: Arc::new(schema_with_id), column_name: column_name.to_string(), stats_state: StatsState::NotMaterialized, - } + }) } pub(crate) fn with_materialized_stats(mut self) -> Self { diff --git a/src/daft-logical-plan/src/ops/pivot.rs b/src/daft-logical-plan/src/ops/pivot.rs index 57ee3bb1c5..cb24e47232 100644 --- a/src/daft-logical-plan/src/ops/pivot.rs +++ b/src/daft-logical-plan/src/ops/pivot.rs @@ -1,14 +1,13 @@ use std::sync::Arc; -use common_error::DaftError; +use common_error::{DaftError, DaftResult}; use daft_core::prelude::*; -use daft_dsl::{AggExpr, Expr, ExprRef, ExprResolver}; +use daft_dsl::{AggExpr, Expr, ExprRef}; use daft_schema::schema::{Schema, SchemaRef}; use itertools::Itertools; -use snafu::ResultExt; use crate::{ - logical_plan::{self, CreationSnafu}, + logical_plan::{self}, stats::StatsState, LogicalPlan, }; @@ -34,24 +33,6 @@ impl Pivot { aggregation: ExprRef, names: Vec, ) -> logical_plan::Result { - let upstream_schema = input.schema(); - - let agg_resolver = ExprResolver::builder().groupby(&group_by).build(); - let (aggregation, _) = agg_resolver - .resolve_single(aggregation, &upstream_schema) - .context(CreationSnafu)?; - - let expr_resolver = ExprResolver::default(); - let (group_by, group_by_fields) = expr_resolver - .resolve(group_by, &upstream_schema) - .context(CreationSnafu)?; - let (pivot_column, _) = expr_resolver - .resolve_single(pivot_column, &upstream_schema) - .context(CreationSnafu)?; - let (value_column, value_col_field) = expr_resolver - .resolve_single(value_column, &upstream_schema) - .context(CreationSnafu)?; - let Expr::Agg(agg_expr) = aggregation.as_ref() else { return Err(DaftError::ValueError(format!( "Pivot only supports using top level aggregation expressions, received {aggregation}", @@ -60,16 +41,22 @@ impl Pivot { }; let output_schema = { - let value_col_dtype = value_col_field.dtype; + let value_col_dtype = value_column.to_field(&input.schema())?.dtype; let pivot_value_fields = names .iter() .map(|f| Field::new(f, value_col_dtype.clone())) .collect::>(); + + let group_by_fields = group_by + .iter() + .map(|expr| expr.to_field(&input.schema())) + .collect::>>()?; + let fields = group_by_fields .into_iter() .chain(pivot_value_fields) .collect::>(); - Schema::new(fields).context(CreationSnafu)?.into() + Schema::new(fields)?.into() }; Ok(Self { diff --git a/src/daft-logical-plan/src/ops/project.rs b/src/daft-logical-plan/src/ops/project.rs index 165d989a09..171899203c 100644 --- a/src/daft-logical-plan/src/ops/project.rs +++ b/src/daft-logical-plan/src/ops/project.rs @@ -1,14 +1,14 @@ use std::sync::Arc; +use common_error::DaftResult; use common_treenode::Transformed; use daft_core::prelude::*; -use daft_dsl::{optimization, AggExpr, ApproxPercentileParams, Expr, ExprRef, ExprResolver}; +use daft_dsl::{optimization, AggExpr, ApproxPercentileParams, Expr, ExprRef}; use indexmap::{IndexMap, IndexSet}; use itertools::Itertools; -use snafu::ResultExt; use crate::{ - logical_plan::{CreationSnafu, Result}, + logical_plan::{self}, stats::StatsState, LogicalPlan, }; @@ -23,18 +23,20 @@ pub struct Project { } impl Project { - pub(crate) fn try_new(input: Arc, projection: Vec) -> Result { - let expr_resolver = ExprResolver::builder().allow_actor_pool_udf(true).build(); - - let (projection, fields) = expr_resolver - .resolve(projection, &input.schema()) - .context(CreationSnafu)?; - + pub(crate) fn try_new( + input: Arc, + projection: Vec, + ) -> logical_plan::Result { // Factor the projection and see if there are any substitutions to factor out. let (factored_input, factored_projection) = Self::try_factor_subexpressions(input, projection)?; - let projected_schema = Schema::new(fields).context(CreationSnafu)?.into(); + let fields = factored_projection + .iter() + .map(|expr| expr.to_field(&factored_input.schema())) + .collect::>()?; + + let projected_schema = Schema::new(fields)?.into(); Ok(Self { input: factored_input, @@ -45,7 +47,10 @@ impl Project { } /// Create a new Projection using the specified output schema - pub(crate) fn new_from_schema(input: Arc, schema: SchemaRef) -> Result { + pub(crate) fn new_from_schema( + input: Arc, + schema: SchemaRef, + ) -> logical_plan::Result { let expr: Vec = schema .names() .into_iter() @@ -75,7 +80,7 @@ impl Project { fn try_factor_subexpressions( input: Arc, projection: Vec, - ) -> Result<(Arc, Vec)> { + ) -> logical_plan::Result<(Arc, Vec)> { // Given construction parameters for a projection, // see if we can factor out common subexpressions. // Returns a new set of projection parameters diff --git a/src/daft-logical-plan/src/ops/repartition.rs b/src/daft-logical-plan/src/ops/repartition.rs index ac12970c49..d67ccd86f7 100644 --- a/src/daft-logical-plan/src/ops/repartition.rs +++ b/src/daft-logical-plan/src/ops/repartition.rs @@ -1,13 +1,6 @@ use std::sync::Arc; -use common_error::DaftResult; -use daft_dsl::ExprResolver; - -use crate::{ - partitioning::{HashRepartitionConfig, RepartitionSpec}, - stats::StatsState, - LogicalPlan, -}; +use crate::{partitioning::RepartitionSpec, stats::StatsState, LogicalPlan}; #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct Repartition { @@ -18,28 +11,12 @@ pub struct Repartition { } impl Repartition { - pub(crate) fn try_new( - input: Arc, - repartition_spec: RepartitionSpec, - ) -> DaftResult { - let repartition_spec = match repartition_spec { - RepartitionSpec::Hash(HashRepartitionConfig { num_partitions, by }) => { - let expr_resolver = ExprResolver::default(); - - let (resolved_by, _) = expr_resolver.resolve(by, &input.schema())?; - RepartitionSpec::Hash(HashRepartitionConfig { - num_partitions, - by: resolved_by, - }) - } - RepartitionSpec::Random(_) | RepartitionSpec::IntoPartitions(_) => repartition_spec, - }; - - Ok(Self { + pub(crate) fn new(input: Arc, repartition_spec: RepartitionSpec) -> Self { + Self { input, repartition_spec, stats_state: StatsState::NotMaterialized, - }) + } } pub(crate) fn with_materialized_stats(mut self) -> Self { diff --git a/src/daft-logical-plan/src/ops/set_operations.rs b/src/daft-logical-plan/src/ops/set_operations.rs index 42009182b6..64521f4ed9 100644 --- a/src/daft-logical-plan/src/ops/set_operations.rs +++ b/src/daft-logical-plan/src/ops/set_operations.rs @@ -47,9 +47,6 @@ fn intersect_or_except_plan( Some(vec![true; left_on_size]), join_type, None, - None, - None, - false, ); join.map(|j| Distinct::new(j.into()).into()) } @@ -303,8 +300,7 @@ impl Union { } else { (self.lhs.clone(), self.rhs.clone()) }; - // we don't want to use `try_new` as we have already checked the schema - let concat = LogicalPlan::Concat(Concat::new(lhs, rhs)); + let concat = LogicalPlan::Concat(Concat::try_new(lhs, rhs)?); if self.is_all { Ok(concat) } else { diff --git a/src/daft-logical-plan/src/ops/sink.rs b/src/daft-logical-plan/src/ops/sink.rs index e5eb9f3f2e..46aa17b1dd 100644 --- a/src/daft-logical-plan/src/ops/sink.rs +++ b/src/daft-logical-plan/src/ops/sink.rs @@ -1,15 +1,14 @@ use std::sync::Arc; -use common_error::DaftResult; use daft_core::prelude::*; -use daft_dsl::ExprResolver; #[cfg(feature = "python")] use crate::sink_info::CatalogType; use crate::{ + logical_plan::{self}, sink_info::SinkInfo, stats::{PlanStats, StatsState}, - LogicalPlan, OutputFileInfo, + LogicalPlan, }; #[derive(Clone, Debug, PartialEq, Eq, Hash)] @@ -23,41 +22,12 @@ pub struct Sink { } impl Sink { - pub(crate) fn try_new(input: Arc, sink_info: Arc) -> DaftResult { + pub(crate) fn try_new( + input: Arc, + sink_info: Arc, + ) -> logical_plan::Result { let schema = input.schema(); - // replace partition columns with resolved columns - let sink_info = match sink_info.as_ref() { - SinkInfo::OutputFileInfo(OutputFileInfo { - root_dir, - file_format, - partition_cols, - compression, - io_config, - }) => { - let expr_resolver = ExprResolver::default(); - - let resolved_partition_cols = partition_cols - .clone() - .map(|cols| { - expr_resolver - .resolve(cols, &schema) - .map(|(resolved_cols, _)| resolved_cols) - }) - .transpose()?; - - Arc::new(SinkInfo::OutputFileInfo(OutputFileInfo { - root_dir: root_dir.clone(), - file_format: *file_format, - partition_cols: resolved_partition_cols, - compression: compression.clone(), - io_config: io_config.clone(), - })) - } - #[cfg(feature = "python")] - SinkInfo::CatalogInfo(_) => sink_info, - }; - let fields = match sink_info.as_ref() { SinkInfo::OutputFileInfo(output_file_info) => { let mut fields = vec![Field::new("path", DataType::Utf8)]; diff --git a/src/daft-logical-plan/src/ops/sort.rs b/src/daft-logical-plan/src/ops/sort.rs index 9c2cd046fd..b5196c617c 100644 --- a/src/daft-logical-plan/src/ops/sort.rs +++ b/src/daft-logical-plan/src/ops/sort.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use common_error::DaftError; use daft_core::prelude::*; -use daft_dsl::{ExprRef, ExprResolver}; +use daft_dsl::{exprs_to_schema, ExprRef}; use itertools::Itertools; use snafu::ResultExt; @@ -32,15 +32,10 @@ impl Sort { .context(CreationSnafu); } - let expr_resolver = ExprResolver::default(); + // TODO(Kevin): make sort by expression names unique so that we can do things like sort(col("a"), col("a") + col("b")) + let sort_by_schema = exprs_to_schema(&sort_by, input.schema())?; - let (sort_by, sort_by_fields) = expr_resolver - .resolve(sort_by, &input.schema()) - .context(CreationSnafu)?; - - let sort_by_resolved_schema = Schema::new(sort_by_fields).context(CreationSnafu)?; - - for (field, expr) in sort_by_resolved_schema.fields.values().zip(sort_by.iter()) { + for (field, expr) in sort_by_schema.fields.values().zip(sort_by.iter()) { // Disallow sorting by null, binary, and boolean columns. // TODO(Clark): This is a port of an existing constraint, we should look at relaxing this. if let dt @ (DataType::Null | DataType::Binary) = &field.dtype { diff --git a/src/daft-logical-plan/src/ops/source.rs b/src/daft-logical-plan/src/ops/source.rs index 4044e08c72..e2c67fa2ac 100644 --- a/src/daft-logical-plan/src/ops/source.rs +++ b/src/daft-logical-plan/src/ops/source.rs @@ -60,10 +60,8 @@ impl Source { num_rows, .. }) => ApproxStats { - lower_bound_rows: *num_rows, - upper_bound_rows: Some(*num_rows), - lower_bound_bytes: *size_bytes, - upper_bound_bytes: Some(*size_bytes), + num_rows: *num_rows, + size_bytes: *size_bytes, }, SourceInfo::Physical(physical_scan_info) => match &physical_scan_info.scan_state { ScanState::Operator(_) => { @@ -72,23 +70,13 @@ impl Source { ScanState::Tasks(scan_tasks) => { let mut approx_stats = ApproxStats::empty(); for st in scan_tasks.iter() { - approx_stats.lower_bound_rows += st.num_rows().unwrap_or(0); - let in_memory_size = st.estimate_in_memory_size_bytes(None); - approx_stats.lower_bound_bytes += in_memory_size.unwrap_or(0); - if let Some(st_ub) = st.upper_bound_rows() { - if let Some(ub) = approx_stats.upper_bound_rows { - approx_stats.upper_bound_rows = Some(ub + st_ub); - } else { - approx_stats.upper_bound_rows = st.upper_bound_rows(); - } - } - if let Some(st_ub) = in_memory_size { - if let Some(ub) = approx_stats.upper_bound_bytes { - approx_stats.upper_bound_bytes = Some(ub + st_ub); - } else { - approx_stats.upper_bound_bytes = in_memory_size; - } + if let Some(num_rows) = st.num_rows() { + approx_stats.num_rows += num_rows; + } else if let Some(approx_num_rows) = st.approx_num_rows(None) { + approx_stats.num_rows += approx_num_rows as usize; } + approx_stats.size_bytes += + st.estimate_in_memory_size_bytes(None).unwrap_or(0); } approx_stats } diff --git a/src/daft-logical-plan/src/ops/unpivot.rs b/src/daft-logical-plan/src/ops/unpivot.rs index 46a7071bf5..e6dda83bef 100644 --- a/src/daft-logical-plan/src/ops/unpivot.rs +++ b/src/daft-logical-plan/src/ops/unpivot.rs @@ -1,8 +1,8 @@ use std::sync::Arc; -use common_error::DaftError; +use common_error::{DaftError, DaftResult}; use daft_core::{prelude::*, utils::supertype::try_get_supertype}; -use daft_dsl::{ExprRef, ExprResolver}; +use daft_dsl::ExprRef; use itertools::Itertools; use snafu::ResultExt; @@ -48,8 +48,8 @@ impl Unpivot { input: Arc, ids: Vec, values: Vec, - variable_name: &str, - value_name: &str, + variable_name: String, + value_name: String, ) -> logical_plan::Result { if values.is_empty() { return Err(DaftError::ValueError( @@ -58,40 +58,29 @@ impl Unpivot { .context(CreationSnafu); } - let expr_resolver = ExprResolver::default(); - - let input_schema = input.schema(); - let (values, values_fields) = expr_resolver - .resolve(values, &input_schema) - .context(CreationSnafu)?; - - let value_dtype = values_fields + let value_dtype = values .iter() - .map(|f| f.dtype.clone()) - .try_reduce(|a, b| try_get_supertype(&a, &b)) - .context(CreationSnafu)? - .unwrap(); + .map(|expr| Ok(expr.to_field(&input.schema())?.dtype)) + .reduce(|a, b| try_get_supertype(&a?, &b?)) + .unwrap()?; - let variable_field = Field::new(variable_name, DataType::Utf8); - let value_field = Field::new(value_name, value_dtype); + let variable_field = Field::new(&variable_name, DataType::Utf8); + let value_field = Field::new(&value_name, value_dtype); - let (ids, ids_fields) = expr_resolver - .resolve(ids, &input_schema) - .context(CreationSnafu)?; - - let output_fields = ids_fields - .into_iter() - .chain([variable_field, value_field]) - .collect::>(); + let output_fields = ids + .iter() + .map(|id| id.to_field(&input.schema())) + .chain([Ok(variable_field), Ok(value_field)]) + .collect::>>()?; - let output_schema = Schema::new(output_fields).context(CreationSnafu)?.into(); + let output_schema = Schema::new(output_fields)?.into(); Ok(Self { input, ids, values, - variable_name: variable_name.to_string(), - value_name: value_name.to_string(), + variable_name, + value_name, output_schema, stats_state: StatsState::NotMaterialized, }) @@ -101,13 +90,8 @@ impl Unpivot { let input_stats = self.input.materialized_stats(); let num_values = self.values.len(); let approx_stats = ApproxStats { - lower_bound_rows: input_stats.approx_stats.lower_bound_rows * num_values, - upper_bound_rows: input_stats - .approx_stats - .upper_bound_rows - .map(|v| v * num_values), - lower_bound_bytes: input_stats.approx_stats.lower_bound_bytes, - upper_bound_bytes: input_stats.approx_stats.upper_bound_bytes, + num_rows: input_stats.approx_stats.num_rows * num_values, + size_bytes: input_stats.approx_stats.size_bytes, }; self.stats_state = StatsState::Materialized(PlanStats::new(approx_stats).into()); self diff --git a/src/daft-logical-plan/src/optimization/optimizer.rs b/src/daft-logical-plan/src/optimization/optimizer.rs index cea66fbf13..d1fea88019 100644 --- a/src/daft-logical-plan/src/optimization/optimizer.rs +++ b/src/daft-logical-plan/src/optimization/optimizer.rs @@ -106,11 +106,16 @@ impl Default for OptimizerBuilder { vec![Box::new(SimplifyExpressionsRule::new())], RuleExecutionStrategy::FixedPoint(Some(3)), ), + // --- Filter out null join keys --- + // This rule should be run once, before any filter pushdown rules. + RuleBatch::new( + vec![Box::new(FilterNullJoinKey::new())], + RuleExecutionStrategy::Once, + ), // --- Bulk of our rules --- RuleBatch::new( vec![ Box::new(DropRepartition::new()), - Box::new(FilterNullJoinKey::new()), Box::new(PushDownFilter::new()), Box::new(PushDownProjection::new()), Box::new(EliminateCrossJoin::new()), diff --git a/src/daft-logical-plan/src/optimization/rules/eliminate_cross_join.rs b/src/daft-logical-plan/src/optimization/rules/eliminate_cross_join.rs index e9e3a2e524..cd192f0df9 100644 --- a/src/daft-logical-plan/src/optimization/rules/eliminate_cross_join.rs +++ b/src/daft-logical-plan/src/optimization/rules/eliminate_cross_join.rs @@ -303,12 +303,10 @@ fn find_inner_join( if !join_keys.is_empty() { all_join_keys.insert_all(join_keys.iter()); let right_input = rights.remove(i); - let join_schema = left_input - .schema() - .non_distinct_union(right_input.schema().as_ref()); let (left_keys, right_keys) = join_keys.iter().cloned().unzip(); - return Ok(LogicalPlan::Join(Join::new( + + return Ok(LogicalPlan::Join(Join::try_new( left_input, right_input, left_keys, @@ -316,8 +314,7 @@ fn find_inner_join( None, JoinType::Inner, None, - Arc::new(join_schema), - )) + )?) .arced()); } } @@ -325,11 +322,8 @@ fn find_inner_join( // no matching right plan had any join keys, cross join with the first right // plan let right = rights.remove(0); - let join_schema = left_input - .schema() - .non_distinct_union(right.schema().as_ref()); - Ok(LogicalPlan::Join(Join::new( + Ok(LogicalPlan::Join(Join::try_new( left_input, right, vec![], @@ -337,8 +331,7 @@ fn find_inner_join( None, JoinType::Inner, None, - Arc::new(join_schema), - )) + )?) .arced()) } diff --git a/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs b/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs index 442d8120dc..70a8815080 100644 --- a/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs +++ b/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs @@ -356,6 +356,7 @@ mod tests { use common_scan_info::Pushdowns; use daft_core::prelude::*; use daft_dsl::{col, lit}; + use daft_functions::uri::download::UrlDownloadArgs; use rstest::rstest; use crate::{ @@ -435,7 +436,10 @@ mod tests { /// Tests that we can't pushdown a filter into a ScanOperator if it has an udf-ish expression. #[test] fn filter_with_udf_not_pushed_down_into_scan() -> DaftResult<()> { - let pred = daft_functions::uri::download(col("a"), 1, true, true, None); + let pred = daft_functions::uri::download( + col("a"), + Some(UrlDownloadArgs::new(1, true, true, None)), + ); let plan = dummy_scan_node(dummy_scan_operator(vec![ Field::new("a", DataType::Int64), Field::new("b", DataType::Utf8), diff --git a/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs b/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs index 00f93f536a..da314858d4 100644 --- a/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs +++ b/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs @@ -55,7 +55,6 @@ pub(super) trait JoinOrderer { pub(super) struct JoinNode { relation_name: String, plan: LogicalPlanRef, - final_name: String, } // TODO(desmond): We should also take into account user provided values for: @@ -65,21 +64,15 @@ pub(super) struct JoinNode { /// JoinNodes represent a relation (i.e. a non-reorderable logical plan node), the column /// that's being accessed from the relation, and the final name of the column in the output. impl JoinNode { - pub(super) fn new(relation_name: String, plan: LogicalPlanRef, final_name: String) -> Self { + pub(super) fn new(relation_name: String, plan: LogicalPlanRef) -> Self { Self { relation_name, plan, - final_name, } } fn simple_repr(&self) -> String { - format!( - "{}#{}({})", - self.final_name, - self.plan.name(), - self.relation_name - ) + format!("{}({})", self.plan.name(), self.relation_name) } } @@ -168,8 +161,8 @@ impl JoinAdjList { fn add_unidirectional_edge(&mut self, left: &JoinNode, right: &JoinNode) { let join_condition = JoinCondition { - left_on: left.final_name.clone(), - right_on: right.final_name.clone(), + left_on: left.relation_name.clone(), + right_on: right.relation_name.clone(), }; let left_id = self.get_or_create_plan_id(&left.plan); let right_id = self.get_or_create_plan_id(&right.plan); @@ -448,78 +441,140 @@ impl JoinGraphBuilder { /// Joins that added conditions to `join_conds_to_resolve` will pop them off the stack after they have been resolved. /// Combining each of their resolved `left_on` conditions with their respective resolved `right_on` conditions produces /// a join edge between the relation used in the left condition and the relation used in the right condition. - fn process_node(&mut self, plan: &LogicalPlanRef) { - let schema = plan.schema(); - for (name, node, done) in &mut self.join_conds_to_resolve { - if !*done && schema.has_field(name) { - *node = plan.clone(); + fn process_node(&mut self, mut plan: &LogicalPlanRef) { + // Go down the linear chain of Projects and Filters until we hit a join or an unreorderable operator. + // If we hit a join, we should process all the Projects and Filters that we encountered before the join. + // If we hit an unreorderable operator, the root plan at the top of this linear chain becomes a relation for + // join ordering. + // + // For example, consider this query tree: + // + // InnerJoin (a = d) + // / \ + // / Project + // / (d, quad <- double + double) + // / \ + // InnerJoin (a = b) InnerJoin (c = d) + // / \ / \ + // Scan(a) Scan(b) Filter Scan(d) + // (c < 5) + // | + // Project + // (c <- c_prime, double <- c_prime + c_prime) + // | + // Filter + // (c_prime > 0) + // | + // Scan(c_prime) + // + // In between InnerJoin(c=d) and Scan(c_prime) there are Filter and Project nodes. Since there is no join below InnerJoin(c=d), + // we take the Filter(c<5) operator as the relation to pass into the join (as opposed to using Scan(c_prime) and pushing up + // the Projects and Filters above it). + let root_plan = plan; + loop { + match &**plan { + // TODO(desmond): There are potentially more reorderable nodes. For example, we can move repartitions around. + LogicalPlan::Project(Project { + input, projection, .. + }) => { + let projection_input_mapping = projection + .iter() + .filter_map(|e| e.input_mapping().map(|s| (e.name().to_string(), s))) + .collect::>(); + // To be able to reorder through the current projection, all unresolved columns must either have a + // zero-computation projection, or must not be projected by the current Project node (i.e. it should be + // resolved from some other branch in the query tree). + let schema = plan.schema(); + let reorderable_project = + self.join_conds_to_resolve.iter().all(|(name, _, done)| { + *done + || !schema.has_field(name) + || projection_input_mapping.contains_key(name.as_str()) + }); + if reorderable_project { + plan = input; + } else { + // Encountered a non-reorderable Project. Add the root plan at the top of the current linear chain as a relation to join. + self.add_relation(root_plan); + break; + } + } + LogicalPlan::Filter(Filter { input, .. }) => plan = input, + // Since we hit a join, we need to process the linear chain of Projects and Filters that were encountered starting + // from the plan at the root of the linear chain to the current plan. + LogicalPlan::Join(Join { + left_on, join_type, .. + }) if *join_type == JoinType::Inner && !left_on.is_empty() => { + self.process_linear_chain(root_plan, plan); + break; + } + _ => { + // Encountered a non-reorderable node. Add the root plan at the top of the current linear chain as a relation to join. + self.add_relation(root_plan); + break; + } } } - match &**plan { - LogicalPlan::Project(Project { - input, projection, .. - }) => { - // Get the mapping from input->output for projections that don't need computation. - let projection_input_mapping = projection - .iter() - .filter_map(|e| e.input_mapping().map(|s| (e.name().to_string(), col(s)))) - .collect::>(); - // To be able to reorder through the current projection, all unresolved columns must either have a - // zero-computation projection, or must not be projected by the current Project node (i.e. it should be - // resolved from some other branch in the query tree). - let reorderable_project = - self.join_conds_to_resolve.iter().all(|(name, _, done)| { - *done - || !schema.has_field(name) - || projection_input_mapping.contains_key(name.as_str()) - }); - if reorderable_project { - let mut non_join_names: HashSet = schema.names().into_iter().collect(); - for (name, _, done) in &mut self.join_conds_to_resolve { - if !*done { - if let Some(new_expr) = projection_input_mapping.get(name) { - // Remove the current name from the list of schema names so that we can produce - // a set of non-join-key names for the current Project's schema. - non_join_names.remove(name); - // If we haven't updated the corresponding entry in the final name map, do so now. - if let Some(final_name) = self.final_name_map.remove(name) { - self.final_name_map - .insert(new_expr.name().to_string(), final_name); - } - *name = new_expr.name().to_string(); + } + + /// `process_linear_chain` is a helper function that pushes up the Projects and Filters from `starting_node` to + /// `ending_node`. `ending_node` MUST be an inner join. + /// + /// After pushing up Projects and Filters, `process_linear_chain` will call `process_node` on the left and right + /// children of the Join node in `ending_node`. + fn process_linear_chain( + &mut self, + starting_node: &LogicalPlanRef, + ending_node: &LogicalPlanRef, + ) { + let mut cur_node = starting_node; + while !Arc::ptr_eq(cur_node, ending_node) { + match &**cur_node { + LogicalPlan::Project(Project { + input, projection, .. + }) => { + // Get the mapping from input->output for projections that don't need computation. + let mut compute_projections = vec![]; + let projection_input_mapping = projection + .iter() + .filter_map(|e| { + let input_mapping = e.input_mapping(); + if input_mapping.is_none() { + compute_projections.push(e.clone()); } + input_mapping.map(|s| (e.name().to_string(), s)) + }) + .collect::>(); + for (output, input) in &projection_input_mapping { + if let Some(final_name) = self.final_name_map.remove(output) { + self.final_name_map.insert(input.clone(), final_name); + } else { + self.final_name_map + .insert(input.clone(), col(output.clone())); } } - // Keep track of non-join-key projections so that we can reapply them once we've reordered the query tree. - let non_join_key_projections = projection - .iter() - .filter(|e| non_join_names.contains(e.name())) - .map(|e| replace_columns_with_expressions(e.clone(), &self.final_name_map)) - .collect::>(); - if !non_join_key_projections.is_empty() { + if !compute_projections.is_empty() { self.final_projections_and_filters - .push(ProjectionOrFilter::Projection(non_join_key_projections)); + .push(ProjectionOrFilter::Projection(compute_projections.clone())); } // Continue to children. - self.process_node(input); - } else { - for (name, _, done) in &mut self.join_conds_to_resolve { - if schema.has_field(name) { - *done = true; - } - } + cur_node = input; } + LogicalPlan::Filter(Filter { + input, predicate, .. + }) => { + let new_predicate = + replace_columns_with_expressions(predicate.clone(), &self.final_name_map); + self.final_projections_and_filters + .push(ProjectionOrFilter::Filter(new_predicate)); + // Continue to children. + cur_node = input; + } + _ => unreachable!("process_linear_chain is only called with a linear chain of Project and Filters that end with a Join"), } - LogicalPlan::Filter(Filter { - input, predicate, .. - }) => { - let new_predicate = - replace_columns_with_expressions(predicate.clone(), &self.final_name_map); - self.final_projections_and_filters - .push(ProjectionOrFilter::Filter(new_predicate)); - self.process_node(input); - } - // Only reorder inner joins with non-empty join conditions. + } + match &**cur_node { + // The cur_node is now at the ending_node which MUST be a join node. LogicalPlan::Join(Join { left, right, @@ -530,11 +585,17 @@ impl JoinGraphBuilder { }) if *join_type == JoinType::Inner && !left_on.is_empty() => { for l in left_on { let name = l.name(); - if !self.final_name_map.contains_key(name) { + let final_name = if let Some(final_name) = self.final_name_map.get(name) { + final_name.name() + } else { self.final_name_map.insert(name.to_string(), col(name)); - } - self.join_conds_to_resolve - .push((name.to_string(), plan.clone(), false)); + name + }; + self.join_conds_to_resolve.push(( + final_name.to_string(), + cur_node.clone(), + false, + )); } self.process_node(left); let mut ready_left = vec![]; @@ -543,11 +604,17 @@ impl JoinGraphBuilder { } for r in right_on { let name = r.name(); - if !self.final_name_map.contains_key(name) { + let final_name = if let Some(final_name) = self.final_name_map.get(name) { + final_name.name() + } else { self.final_name_map.insert(name.to_string(), col(name)); - } - self.join_conds_to_resolve - .push((name.to_string(), plan.clone(), false)); + name + }; + self.join_conds_to_resolve.push(( + final_name.to_string(), + cur_node.clone(), + false, + )); } self.process_node(right); let mut ready_right = vec![]; @@ -558,64 +625,67 @@ impl JoinGraphBuilder { ready_left.into_iter().zip(ready_right.into_iter()) { if ldone && rdone { - let node1 = JoinNode::new( - lname.clone(), - lnode.clone(), - self.final_name_map.get(&lname).unwrap().name().to_string(), - ); - let node2 = JoinNode::new( - rname.clone(), - rnode.clone(), - self.final_name_map.get(&rname).unwrap().name().to_string(), - ); + let node1 = JoinNode::new(lname.clone(), lnode.clone()); + let node2 = JoinNode::new(rname.clone(), rnode.clone()); self.adj_list.add_bidirectional_edge(node1, node2); } else { panic!("Join conditions were unresolved"); } } } - // TODO(desmond): There are potentially more reorderable nodes. For example, we can move repartitions around. _ => { - // This is an unreorderable node. All unresolved columns coming out of this node should be marked as resolved. - // TODO(desmond): At this point we should perform a fresh join reorder optimization starting from this - // node as the root node. We can do this once we add the optimizer rule. - let mut projections = vec![]; - let mut needs_projection = false; - let mut seen_names = HashSet::new(); - for (name, _, done) in &mut self.join_conds_to_resolve { - if schema.has_field(name) && !*done && !seen_names.contains(name) { - if let Some(final_name) = self.final_name_map.get(name) { - let final_name = final_name.name().to_string(); - if final_name != *name { - needs_projection = true; - projections.push(col(name.clone()).alias(final_name)); - } else { - projections.push(col(name.clone())); - } - } else { - projections.push(col(name.clone())); - } - seen_names.insert(name); - } - } - // Apply projections and return the new plan as the relation for the appropriate join conditions. - let projected_plan = if needs_projection { - let projected_plan = LogicalPlanBuilder::from(plan.clone()) - .select(projections) - .expect("Computed projections could not be applied to relation") - .build(); - Arc::new(Arc::unwrap_or_clone(projected_plan).with_materialized_stats()) + panic!("Expected a join node") + } + } + } + + /// `process_leaf_relation` is a helper function that processes an unreorderable node that sits below some + /// Join node(s). `plan` will become one of the relations involved in join ordering. + fn add_relation(&mut self, plan: &LogicalPlanRef) { + // All unresolved columns coming out of this node should be marked as resolved. + // TODO(desmond): At this point we should perform a fresh join reorder optimization starting from this + // node as the root node. We can do this once we add the optimizer rule. + let schema = plan.schema(); + + let mut projections = vec![]; + let names = schema.names(); + let mut seen_names = HashSet::new(); + let mut needs_projection = false; + for (input, final_name) in &self.final_name_map { + if names.contains(input) { + seen_names.insert(input); + let final_name = final_name.name().to_string(); + if final_name != *input { + projections.push(col(input.clone()).alias(final_name)); + needs_projection = true; } else { - plan.clone() - }; - for (name, node, done) in &mut self.join_conds_to_resolve { - if schema.has_field(name) && !*done { - *done = true; - *node = projected_plan.clone(); - } + projections.push(col(input.clone())); } } } + // Apply projections and return the new plan as the relation for the appropriate join conditions. + let projected_plan = if needs_projection { + // Add the non-join-key columns to the projection. + for name in &schema.names() { + if !seen_names.contains(name) { + projections.push(col(name.clone())); + } + } + let projected_plan = LogicalPlanBuilder::from(plan.clone()) + .select(projections) + .expect("Computed projections could not be applied to relation") + .build(); + Arc::new(Arc::unwrap_or_clone(projected_plan).with_materialized_stats()) + } else { + plan.clone() + }; + let projected_schema = projected_plan.schema(); + for (name, node, done) in &mut self.join_conds_to_resolve { + if projected_schema.has_field(name) && !*done { + *done = true; + *node = projected_plan.clone(); + } + } } } @@ -781,12 +851,12 @@ mod tests { #[test] fn test_create_join_graph_multiple_renames() { - // InnerJoin (a_beta = b) + // InnerJoin (a_beta = c) // / \ // Project Scan(c) // (a_beta <- a_alpha) // / - // InnerJoin (a = c) + // InnerJoin (a_alpha = b) // / \ // Project Scan(b) // (a_alpha <- a) @@ -912,10 +982,6 @@ mod tests { "Project(c) <-> Source(d)", "Source(a) <-> Source(d)" ])); - // Check for non-join projections at the end. - // `c_prime` gets renamed to `c` in the final projection - let double_proj = col("c").add(col("c")).alias("double"); - assert!(join_graph.contains_projections_and_filters(vec![&double_proj])); } #[test] @@ -1003,19 +1069,12 @@ mod tests { assert!(join_graph.num_edges() == 3); assert!(join_graph.contains_edges(vec![ "Source(a) <-> Source(b)", - "Project(c) <-> Source(d)", + "Filter(c) <-> Source(d)", "Source(a) <-> Source(d)", ])); // Check for non-join projections and filters at the end. - // `c_prime` gets renamed to `c` in the final projection - let double_proj = col("c").add(col("c")).alias("double"); - let filter_c_prime = col("c").gt(Arc::new(Expr::Literal(LiteralValue::Int64(0)))); - assert!(join_graph.contains_projections_and_filters(vec![ - &quad_proj, - &filter_c, - &double_proj, - &filter_c_prime, - ])); + // The join graph should only keep track of projections and filters that sit between joins. + assert!(join_graph.contains_projections_and_filters(vec![&quad_proj,])); } #[test] diff --git a/src/daft-logical-plan/src/optimization/rules/reorder_joins/naive_left_deep_join_order.rs b/src/daft-logical-plan/src/optimization/rules/reorder_joins/naive_left_deep_join_order.rs index 9d7d26a8b5..67b659838c 100644 --- a/src/daft-logical-plan/src/optimization/rules/reorder_joins/naive_left_deep_join_order.rs +++ b/src/daft-logical-plan/src/optimization/rules/reorder_joins/naive_left_deep_join_order.rs @@ -80,7 +80,7 @@ mod tests { .iter() .map(|name| { let scan_node = create_scan_node(name, Some(100)); - JoinNode::new(name.to_string(), scan_node, name.to_string()) + JoinNode::new(name.to_string(), scan_node) }) .collect(); let graph = create_join_graph_with_edges(nodes.clone(), $edges); diff --git a/src/daft-logical-plan/src/optimization/rules/unnest_subquery.rs b/src/daft-logical-plan/src/optimization/rules/unnest_subquery.rs index 5039cc9767..18e05a8218 100644 --- a/src/daft-logical-plan/src/optimization/rules/unnest_subquery.rs +++ b/src/daft-logical-plan/src/optimization/rules/unnest_subquery.rs @@ -119,35 +119,33 @@ impl UnnestScalarSubquery { let (decorrelated_subquery, subquery_on, input_on) = pull_up_correlated_cols(subquery_plan)?; - if subquery_on.is_empty() { - // uncorrelated scalar subquery - Ok(Arc::new(LogicalPlan::Join(Join::try_new( - curr_input, - decorrelated_subquery, - vec![], - vec![], - None, - JoinType::Inner, - None, - None, - None, - false, - )?))) + // use inner join when uncorrelated so that filter can be pushed into join and other optimizations + let join_type = if subquery_on.is_empty() { + JoinType::Inner } else { - // correlated scalar subquery - Ok(Arc::new(LogicalPlan::Join(Join::try_new( - curr_input, - decorrelated_subquery, - input_on, - subquery_on, - None, - JoinType::Left, - None, - None, - None, - false, - )?))) - } + JoinType::Left + }; + + let (decorrelated_subquery, subquery_on) = Join::rename_right_columns( + curr_input.clone(), + decorrelated_subquery, + input_on.clone(), + subquery_on, + join_type, + None, + None, + false, + )?; + + Ok(Arc::new(LogicalPlan::Join(Join::try_new( + curr_input, + decorrelated_subquery, + input_on, + subquery_on, + None, + join_type, + None, + )?))) })?; Ok(Transformed::yes((new_input, new_exprs))) @@ -335,9 +333,6 @@ impl OptimizerRule for UnnestPredicateSubquery { None, join_type, None, - None, - None, - false )?))) })?; diff --git a/src/daft-logical-plan/src/stats.rs b/src/daft-logical-plan/src/stats.rs index 22c3f85198..0faad9f184 100644 --- a/src/daft-logical-plan/src/stats.rs +++ b/src/daft-logical-plan/src/stats.rs @@ -46,11 +46,8 @@ impl Display for PlanStats { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, - "{{ Lower bound rows = {}, Upper bound rows = {}, Lower bound bytes = {}, Upper bound bytes = {} }}", - self.approx_stats.lower_bound_rows, - self.approx_stats.upper_bound_rows.map_or("None".to_string(), |v| v.to_string()), - self.approx_stats.lower_bound_bytes, - self.approx_stats.upper_bound_bytes.map_or("None".to_string(), |v| v.to_string()), + "{{ Approx num rows = {}, Approx num bytes = {} }}", + self.approx_stats.num_rows, self.approx_stats.size_bytes, ) } } @@ -103,27 +100,21 @@ impl Display for AlwaysSame { #[derive(Clone, Debug, PartialEq, Eq, Deserialize, Serialize)] pub struct ApproxStats { - pub lower_bound_rows: usize, - pub upper_bound_rows: Option, - pub lower_bound_bytes: usize, - pub upper_bound_bytes: Option, + pub num_rows: usize, + pub size_bytes: usize, } impl ApproxStats { pub fn empty() -> Self { Self { - lower_bound_rows: 0, - upper_bound_rows: None, - lower_bound_bytes: 0, - upper_bound_bytes: None, + num_rows: 0, + size_bytes: 0, } } pub fn apply usize>(&self, f: F) -> Self { Self { - lower_bound_rows: f(self.lower_bound_rows), - upper_bound_rows: self.upper_bound_rows.map(&f), - lower_bound_bytes: f(self.lower_bound_rows), - upper_bound_bytes: self.upper_bound_bytes.map(&f), + num_rows: f(self.num_rows), + size_bytes: f(self.size_bytes), } } } @@ -133,14 +124,8 @@ impl Add for &ApproxStats { type Output = ApproxStats; fn add(self, rhs: Self) -> Self::Output { ApproxStats { - lower_bound_rows: self.lower_bound_rows + rhs.lower_bound_rows, - upper_bound_rows: self - .upper_bound_rows - .and_then(|l_ub| rhs.upper_bound_rows.map(|v| v + l_ub)), - lower_bound_bytes: self.lower_bound_bytes + rhs.lower_bound_bytes, - upper_bound_bytes: self - .upper_bound_bytes - .and_then(|l_ub| rhs.upper_bound_bytes.map(|v| v + l_ub)), + num_rows: self.num_rows + rhs.num_rows, + size_bytes: self.size_bytes + rhs.size_bytes, } } } diff --git a/src/daft-micropartition/src/partitioning.rs b/src/daft-micropartition/src/partitioning.rs index 76667a8618..a2d8d19c00 100644 --- a/src/daft-micropartition/src/partitioning.rs +++ b/src/daft-micropartition/src/partitioning.rs @@ -25,6 +25,7 @@ impl Partition for MicroPartition { pub struct MicroPartitionSet { pub partitions: DashMap, } + impl From> for MicroPartitionSet { fn from(value: Vec) -> Self { let partitions = value diff --git a/src/daft-micropartition/src/python.rs b/src/daft-micropartition/src/python.rs index 6ed01c7a7a..53a302dc3e 100644 --- a/src/daft-micropartition/src/python.rs +++ b/src/daft-micropartition/src/python.rs @@ -1,6 +1,7 @@ use std::sync::{Arc, Mutex}; use common_error::DaftResult; +use common_partitioning::Partition; use daft_core::{ join::JoinSide, prelude::*, @@ -25,9 +26,9 @@ use crate::{ }; #[pyclass(module = "daft.daft", frozen)] -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct PyMicroPartition { - inner: Arc, + pub inner: Arc, } #[pymethods] diff --git a/src/daft-physical-plan/src/ops/filter.rs b/src/daft-physical-plan/src/ops/filter.rs index af8c35342c..7977726daf 100644 --- a/src/daft-physical-plan/src/ops/filter.rs +++ b/src/daft-physical-plan/src/ops/filter.rs @@ -9,11 +9,20 @@ pub struct Filter { pub input: PhysicalPlanRef, // The Boolean expression to filter on. pub predicate: ExprRef, + pub estimated_selectivity: f64, } impl Filter { - pub(crate) fn new(input: PhysicalPlanRef, predicate: ExprRef) -> Self { - Self { input, predicate } + pub(crate) fn new( + input: PhysicalPlanRef, + predicate: ExprRef, + estimated_selectivity: f64, + ) -> Self { + Self { + input, + predicate, + estimated_selectivity, + } } pub fn multiline_display(&self) -> Vec { diff --git a/src/daft-physical-plan/src/physical_planner/planner.rs b/src/daft-physical-plan/src/physical_planner/planner.rs index 1e5b191563..149cae0c29 100644 --- a/src/daft-physical-plan/src/physical_planner/planner.rs +++ b/src/daft-physical-plan/src/physical_planner/planner.rs @@ -119,7 +119,7 @@ impl TreeNodeRewriter for QueryStagePhysicalPlanTranslator { let left_stats = left.approximate_stats(); let right_stats = right.approximate_stats(); - if left_stats.lower_bound_bytes <= right_stats.lower_bound_bytes { + if left_stats.size_bytes <= right_stats.size_bytes { RunNext::Left } else { RunNext::Right diff --git a/src/daft-physical-plan/src/physical_planner/translate.rs b/src/daft-physical-plan/src/physical_planner/translate.rs index 838426253f..c9c0c5f24c 100644 --- a/src/daft-physical-plan/src/physical_planner/translate.rs +++ b/src/daft-physical-plan/src/physical_planner/translate.rs @@ -10,8 +10,8 @@ use common_file_formats::FileFormat; use common_scan_info::{PhysicalScanInfo, ScanState, SPLIT_AND_MERGE_PASS}; use daft_core::{join::JoinSide, prelude::*}; use daft_dsl::{ - col, functions::agg::merge_mean, is_partition_compatible, AggExpr, ApproxPercentileParams, - Expr, ExprRef, SketchType, + col, estimated_selectivity, functions::agg::merge_mean, is_partition_compatible, AggExpr, + ApproxPercentileParams, Expr, ExprRef, SketchType, }; use daft_functions::{list::unique_count, numeric::sqrt}; use daft_logical_plan::{ @@ -116,9 +116,17 @@ pub(super) fn translate_single_logical_node( )?) .arced()) } - LogicalPlan::Filter(LogicalFilter { predicate, .. }) => { + LogicalPlan::Filter(LogicalFilter { + predicate, input, .. + }) => { let input_physical = physical_children.pop().expect("requires 1 input"); - Ok(PhysicalPlan::Filter(Filter::new(input_physical, predicate.clone())).arced()) + let estimated_selectivity = estimated_selectivity(predicate, &input.schema()); + Ok(PhysicalPlan::Filter(Filter::new( + input_physical, + predicate.clone(), + estimated_selectivity, + )) + .arced()) } LogicalPlan::Limit(LogicalLimit { limit, eager, .. }) => { let input_physical = physical_children.pop().expect("requires 1 input"); @@ -472,17 +480,10 @@ pub(super) fn translate_single_logical_node( // For broadcast joins, ensure that the left side of the join is the smaller side. let (smaller_size_bytes, left_is_larger) = - match (left_stats.upper_bound_bytes, right_stats.upper_bound_bytes) { - (Some(left_size_bytes), Some(right_size_bytes)) => { - if right_size_bytes < left_size_bytes { - (Some(right_size_bytes), true) - } else { - (Some(left_size_bytes), false) - } - } - (Some(left_size_bytes), None) => (Some(left_size_bytes), false), - (None, Some(right_size_bytes)) => (Some(right_size_bytes), true), - (None, None) => (None, false), + if right_stats.size_bytes < left_stats.size_bytes { + (right_stats.size_bytes, true) + } else { + (left_stats.size_bytes, false) }; let is_larger_partitioned = if left_is_larger { is_left_hash_partitioned || is_left_sort_partitioned @@ -518,7 +519,6 @@ pub(super) fn translate_single_logical_node( // If larger table is not already partitioned on the join key AND the smaller table is under broadcast size threshold AND we are not broadcasting the side we are outer joining by, use broadcast join. if !is_larger_partitioned - && let Some(smaller_size_bytes) = smaller_size_bytes && smaller_size_bytes <= cfg.broadcast_join_size_bytes_threshold && smaller_side_is_broadcastable { diff --git a/src/daft-physical-plan/src/plan.rs b/src/daft-physical-plan/src/plan.rs index f3fb6fc45e..876b8be111 100644 --- a/src/daft-physical-plan/src/plan.rs +++ b/src/daft-physical-plan/src/plan.rs @@ -187,69 +187,49 @@ impl PhysicalPlan { pub fn approximate_stats(&self) -> ApproxStats { match self { Self::InMemoryScan(InMemoryScan { in_memory_info, .. }) => ApproxStats { - lower_bound_rows: in_memory_info.num_rows, - upper_bound_rows: Some(in_memory_info.num_rows), - lower_bound_bytes: in_memory_info.size_bytes, - upper_bound_bytes: Some(in_memory_info.size_bytes), + num_rows: in_memory_info.num_rows, + size_bytes: in_memory_info.size_bytes, }, Self::TabularScan(TabularScan { scan_tasks, .. }) => { - let mut stats = ApproxStats::empty(); + let mut approx_stats = ApproxStats::empty(); for st in scan_tasks.iter() { - stats.lower_bound_rows += st.num_rows().unwrap_or(0); - let in_memory_size = st.estimate_in_memory_size_bytes(None); - stats.lower_bound_bytes += in_memory_size.unwrap_or(0); - if let Some(st_ub) = st.upper_bound_rows() { - if let Some(ub) = stats.upper_bound_rows { - stats.upper_bound_rows = Some(ub + st_ub); - } else { - stats.upper_bound_rows = st.upper_bound_rows(); - } - } - if let Some(st_ub) = in_memory_size { - if let Some(ub) = stats.upper_bound_bytes { - stats.upper_bound_bytes = Some(ub + st_ub); - } else { - stats.upper_bound_bytes = in_memory_size; - } - } + approx_stats.num_rows += st + .num_rows() + .unwrap_or_else(|| st.approx_num_rows(None).unwrap_or(0.0) as usize); + approx_stats.size_bytes += st.estimate_in_memory_size_bytes(None).unwrap_or(0); } - stats + approx_stats } Self::EmptyScan(..) => ApproxStats { - lower_bound_rows: 0, - upper_bound_rows: Some(0), - lower_bound_bytes: 0, - upper_bound_bytes: Some(0), + num_rows: 0, + size_bytes: 0, }, // Assume no row/column pruning in cardinality-affecting operations. // TODO(Clark): Estimate row/column pruning to get a better size approximation. - Self::Filter(Filter { input, .. }) => { + Self::Filter(Filter { + input, + estimated_selectivity, + .. + }) => { let input_stats = input.approximate_stats(); ApproxStats { - lower_bound_rows: 0, - upper_bound_rows: input_stats.upper_bound_rows, - lower_bound_bytes: 0, - upper_bound_bytes: input_stats.upper_bound_bytes, + num_rows: (input_stats.num_rows as f64 * estimated_selectivity).ceil() as usize, + size_bytes: (input_stats.size_bytes as f64 * estimated_selectivity).ceil() + as usize, } } Self::Limit(Limit { input, limit, .. }) => { - let limit = *limit as usize; let input_stats = input.approximate_stats(); - let est_bytes_per_row_lower = - input_stats.lower_bound_bytes / (input_stats.lower_bound_rows.max(1)); - let est_bytes_per_row_upper = input_stats - .upper_bound_bytes - .and_then(|bytes| input_stats.upper_bound_rows.map(|rows| bytes / rows.max(1))); - let new_lower_rows = input_stats.lower_bound_rows.min(limit); - let new_upper_rows = input_stats - .upper_bound_rows - .map(|ub| ub.min(limit)) - .unwrap_or(limit); + let limit = *limit as usize; ApproxStats { - lower_bound_rows: new_lower_rows, - upper_bound_rows: Some(new_upper_rows), - lower_bound_bytes: new_lower_rows * est_bytes_per_row_lower, - upper_bound_bytes: est_bytes_per_row_upper.map(|x| x * new_upper_rows), + num_rows: limit.min(input_stats.num_rows), + size_bytes: if input_stats.num_rows > limit { + let est_bytes_per_row = + input_stats.size_bytes / input_stats.num_rows.max(1); + limit * est_bytes_per_row + } else { + input_stats.size_bytes + }, } } Self::Project(Project { input, .. }) @@ -265,11 +245,10 @@ impl PhysicalPlan { .apply(|v| ((v as f64) * fraction) as usize), Self::Explode(Explode { input, .. }) => { let input_stats = input.approximate_stats(); + let est_num_exploded_rows = input_stats.num_rows * 4; ApproxStats { - lower_bound_rows: input_stats.lower_bound_rows, - upper_bound_rows: None, - lower_bound_bytes: input_stats.lower_bound_bytes, - upper_bound_bytes: None, + num_rows: est_num_exploded_rows, + size_bytes: input_stats.size_bytes, } } // Propagate child approximation for operations that don't affect cardinality. @@ -294,53 +273,35 @@ impl PhysicalPlan { let right_stats = right.approximate_stats(); ApproxStats { - lower_bound_rows: 0, - upper_bound_rows: left_stats - .upper_bound_rows - .and_then(|l| right_stats.upper_bound_rows.map(|r| l.max(r))), - lower_bound_bytes: 0, - upper_bound_bytes: left_stats - .upper_bound_bytes - .and_then(|l| right_stats.upper_bound_bytes.map(|r| l.max(r))), + num_rows: left_stats.num_rows.max(right_stats.num_rows), + size_bytes: left_stats.size_bytes.max(right_stats.size_bytes), } } // TODO(Clark): Approximate post-aggregation sizes via grouping estimates + aggregation type. Self::Aggregate(Aggregate { input, groupby, .. }) => { let input_stats = input.approximate_stats(); // TODO we should use schema inference here - let est_bytes_per_row_lower = - input_stats.lower_bound_bytes / (input_stats.lower_bound_rows.max(1)); - let est_bytes_per_row_upper = input_stats - .upper_bound_bytes - .and_then(|bytes| input_stats.upper_bound_rows.map(|rows| bytes / rows.max(1))); + let est_bytes_per_row = input_stats.size_bytes / (input_stats.num_rows.max(1)); if groupby.is_empty() { ApproxStats { - lower_bound_rows: input_stats.lower_bound_rows.min(1), - upper_bound_rows: Some(1), - lower_bound_bytes: input_stats.lower_bound_bytes.min(1) - * est_bytes_per_row_lower, - upper_bound_bytes: est_bytes_per_row_upper, + num_rows: 1, + size_bytes: est_bytes_per_row, } } else { - // we should use the new schema here + // Assume high cardinality for group by columns, and 80% of rows are unique. + let est_num_groups = input_stats.num_rows * 4 / 5; ApproxStats { - lower_bound_rows: input_stats.lower_bound_rows.min(1), - upper_bound_rows: input_stats.upper_bound_rows, - lower_bound_bytes: input_stats.lower_bound_bytes.min(1) - * est_bytes_per_row_lower, - upper_bound_bytes: input_stats.upper_bound_bytes, + num_rows: est_num_groups, + size_bytes: est_bytes_per_row * est_num_groups, } } } Self::Unpivot(Unpivot { input, values, .. }) => { let input_stats = input.approximate_stats(); let num_values = values.len(); - // the number of bytes should be the name but nows should be multiplied by num_values ApproxStats { - lower_bound_rows: input_stats.lower_bound_rows * num_values, - upper_bound_rows: input_stats.upper_bound_rows.map(|v| v * num_values), - lower_bound_bytes: input_stats.lower_bound_bytes, - upper_bound_bytes: input_stats.upper_bound_bytes, + num_rows: input_stats.num_rows * num_values, + size_bytes: input_stats.size_bytes, } } // Post-write DataFrame will contain paths to files that were written. @@ -408,7 +369,7 @@ impl PhysicalPlan { ).unwrap()), Self::ActorPoolProject(ActorPoolProject {projection, ..}) => Self::ActorPoolProject(ActorPoolProject::try_new(input.clone(), projection.clone()).unwrap()), - Self::Filter(Filter { predicate, .. }) => Self::Filter(Filter::new(input.clone(), predicate.clone())), + Self::Filter(Filter { predicate, estimated_selectivity,.. }) => Self::Filter(Filter::new(input.clone(), predicate.clone(), *estimated_selectivity)), Self::Limit(Limit { limit, eager, num_partitions, .. }) => Self::Limit(Limit::new(input.clone(), *limit, *eager, *num_partitions)), Self::Explode(Explode { to_explode, .. }) => Self::Explode(Explode::try_new(input.clone(), to_explode.clone()).unwrap()), Self::Unpivot(Unpivot { ids, values, variable_name, value_name, .. }) => Self::Unpivot(Unpivot::new(input.clone(), ids.clone(), values.clone(), variable_name, value_name)), diff --git a/src/daft-ray-execution/Cargo.toml b/src/daft-ray-execution/Cargo.toml new file mode 100644 index 0000000000..98b689c865 --- /dev/null +++ b/src/daft-ray-execution/Cargo.toml @@ -0,0 +1,22 @@ +[dependencies] +common-error = {workspace = true} +daft-logical-plan = {workspace = true} +daft-micropartition = {workspace = true} +pyo3 = {workspace = true, optional = true} + +[features] +default = ["python"] +python = [ + "dep:pyo3", + "common-error/python", + "daft-logical-plan/python", + "daft-micropartition/python" +] + +[lints] +workspace = true + +[package] +name = "daft-ray-execution" +edition.workspace = true +version.workspace = true diff --git a/src/daft-ray-execution/src/lib.rs b/src/daft-ray-execution/src/lib.rs new file mode 100644 index 0000000000..2180a54e45 --- /dev/null +++ b/src/daft-ray-execution/src/lib.rs @@ -0,0 +1,74 @@ +//! Wrapper around the python RayRunner class +#[cfg(feature = "python")] +use common_error::{DaftError, DaftResult}; +#[cfg(feature = "python")] +use daft_logical_plan::{LogicalPlanBuilder, PyLogicalPlanBuilder}; +#[cfg(feature = "python")] +use daft_micropartition::{python::PyMicroPartition, MicroPartitionRef}; +#[cfg(feature = "python")] +use pyo3::{ + intern, + prelude::*, + types::{PyDict, PyIterator}, +}; + +#[cfg(feature = "python")] +pub struct RayEngine { + ray_runner: PyObject, +} + +#[cfg(feature = "python")] +impl RayEngine { + pub fn try_new( + address: Option, + max_task_backlog: Option, + force_client_mode: Option, + ) -> DaftResult { + Python::with_gil(|py| { + let ray_runner_module = py.import(intern!(py, "daft.runners.ray_runner"))?; + let ray_runner = ray_runner_module.getattr(intern!(py, "RayRunner"))?; + let kwargs = PyDict::new(py); + kwargs.set_item(intern!(py, "address"), address)?; + kwargs.set_item(intern!(py, "max_task_backlog"), max_task_backlog)?; + kwargs.set_item(intern!(py, "force_client_mode"), force_client_mode)?; + + let instance = ray_runner.call((), Some(&kwargs))?; + let instance = instance.unbind(); + + Ok(Self { + ray_runner: instance, + }) + }) + } + + pub fn run_iter_impl( + &self, + py: Python<'_>, + lp: LogicalPlanBuilder, + results_buffer_size: Option, + ) -> DaftResult>> { + let py_lp = PyLogicalPlanBuilder::new(lp); + let builder = py.import(intern!(py, "daft.logical.builder"))?; + let builder = builder.getattr(intern!(py, "LogicalPlanBuilder"))?; + let builder = builder.call((py_lp,), None)?; + let result = self.ray_runner.call_method( + py, + intern!(py, "run_iter_tables"), + (builder, results_buffer_size), + None, + )?; + + let result = result.bind(py); + let iter = PyIterator::from_object(result)?; + + let iter = iter.map(|item| { + let item = item?; + let partition = item.getattr(intern!(py, "_micropartition"))?; + let partition = partition.extract::()?; + let partition = partition.inner; + Ok::<_, DaftError>(partition) + }); + + Ok(iter.collect()) + } +} diff --git a/src/daft-scan/src/builder.rs b/src/daft-scan/src/builder.rs index 68c67cbd20..30b6da294a 100644 --- a/src/daft-scan/src/builder.rs +++ b/src/daft-scan/src/builder.rs @@ -1,7 +1,9 @@ use std::{collections::BTreeMap, sync::Arc}; use common_error::DaftResult; -use common_file_formats::{CsvSourceConfig, FileFormatConfig, ParquetSourceConfig}; +use common_file_formats::{ + CsvSourceConfig, FileFormatConfig, JsonSourceConfig, ParquetSourceConfig, +}; use common_io_config::IOConfig; use common_scan_info::ScanOperatorRef; use daft_core::prelude::TimeUnit; @@ -263,6 +265,101 @@ impl CsvScanBuilder { } } +/// An argument builder for a JSON scan operator. +pub struct JsonScanBuilder { + pub glob_paths: Vec, + pub infer_schema: bool, + pub io_config: Option, + pub schema: Option, + pub file_path_column: Option, + pub hive_partitioning: bool, + pub schema_hints: Option, + pub buffer_size: Option, + pub chunk_size: Option, +} + +impl JsonScanBuilder { + pub fn new(glob_paths: T) -> Self { + let glob_paths = glob_paths.into_glob_path(); + Self::new_impl(glob_paths) + } + + fn new_impl(glob_paths: Vec) -> Self { + Self { + glob_paths, + infer_schema: true, + schema: None, + io_config: None, + file_path_column: None, + hive_partitioning: false, + buffer_size: None, + chunk_size: None, + schema_hints: None, + } + } + + pub fn infer_schema(mut self, infer_schema: bool) -> Self { + self.infer_schema = infer_schema; + self + } + + pub fn io_config(mut self, io_config: IOConfig) -> Self { + self.io_config = Some(io_config); + self + } + + pub fn schema(mut self, schema: SchemaRef) -> Self { + self.schema = Some(schema); + self + } + + pub fn file_path_column(mut self, file_path_column: String) -> Self { + self.file_path_column = Some(file_path_column); + self + } + + pub fn hive_partitioning(mut self, hive_partitioning: bool) -> Self { + self.hive_partitioning = hive_partitioning; + self + } + + pub fn schema_hints(mut self, schema_hints: SchemaRef) -> Self { + self.schema_hints = Some(schema_hints); + self + } + + pub fn buffer_size(mut self, buffer_size: usize) -> Self { + self.buffer_size = Some(buffer_size); + self + } + + pub fn chunk_size(mut self, chunk_size: usize) -> Self { + self.chunk_size = Some(chunk_size); + self + } + + /// Creates a logical table scan backed by a JSON scan operator. + pub async fn finish(self) -> DaftResult { + let cfg = JsonSourceConfig { + buffer_size: self.buffer_size, + chunk_size: self.chunk_size, + }; + let operator = Arc::new( + GlobScanOperator::try_new( + self.glob_paths, + Arc::new(FileFormatConfig::Json(cfg)), + Arc::new(StorageConfig::new_internal(false, self.io_config)), + self.infer_schema, + self.schema, + self.file_path_column, + self.hive_partitioning, + ) + .await?, + ); + LogicalPlanBuilder::table_scan(ScanOperatorRef(operator), None) + } +} + #[cfg(feature = "python")] pub fn delta_scan>( glob_path: T, @@ -302,3 +399,46 @@ pub fn delta_scan( ) -> DaftResult { panic!("Delta Lake scan requires the 'python' feature to be enabled.") } + +/// Creates a logical scan operator from a Python IcebergScanOperator. +/// ex: +/// ```python +/// iceberg_table = pyiceberg.table.StaticTable.from_metadata(metadata_location) +/// iceberg_scan = daft.iceberg.iceberg_scan.IcebergScanOperator(iceberg_table, snapshot_id, storage_config) +/// ``` +#[cfg(feature = "python")] +pub fn iceberg_scan>( + metadata_location: T, + snapshot_id: Option, + io_config: Option, +) -> DaftResult { + use pyo3::IntoPyObjectExt; + let storage_config: StorageConfig = io_config.unwrap_or_default().into(); + let scan_operator = Python::with_gil(|py| -> DaftResult { + // iceberg_table = pyiceberg.table.StaticTable.from_metadata(metadata_location) + let iceberg_table_module = PyModule::import(py, "pyiceberg.table")?; + let iceberg_static_table = iceberg_table_module.getattr("StaticTable")?; + let iceberg_table = + iceberg_static_table.call_method1("from_metadata", (metadata_location.as_ref(),))?; + // iceberg_scan = daft.iceberg.iceberg_scan.IcebergScanOperator(iceberg_table, snapshot_id, storage_config) + let iceberg_scan_module = PyModule::import(py, "daft.iceberg.iceberg_scan")?; + let iceberg_scan_class = iceberg_scan_module.getattr("IcebergScanOperator")?; + let iceberg_scan = iceberg_scan_class + .call1((iceberg_table, snapshot_id, storage_config))? + .into_py_any(py)?; + Ok(ScanOperatorHandle::from_python_scan_operator( + iceberg_scan, + py, + )?) + })?; + LogicalPlanBuilder::table_scan(scan_operator.into(), None) +} + +#[cfg(not(feature = "python"))] +pub fn iceberg_scan>( + uri: T, + snapshot_id: Option, + io_config: Option, +) -> DaftResult { + panic!("Iceberg scan requires the 'python' feature to be enabled.") +} diff --git a/src/daft-scan/src/lib.rs b/src/daft-scan/src/lib.rs index 2f984dc213..2e3b52a692 100644 --- a/src/daft-scan/src/lib.rs +++ b/src/daft-scan/src/lib.rs @@ -686,7 +686,9 @@ impl ScanTask { if self.pushdowns.filters.is_some() { // HACK: This might not be a good idea? We could also just return None here // Assume that filters filter out about 80% of the data - approx_total_num_rows_before_pushdowns / 5. + let estimated_selectivity = + self.pushdowns.estimated_selectivity(self.schema.as_ref()); + approx_total_num_rows_before_pushdowns * estimated_selectivity } else if let Some(limit) = self.pushdowns.limit { (limit as f64).min(approx_total_num_rows_before_pushdowns) } else { diff --git a/src/daft-scan/src/storage_config.rs b/src/daft-scan/src/storage_config.rs index 270964b705..b4f8a2cceb 100644 --- a/src/daft-scan/src/storage_config.rs +++ b/src/daft-scan/src/storage_config.rs @@ -65,6 +65,12 @@ impl Default for StorageConfig { } } +impl From for StorageConfig { + fn from(io_config: IOConfig) -> Self { + Self::new_internal(true, Some(io_config)) + } +} + #[cfg(feature = "python")] #[pymethods] impl StorageConfig { diff --git a/src/daft-scheduler/src/scheduler.rs b/src/daft-scheduler/src/scheduler.rs index 56662dac31..6e57f9c5b4 100644 --- a/src/daft-scheduler/src/scheduler.rs +++ b/src/daft-scheduler/src/scheduler.rs @@ -268,7 +268,6 @@ fn physical_plan_to_partition_tasks( ) -> PyResult { use daft_dsl::Expr; use daft_physical_plan::ops::{CrossJoin, ShuffleExchange, ShuffleExchangeStrategy}; - match physical_plan { PhysicalPlan::InMemoryScan(InMemoryScan { in_memory_info: InMemoryInfo { cache_key, .. }, @@ -355,7 +354,9 @@ fn physical_plan_to_partition_tasks( Ok(py_iter.into()) } - PhysicalPlan::Filter(Filter { input, predicate }) => { + PhysicalPlan::Filter(Filter { + input, predicate, .. + }) => { let upstream_iter = physical_plan_to_partition_tasks(input, py, psets, actor_pool_manager)?; let expressions_mod = py.import(pyo3::intern!(py, "daft.expressions.expressions"))?; diff --git a/src/daft-sql/Cargo.toml b/src/daft-sql/Cargo.toml index a402235011..42b7163666 100644 --- a/src/daft-sql/Cargo.toml +++ b/src/daft-sql/Cargo.toml @@ -4,6 +4,7 @@ common-error = {path = "../common/error"} common-io-config = {path = "../common/io-config", default-features = false} common-runtime = {workspace = true} daft-algebra = {path = "../daft-algebra"} +daft-catalog = {path = "../daft-catalog"} daft-core = {path = "../daft-core"} daft-dsl = {path = "../daft-dsl"} daft-functions = {path = "../daft-functions"} @@ -20,7 +21,14 @@ snafu.workspace = true rstest = {workspace = true} [features] -python = ["dep:pyo3", "common-error/python", "daft-functions/python", "daft-functions-json/python", "daft-scan/python"] +python = [ + "dep:pyo3", + "common-error/python", + "daft-functions/python", + "daft-functions-json/python", + "daft-scan/python", + "daft-catalog/python" +] [lints] workspace = true diff --git a/src/daft-sql/src/catalog.rs b/src/daft-sql/src/catalog.rs deleted file mode 100644 index 0b5634da36..0000000000 --- a/src/daft-sql/src/catalog.rs +++ /dev/null @@ -1,43 +0,0 @@ -use std::{collections::HashMap, sync::Arc}; - -use daft_logical_plan::{LogicalPlan, LogicalPlanRef}; - -/// A simple map of table names to logical plans -#[derive(Debug, Clone)] -pub struct SQLCatalog { - tables: HashMap>, -} - -impl SQLCatalog { - /// Create an empty catalog - #[must_use] - pub fn new() -> Self { - Self { - tables: HashMap::new(), - } - } - - /// Register a table with the catalog - pub fn register_table(&mut self, name: &str, plan: LogicalPlanRef) { - self.tables.insert(name.to_string(), plan); - } - - /// Get a table from the catalog - #[must_use] - pub fn get_table(&self, name: &str) -> Option { - self.tables.get(name).cloned() - } - - /// Copy from another catalog, using tables from other in case of conflict - pub fn copy_from(&mut self, other: &Self) { - for (name, plan) in &other.tables { - self.tables.insert(name.clone(), plan.clone()); - } - } -} - -impl Default for SQLCatalog { - fn default() -> Self { - Self::new() - } -} diff --git a/src/daft-sql/src/functions.rs b/src/daft-sql/src/functions.rs index d75f090072..79e92ea2a1 100644 --- a/src/daft-sql/src/functions.rs +++ b/src/daft-sql/src/functions.rs @@ -14,7 +14,7 @@ use crate::{ coalesce::SQLCoalesce, hashing, SQLModule, SQLModuleAggs, SQLModuleConfig, SQLModuleFloat, SQLModuleImage, SQLModuleJson, SQLModuleList, SQLModuleMap, SQLModuleNumeric, SQLModulePartitioning, SQLModulePython, SQLModuleSketch, SQLModuleStructs, - SQLModuleTemporal, SQLModuleUtf8, + SQLModuleTemporal, SQLModuleUri, SQLModuleUtf8, }, planner::SQLPlanner, unsupported_sql_err, @@ -36,6 +36,7 @@ pub(crate) static SQL_FUNCTIONS: Lazy = Lazy::new(|| { functions.register::(); functions.register::(); functions.register::(); + functions.register::(); functions.register::(); functions.register::(); functions.add_fn("coalesce", SQLCoalesce {}); @@ -91,6 +92,7 @@ pub trait SQLFunction: Send + Sync { .collect::>>() } + // nit cleanup: argument consistency with SQLTableFunction fn to_expr(&self, inputs: &[FunctionArg], planner: &SQLPlanner) -> SQLPlannerResult; /// Produce the docstrings for this SQL function, parametrized by an alias which is the function name to invoke this in SQL @@ -375,3 +377,31 @@ impl<'a> SQLPlanner<'a> { } } } + +/// A namespace for function argument parsing helpers. +pub(crate) mod args { + use common_io_config::IOConfig; + + use super::SQLFunctionArguments; + use crate::{error::PlannerError, modules::config::expr_to_iocfg, unsupported_sql_err}; + + /// Parses on_error => Literal['raise', 'null'] = 'raise' or err. + pub(crate) fn parse_on_error(args: &SQLFunctionArguments) -> Result { + match args.try_get_named::("on_error")?.as_deref() { + None => Ok(true), + Some("raise") => Ok(true), + Some("null") => Ok(false), + Some(other) => { + unsupported_sql_err!("Expected on_error to be 'raise' or 'null', found '{other}'") + } + } + } + + /// Parses io_config which is used in several SQL functions. + pub(crate) fn parse_io_config(args: &SQLFunctionArguments) -> Result { + args.get_named("io_config") + .map(expr_to_iocfg) + .transpose() + .map(|op| op.unwrap_or_default()) + } +} diff --git a/src/daft-sql/src/lib.rs b/src/daft-sql/src/lib.rs index 211973f0d3..82c4cc6f93 100644 --- a/src/daft-sql/src/lib.rs +++ b/src/daft-sql/src/lib.rs @@ -1,6 +1,5 @@ #![feature(let_chains)] -pub mod catalog; pub mod error; pub mod functions; mod modules; @@ -28,7 +27,7 @@ pub fn register_modules(parent: &Bound) -> PyResult<()> { mod tests { use std::sync::Arc; - use catalog::SQLCatalog; + use daft_catalog::DaftCatalog; use daft_core::prelude::*; use daft_dsl::{col, lit, Expr, OuterReferenceColumn, Subquery}; use daft_logical_plan::{ @@ -113,7 +112,7 @@ mod tests { #[fixture] fn planner() -> SQLPlanner<'static> { - let mut catalog = SQLCatalog::new(); + let mut catalog = DaftCatalog::default(); catalog.register_table("tbl1", tbl_1()); catalog.register_table("tbl2", tbl_2()); diff --git a/src/daft-sql/src/modules/image/decode.rs b/src/daft-sql/src/modules/image/decode.rs index a896c67a05..92f2954ca0 100644 --- a/src/daft-sql/src/modules/image/decode.rs +++ b/src/daft-sql/src/modules/image/decode.rs @@ -4,7 +4,7 @@ use sqlparser::ast::FunctionArg; use crate::{ error::{PlannerError, SQLPlannerResult}, - functions::{SQLFunction, SQLFunctionArguments}, + functions::{self, SQLFunction, SQLFunctionArguments}, unsupported_sql_err, }; @@ -21,20 +21,7 @@ impl TryFrom for ImageDecode { _ => unsupported_sql_err!("Expected mode to be a string"), }) .transpose()?; - - let raise_on_error = args - .get_named("on_error") - .map(|arg| match arg.as_ref() { - Expr::Literal(LiteralValue::Utf8(s)) => match s.as_ref() { - "raise" => Ok(true), - "null" => Ok(false), - _ => unsupported_sql_err!("Expected on_error to be 'raise' or 'null'"), - }, - _ => unsupported_sql_err!("Expected on_error to be 'raise' or 'null'"), - }) - .transpose()? - .unwrap_or(true); - + let raise_on_error = functions::args::parse_on_error(&args)?; Ok(Self { mode, raise_on_error, diff --git a/src/daft-sql/src/modules/mod.rs b/src/daft-sql/src/modules/mod.rs index 30195dc52f..d46a25f55e 100644 --- a/src/daft-sql/src/modules/mod.rs +++ b/src/daft-sql/src/modules/mod.rs @@ -15,6 +15,7 @@ pub mod python; pub mod sketch; pub mod structs; pub mod temporal; +pub mod uri; pub mod utf8; pub use aggs::SQLModuleAggs; @@ -30,6 +31,7 @@ pub use python::SQLModulePython; pub use sketch::SQLModuleSketch; pub use structs::SQLModuleStructs; pub use temporal::SQLModuleTemporal; +pub use uri::SQLModuleUri; pub use utf8::SQLModuleUtf8; /// A [SQLModule] is a collection of SQL functions that can be registered with a [SQLFunctions] instance. diff --git a/src/daft-sql/src/modules/uri/mod.rs b/src/daft-sql/src/modules/uri/mod.rs new file mode 100644 index 0000000000..5c748c9ece --- /dev/null +++ b/src/daft-sql/src/modules/uri/mod.rs @@ -0,0 +1,14 @@ +use super::SQLModule; +use crate::functions::SQLFunctions; + +mod url_download; +mod url_upload; + +pub struct SQLModuleUri; + +impl SQLModule for SQLModuleUri { + fn register(parent: &mut SQLFunctions) { + parent.add_fn("url_download", url_download::SqlUrlDownload); + parent.add_fn("url_upload", url_upload::SqlUrlUpload); + } +} diff --git a/src/daft-sql/src/modules/uri/url_download.rs b/src/daft-sql/src/modules/uri/url_download.rs new file mode 100644 index 0000000000..9b7d97fa5b --- /dev/null +++ b/src/daft-sql/src/modules/uri/url_download.rs @@ -0,0 +1,58 @@ +use daft_dsl::ExprRef; +use daft_functions::uri::{self, download::UrlDownloadArgs}; +use sqlparser::ast::FunctionArg; + +use crate::{ + error::{PlannerError, SQLPlannerResult}, + functions::{self, SQLFunction, SQLFunctionArguments}, + unsupported_sql_err, SQLPlanner, +}; + +/// The Daft-SQL `url_download` definition. +pub struct SqlUrlDownload; + +impl TryFrom for UrlDownloadArgs { + type Error = PlannerError; + + fn try_from(args: SQLFunctionArguments) -> Result { + let max_connections: usize = args.try_get_named("max_connections")?.unwrap_or(32); + let raise_error_on_failure = functions::args::parse_on_error(&args)?; + let io_config = functions::args::parse_io_config(&args)?; + Ok(Self { + max_connections, + raise_error_on_failure, + multi_thread: true, // TODO always true + io_config: io_config.into(), + }) + } +} + +impl SQLFunction for SqlUrlDownload { + fn to_expr(&self, inputs: &[FunctionArg], planner: &SQLPlanner) -> SQLPlannerResult { + match inputs { + [input] => { + let input = planner.plan_function_arg(input)?; + Ok(uri::download(input, None)) + } + [input, args @ ..] => { + let input = planner.plan_function_arg(input)?; + let args = planner.plan_function_args( + args, + &["max_connections", "on_error", "io_config"], + 0, + )?; + Ok(uri::download(input, Some(args))) + } + _ => unsupported_sql_err!("Invalid arguments for url_download: '{inputs:?}'"), + } + } + + fn docstrings(&self, _alias: &str) -> String { + "Treats each string as a URL, and downloads the bytes contents as a bytes column." + .to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["input", "max_connections", "on_error", "io_config"] + } +} diff --git a/src/daft-sql/src/modules/uri/url_upload.rs b/src/daft-sql/src/modules/uri/url_upload.rs new file mode 100644 index 0000000000..a7ac62f7c3 --- /dev/null +++ b/src/daft-sql/src/modules/uri/url_upload.rs @@ -0,0 +1,73 @@ +use daft_dsl::{Expr, ExprRef, LiteralValue}; +use daft_functions::uri::{self, upload::UrlUploadArgs}; +use sqlparser::ast::FunctionArg; + +use crate::{ + error::{PlannerError, SQLPlannerResult}, + functions::{self, SQLFunction, SQLFunctionArguments}, + unsupported_sql_err, SQLPlanner, +}; + +/// The Daft-SQL `url_upload` definition. +pub struct SqlUrlUpload; + +impl TryFrom for UrlUploadArgs { + type Error = PlannerError; + + fn try_from(args: SQLFunctionArguments) -> Result { + let max_connections: usize = args.try_get_named("max_connections")?.unwrap_or(32); + let raise_error_on_failure = functions::args::parse_on_error(&args)?; + let io_config = functions::args::parse_io_config(&args)?; + Ok(Self { + max_connections, + raise_error_on_failure, + multi_thread: true, // TODO always true + is_single_folder: true, + io_config: io_config.into(), + }) + } +} + +impl SQLFunction for SqlUrlUpload { + fn to_expr(&self, inputs: &[FunctionArg], planner: &SQLPlanner) -> SQLPlannerResult { + match inputs { + [input, location] => { + let input = planner.plan_function_arg(input)?; + let location = planner.plan_function_arg(location)?; + Ok(uri::upload(input, location, None)) + } + [input, location, args @ ..] => { + let input = planner.plan_function_arg(input)?; + let location = planner.plan_function_arg(location)?; + let mut args: UrlUploadArgs = planner.plan_function_args( + args, + &["max_connections", "on_error", "io_config"], + 0, + )?; + // TODO consider moving the calculation of "is_single_folder" deeper so that both the SQL and Python sides don't compute this independently. + // is_single_folder = true iff the location is a string. + // in python, this is isinstance(location, str) + // We perform the check here (with mut args) because TryFrom does not have access to `location`. + args.is_single_folder = + matches!(location.as_ref(), Expr::Literal(LiteralValue::Utf8(_))); + Ok(uri::upload(input, location, Some(args))) + } + _ => unsupported_sql_err!("Invalid arguments for url_upload: '{inputs:?}'"), + } + } + + fn docstrings(&self, _alias: &str) -> String { + "Uploads a column of binary data to the provided location(s) (also supports S3, local etc)." + .to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &[ + "input", + "location", + "max_connections", + "on_error", + "io_config", + ] + } +} diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index ce2ef703a3..87cdece093 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -1,12 +1,14 @@ use std::{ cell::{Ref, RefCell, RefMut}, collections::{HashMap, HashSet}, + path::Path, rc::Rc, sync::Arc, }; use common_error::{DaftError, DaftResult}; use daft_algebra::boolean::combine_conjunction; +use daft_catalog::DaftCatalog; use daft_core::prelude::*; use daft_dsl::{ col, @@ -21,10 +23,11 @@ use daft_functions::{ use daft_logical_plan::{LogicalPlanBuilder, LogicalPlanRef}; use sqlparser::{ ast::{ - ArrayElemTypeDef, BinaryOperator, CastKind, ColumnDef, DateTimeField, Distinct, - ExactNumberInfo, ExcludeSelectItem, GroupByExpr, Ident, Query, SelectItem, SetExpr, - Statement, StructField, Subscript, TableAlias, TableWithJoins, TimezoneInfo, UnaryOperator, - Value, WildcardAdditionalOptions, With, + self, ArrayElemTypeDef, BinaryOperator, CastKind, ColumnDef, DateTimeField, Distinct, + ExactNumberInfo, ExcludeSelectItem, FunctionArg, FunctionArgExpr, GroupByExpr, Ident, + ObjectName, Query, SelectItem, SetExpr, Statement, StructField, Subscript, TableAlias, + TableFunctionArgs, TableWithJoins, TimezoneInfo, UnaryOperator, Value, + WildcardAdditionalOptions, With, }, dialect::GenericDialect, parser::{Parser, ParserOptions}, @@ -32,8 +35,7 @@ use sqlparser::{ }; use crate::{ - catalog::SQLCatalog, column_not_found_err, error::*, invalid_operation_err, - table_not_found_err, unsupported_sql_err, + column_not_found_err, error::*, invalid_operation_err, table_not_found_err, unsupported_sql_err, }; /// A named logical plan @@ -71,20 +73,12 @@ impl Relation { } /// Context that is shared across a query and its subqueries +#[derive(Default)] struct PlannerContext { - catalog: SQLCatalog, + catalog: DaftCatalog, cte_map: HashMap, } -impl Default for PlannerContext { - fn default() -> Self { - Self { - catalog: SQLCatalog::new(), - cte_map: Default::default(), - } - } -} - #[derive(Default)] pub struct SQLPlanner<'a> { current_relation: Option, @@ -98,7 +92,7 @@ pub struct SQLPlanner<'a> { } impl<'a> SQLPlanner<'a> { - pub fn new(catalog: SQLCatalog) -> Self { + pub fn new(catalog: DaftCatalog) -> Self { let context = Rc::new(RefCell::new(PlannerContext { catalog, ..Default::default() @@ -144,7 +138,7 @@ impl<'a> SQLPlanner<'a> { Ref::map(self.context.borrow(), |i| &i.cte_map) } - fn catalog(&self) -> Ref<'_, SQLCatalog> { + fn catalog(&self) -> Ref<'_, DaftCatalog> { Ref::map(self.context.borrow(), |i| &i.catalog) } @@ -1002,21 +996,11 @@ impl<'a> SQLPlanner<'a> { alias, .. } => { - let table_name = name.to_string(); - let Some(rel) = self - .table_map - .get(&table_name) - .cloned() - .or_else(|| self.cte_map().get(&table_name).cloned()) - .or_else(|| { - self.catalog() - .get_table(&table_name) - .map(|table| Relation::new(table.into(), table_name.clone())) - }) - else { - table_not_found_err!(table_name) + let rel = if is_table_path(name) { + self.plan_relation_path(name)? + } else { + self.plan_relation_table(name)? }; - (rel, alias.clone()) } sqlparser::ast::TableFactor::Derived { @@ -1066,6 +1050,46 @@ impl<'a> SQLPlanner<'a> { } } + /// Plan a `FROM ` table factor by rewriting to relevant table-value function. + fn plan_relation_path(&self, name: &ObjectName) -> SQLPlannerResult { + let path = name.0[0].value.as_str(); + let func = match Path::new(path).extension() { + Some(ext) if ext.eq_ignore_ascii_case("csv") => "read_csv", + Some(ext) if ext.eq_ignore_ascii_case("json") => "read_json", + Some(ext) if ext.eq_ignore_ascii_case("jsonl") => "read_json", + Some(ext) if ext.eq_ignore_ascii_case("parquet") => "read_parquet", + Some(_) => invalid_operation_err!("unsupported file path extension: {}", name), + None => invalid_operation_err!("unsupported file path, no extension: {}", name), + }; + let args = TableFunctionArgs { + args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr( + ast::Expr::Value(Value::SingleQuotedString(path.to_string())), + ))], + settings: None, + }; + self.plan_table_function(func, &args) + } + + /// Plan a `FROM
` table factor. + fn plan_relation_table(&self, name: &ObjectName) -> SQLPlannerResult { + let table_name = name.to_string(); + let Some(rel) = self + .table_map + .get(&table_name) + .cloned() + .or_else(|| self.cte_map().get(&table_name).cloned()) + .or_else(|| { + self.catalog() + .read_table(&table_name) + .ok() + .map(|table| Relation::new(table, table_name.clone())) + }) + else { + table_not_found_err!(table_name) + }; + Ok(rel) + } + fn plan_identifier(&self, idents: &[Ident]) -> SQLPlannerResult { // if the current relation is not resolved (e.g. in a `sql_expr` call, simply wrap identifier in a col) if self.current_relation.is_none() { @@ -1791,6 +1815,7 @@ impl<'a> SQLPlanner<'a> { BinaryOperator::Multiply => Ok(Operator::Multiply), BinaryOperator::Divide => Ok(Operator::TrueDivide), BinaryOperator::Eq => Ok(Operator::Eq), + BinaryOperator::Spaceship => Ok(Operator::EqNullSafe), BinaryOperator::Modulo => Ok(Operator::Modulus), BinaryOperator::Gt => Ok(Operator::Gt), BinaryOperator::Lt => Ok(Operator::Lt), @@ -2208,6 +2233,24 @@ fn idents_to_str(idents: &[Ident]) -> String { .join(".") } +/// Returns true iff the ObjectName is a string literal (single-quoted identifier e.g. 'path/to/file.extension'). +/// +/// # Examples +/// +/// ``` +/// 'file.ext' -> true +/// 'path/to/file.ext' -> true +/// 'a'.'b'.'c' -> false (multiple identifiers) +/// "path/to/file.ext" -> false (double-quotes) +/// hello -> false (not single-quoted) +/// ``` +fn is_table_path(name: &ObjectName) -> bool { + if name.0.len() != 1 { + return false; + } + matches!(name.0[0].quote_style, Some('\'')) +} + /// unresolves an alias in a projection /// Example: /// ```sql @@ -2248,8 +2291,9 @@ fn unresolve_alias(expr: ExprRef, projection: &[ExprRef]) -> SQLPlannerResult Vec { #[pyclass(module = "daft.daft")] #[derive(Debug, Clone)] pub struct PyCatalog { - catalog: SQLCatalog, + catalog: DaftCatalog, } #[pymethods] @@ -77,14 +78,19 @@ impl PyCatalog { #[staticmethod] pub fn new() -> Self { Self { - catalog: SQLCatalog::new(), + catalog: DaftCatalog::default(), } } /// Register a table with the catalog. - pub fn register_table(&mut self, name: &str, dataframe: &mut PyLogicalPlanBuilder) { + pub fn register_table( + &mut self, + name: &str, + dataframe: &mut PyLogicalPlanBuilder, + ) -> PyResult<()> { let plan = dataframe.builder.build(); - self.catalog.register_table(name, plan); + self.catalog.register_table(name, plan)?; + Ok(()) } /// Copy from another catalog, using tables from other in case of conflict diff --git a/src/daft-sql/src/table_provider/mod.rs b/src/daft-sql/src/table_provider/mod.rs index f32ea01ab9..fabce9bdc8 100644 --- a/src/daft-sql/src/table_provider/mod.rs +++ b/src/daft-sql/src/table_provider/mod.rs @@ -1,10 +1,17 @@ -pub mod read_csv; -pub mod read_parquet; +mod read_csv; +mod read_deltalake; +mod read_iceberg; +mod read_json; +mod read_parquet; + use std::{collections::HashMap, sync::Arc}; use daft_logical_plan::LogicalPlanBuilder; use once_cell::sync::Lazy; use read_csv::ReadCsvFunction; +use read_deltalake::ReadDeltalakeFunction; +use read_iceberg::SqlReadIceberg; +use read_json::ReadJsonFunction; use read_parquet::ReadParquetFunction; use sqlparser::ast::TableFunctionArgs; @@ -17,11 +24,11 @@ use crate::{ pub(crate) static SQL_TABLE_FUNCTIONS: Lazy = Lazy::new(|| { let mut functions = SQLTableFunctions::new(); - functions.add_fn("read_parquet", ReadParquetFunction); functions.add_fn("read_csv", ReadCsvFunction); - #[cfg(feature = "python")] functions.add_fn("read_deltalake", ReadDeltalakeFunction); - + functions.add_fn("read_iceberg", SqlReadIceberg); + functions.add_fn("read_json", ReadJsonFunction); + functions.add_fn("read_parquet", ReadParquetFunction); functions }); @@ -68,6 +75,7 @@ impl<'a> SQLPlanner<'a> { } } +// nit cleanup: switch param order and rename to `to_logical_plan` for consistency with SQLFunction. pub(crate) trait SQLTableFunction: Send + Sync { fn plan( &self, @@ -75,43 +83,3 @@ pub(crate) trait SQLTableFunction: Send + Sync { args: &TableFunctionArgs, ) -> SQLPlannerResult; } - -pub struct ReadDeltalakeFunction; - -#[cfg(feature = "python")] -impl SQLTableFunction for ReadDeltalakeFunction { - fn plan( - &self, - planner: &SQLPlanner, - args: &TableFunctionArgs, - ) -> SQLPlannerResult { - let (uri, io_config) = match args.args.as_slice() { - [uri] => (uri, None), - [uri, io_config] => { - let args = planner.parse_function_args(&[io_config.clone()], &["io_config"], 0)?; - let io_config = args.get_named("io_config").map(expr_to_iocfg).transpose()?; - - (uri, io_config) - } - _ => unsupported_sql_err!("Expected one or two arguments"), - }; - let uri = planner.plan_function_arg(uri)?; - - let Some(uri) = uri.as_literal().and_then(|lit| lit.as_str()) else { - unsupported_sql_err!("Expected a string literal for the first argument"); - }; - - daft_scan::builder::delta_scan(uri, io_config, true).map_err(From::from) - } -} - -#[cfg(not(feature = "python"))] -impl SQLTableFunction for ReadDeltalakeFunction { - fn plan( - &self, - planner: &SQLPlanner, - args: &TableFunctionArgs, - ) -> SQLPlannerResult { - unsupported_sql_err!("`read_deltalake` function is not supported. Enable the `python` feature to use this function.") - } -} diff --git a/src/daft-sql/src/table_provider/read_csv.rs b/src/daft-sql/src/table_provider/read_csv.rs index 0ced7ea5a9..e27c5873a0 100644 --- a/src/daft-sql/src/table_provider/read_csv.rs +++ b/src/daft-sql/src/table_provider/read_csv.rs @@ -16,6 +16,10 @@ impl TryFrom for CsvScanBuilder { type Error = PlannerError; fn try_from(args: SQLFunctionArguments) -> Result { + // TODO validations (unsure if should carry over from python API) + // - schema_hints is deprecated + // - ensure infer_schema is true if schema is None. + let delimiter = args.try_get_named("delimiter")?; let has_headers: bool = args.try_get_named("has_headers")?.unwrap_or(true); let double_quote: bool = args.try_get_named("double_quote")?.unwrap_or(true); diff --git a/src/daft-sql/src/table_provider/read_deltalake.rs b/src/daft-sql/src/table_provider/read_deltalake.rs new file mode 100644 index 0000000000..142ce8ec89 --- /dev/null +++ b/src/daft-sql/src/table_provider/read_deltalake.rs @@ -0,0 +1,44 @@ +use daft_logical_plan::LogicalPlanBuilder; +use sqlparser::ast::TableFunctionArgs; + +use super::{expr_to_iocfg, SQLTableFunction}; +use crate::{error::SQLPlannerResult, unsupported_sql_err, SQLPlanner}; + +pub(super) struct ReadDeltalakeFunction; + +#[cfg(feature = "python")] +impl SQLTableFunction for ReadDeltalakeFunction { + fn plan( + &self, + planner: &SQLPlanner, + args: &TableFunctionArgs, + ) -> SQLPlannerResult { + let (uri, io_config) = match args.args.as_slice() { + [uri] => (uri, None), + [uri, io_config] => { + let args = planner.parse_function_args(&[io_config.clone()], &["io_config"], 0)?; + let io_config = args.get_named("io_config").map(expr_to_iocfg).transpose()?; + (uri, io_config) + } + _ => unsupported_sql_err!("Expected one or two arguments"), + }; + let uri = planner.plan_function_arg(uri)?; + + let Some(uri) = uri.as_literal().and_then(|lit| lit.as_str()) else { + unsupported_sql_err!("Expected a string literal for the first argument"); + }; + + daft_scan::builder::delta_scan(uri, io_config, true).map_err(From::from) + } +} + +#[cfg(not(feature = "python"))] +impl SQLTableFunction for ReadDeltalakeFunction { + fn plan( + &self, + planner: &SQLPlanner, + args: &TableFunctionArgs, + ) -> SQLPlannerResult { + unsupported_sql_err!("`read_deltalake` function is not supported. Enable the `python` feature to use this function.") + } +} diff --git a/src/daft-sql/src/table_provider/read_iceberg.rs b/src/daft-sql/src/table_provider/read_iceberg.rs new file mode 100644 index 0000000000..b8b9e30d8a --- /dev/null +++ b/src/daft-sql/src/table_provider/read_iceberg.rs @@ -0,0 +1,74 @@ +use common_io_config::IOConfig; +use daft_logical_plan::LogicalPlanBuilder; +use sqlparser::ast::TableFunctionArgs; + +use super::SQLTableFunction; +use crate::{ + error::{PlannerError, SQLPlannerResult}, + functions::{self, SQLFunctionArguments}, + SQLPlanner, +}; + +/// The Daft-SQL `read_iceberg` table-value function. +pub(super) struct SqlReadIceberg; + +/// The Daft-SQL `read_iceberg` table-value function arguments. +struct SqlReadIcebergArgs { + metadata_location: String, + snapshot_id: Option, + io_config: Option, +} + +impl SqlReadIcebergArgs { + /// Like a TryFrom but from TalbeFunctionArgs directly and passing the planner. + fn try_from(planner: &SQLPlanner, args: &TableFunctionArgs) -> SQLPlannerResult { + planner.plan_function_args(&args.args, &["snapshot_id", "io_config"], 1) + } +} + +impl TryFrom for SqlReadIcebergArgs { + type Error = PlannerError; + + /// This is required to use `planner.plan_function_args` + fn try_from(args: SQLFunctionArguments) -> Result { + let metadata_location: String = args + .try_get_positional(0)? + .expect("read_iceberg requires a path"); + let snapshot_id: Option = args.try_get_named("snapshot_id")?; + let io_config: Option = functions::args::parse_io_config(&args)?.into(); + Ok(Self { + metadata_location, + snapshot_id, + io_config, + }) + } +} + +/// Translates the `read_iceberg` table-value function to a logical scan operator. +#[cfg(feature = "python")] +impl SQLTableFunction for SqlReadIceberg { + fn plan( + &self, + planner: &SQLPlanner, + args: &TableFunctionArgs, + ) -> SQLPlannerResult { + let args = SqlReadIcebergArgs::try_from(planner, args)?; + Ok(daft_scan::builder::iceberg_scan( + args.metadata_location, + args.snapshot_id, + args.io_config, + )?) + } +} + +/// Translates the `read_iceberg` table-value function to a logical scan operator (errors without python feature). +#[cfg(not(feature = "python"))] +impl SQLTableFunction for SqlReadIceberg { + fn plan( + &self, + planner: &SQLPlanner, + args: &TableFunctionArgs, + ) -> SQLPlannerResult { + crate::unsupported_sql_err!("`read_iceberg` function is not supported. Enable the `python` feature to use this function.") + } +} diff --git a/src/daft-sql/src/table_provider/read_json.rs b/src/daft-sql/src/table_provider/read_json.rs new file mode 100644 index 0000000000..f892901691 --- /dev/null +++ b/src/daft-sql/src/table_provider/read_json.rs @@ -0,0 +1,68 @@ +use daft_scan::builder::JsonScanBuilder; + +use super::{expr_to_iocfg, SQLTableFunction}; +use crate::{error::PlannerError, functions::SQLFunctionArguments}; + +pub(super) struct ReadJsonFunction; + +impl SQLTableFunction for ReadJsonFunction { + fn plan( + &self, + planner: &crate::SQLPlanner, + args: &sqlparser::ast::TableFunctionArgs, + ) -> crate::error::SQLPlannerResult { + let builder: JsonScanBuilder = planner.plan_function_args( + args.args.as_slice(), + &[ + "path", + "infer_schema", + // "schema" + "io_config", + "file_path_column", + "hive_partitioning", + // "schema_hints", + "buffer_size", + "chunk_size", + ], + 1, // (path) + )?; + let runtime = common_runtime::get_io_runtime(true); + let result = runtime.block_on(builder.finish())??; + Ok(result) + } +} + +impl TryFrom for JsonScanBuilder { + type Error = PlannerError; + + fn try_from(args: SQLFunctionArguments) -> Result { + // TODO validations (unsure if should carry over from python API) + // - schema_hints is deprecated + // - ensure infer_schema is true if schema is None. + + let glob_paths: String = args + .try_get_positional(0)? + .ok_or_else(|| PlannerError::invalid_operation("path is required for `read_json`"))?; + + let infer_schema = args.try_get_named("infer_schema")?.unwrap_or(true); + let chunk_size = args.try_get_named("chunk_size")?; + let buffer_size = args.try_get_named("buffer_size")?; + let file_path_column = args.try_get_named("file_path_column")?; + let hive_partitioning = args.try_get_named("hive_partitioning")?.unwrap_or(false); + let schema = None; // TODO + let schema_hints = None; // TODO + let io_config = args.get_named("io_config").map(expr_to_iocfg).transpose()?; + + Ok(Self { + glob_paths: vec![glob_paths], + infer_schema, + schema, + io_config, + file_path_column, + hive_partitioning, + schema_hints, + buffer_size, + chunk_size, + }) + } +} diff --git a/src/daft-stats/src/column_stats/comparison.rs b/src/daft-stats/src/column_stats/comparison.rs index 7e2021744c..33cc0ed505 100644 --- a/src/daft-stats/src/column_stats/comparison.rs +++ b/src/daft-stats/src/column_stats/comparison.rs @@ -44,6 +44,10 @@ impl DaftCompare<&Self> for ColumnRangeStatistics { self.equal(rhs)?.not() } + fn eq_null_safe(&self, rhs: &Self) -> Self::Output { + self.equal(rhs) + } + fn gt(&self, rhs: &Self) -> Self::Output { // lower_bound: True greater (self.lower > rhs.upper) // upper_bound: some value that can be greater (self.upper > rhs.lower) diff --git a/src/daft-table/src/lib.rs b/src/daft-table/src/lib.rs index f57a17a5bc..f2adfda35e 100644 --- a/src/daft-table/src/lib.rs +++ b/src/daft-table/src/lib.rs @@ -568,6 +568,7 @@ impl Table { Lt => Ok(lhs.lt(&rhs)?.into_series()), LtEq => Ok(lhs.lte(&rhs)?.into_series()), Eq => Ok(lhs.equal(&rhs)?.into_series()), + EqNullSafe => Ok(lhs.eq_null_safe(&rhs)?.into_series()), NotEq => Ok(lhs.not_equal(&rhs)?.into_series()), GtEq => Ok(lhs.gte(&rhs)?.into_series()), Gt => Ok(lhs.gt(&rhs)?.into_series()), diff --git a/tests/assets/json-data/small.jsonl b/tests/assets/json-data/small.jsonl new file mode 100644 index 0000000000..617f566255 --- /dev/null +++ b/tests/assets/json-data/small.jsonl @@ -0,0 +1,25 @@ +{ "x": 42, "y": "apple", "z": true } +{ "x": 17, "y": "banana", "z": false } +{ "x": 89, "y": "cherry", "z": true } +{ "x": 3, "y": "date", "z": false } +{ "x": 156, "y": "elderberry", "z": true } +{ "x": 23, "y": "fig", "z": true } +{ "x": 777, "y": "grape", "z": false } +{ "x": 444, "y": "honeydew", "z": true } +{ "x": 91, "y": "kiwi", "z": false } +{ "x": 12, "y": "lemon", "z": true } +{ "x": 365, "y": "mango", "z": false } +{ "x": 55, "y": "nectarine", "z": true } +{ "x": 888, "y": "orange", "z": false } +{ "x": 247, "y": "papaya", "z": true } +{ "x": 33, "y": "quince", "z": false } +{ "x": 159, "y": "raspberry", "z": true } +{ "x": 753, "y": "strawberry", "z": false } +{ "x": 951, "y": "tangerine", "z": true } +{ "x": 426, "y": "ugli fruit", "z": false } +{ "x": 87, "y": "vanilla", "z": true } +{ "x": 234, "y": "watermelon", "z": false } +{ "x": 567, "y": "xigua", "z": true } +{ "x": 111, "y": "yuzu", "z": false } +{ "x": 999, "y": "zucchini", "z": true } +{ "x": 123, "y": "apricot", "z": false } diff --git a/tests/connect/test_alias.py b/tests/connect/test_alias.py deleted file mode 100644 index 94efb35fc2..0000000000 --- a/tests/connect/test_alias.py +++ /dev/null @@ -1,21 +0,0 @@ -from __future__ import annotations - -from pyspark.sql.functions import col - - -def test_alias(spark_session): - # Create DataFrame from range(10) - df = spark_session.range(10) - - # Simply rename the 'id' column to 'my_number' - df_renamed = df.select(col("id").alias("my_number")) - - # Verify the alias was set correctly - assert df_renamed.schema != df.schema, "Schema should be changed after alias" - - # Verify the data is unchanged but column name is different - df_rows = df.collect() - df_renamed_rows = df_renamed.collect() - assert [row.id for row in df_rows] == [ - row.my_number for row in df_renamed_rows - ], "Data should be unchanged after alias" diff --git a/tests/connect/test_analyze_plan.py b/tests/connect/test_analyze_plan.py deleted file mode 100644 index 492de7e53c..0000000000 --- a/tests/connect/test_analyze_plan.py +++ /dev/null @@ -1,18 +0,0 @@ -from __future__ import annotations - -import pytest - - -@pytest.mark.skip( - reason="Currently an issue in the spark connect code. It always passes the inferred schema instead of the supplied schema." -) -def test_analyze_plan(spark_session): - data = [[1000, 99]] - df1 = spark_session.createDataFrame(data, schema="Value int, Total int") - s = df1.schema - - # todo: this is INCORRECT but it is an issue with pyspark client - # right now it is assert str(s) == "StructType([StructField('_1', LongType(), True), StructField('_2', LongType(), True)])" - assert ( - str(s) == "StructType([StructField('Value', IntegerType(), True), StructField('Total', IntegerType(), True)])" - ) diff --git a/tests/connect/test_basic_column.py b/tests/connect/test_basic_column.py deleted file mode 100644 index 95dcb5cdd0..0000000000 --- a/tests/connect/test_basic_column.py +++ /dev/null @@ -1,46 +0,0 @@ -from __future__ import annotations - -from pyspark.sql.functions import col -from pyspark.sql.types import StringType - - -def test_column_alias(spark_session): - df = spark_session.range(10) - df_alias = df.select(col("id").alias("my_number")) - assert "my_number" in df_alias.columns, "alias should rename column" - assert df_alias.toPandas()["my_number"].equals(df.toPandas()["id"]), "data should be unchanged" - - -def test_column_cast(spark_session): - df = spark_session.range(10) - df_cast = df.select(col("id").cast(StringType())) - assert df_cast.schema.fields[0].dataType == StringType(), "cast should change data type" - assert df_cast.toPandas()["id"].dtype == "object", "cast should change pandas dtype to object/string" - - -def test_column_null_checks(spark_session): - df = spark_session.range(10) - df_null = df.select(col("id").isNotNull().alias("not_null"), col("id").isNull().alias("is_null")) - assert df_null.toPandas()["not_null"].iloc[0], "isNotNull should be True for non-null values" - assert not df_null.toPandas()["is_null"].iloc[0], "isNull should be False for non-null values" - - -def test_column_name(spark_session): - df = spark_session.range(10) - df_name = df.select(col("id").name("renamed_id")) - assert "renamed_id" in df_name.columns, "name should rename column" - assert df_name.toPandas()["renamed_id"].equals(df.toPandas()["id"]), "data should be unchanged" - - -# TODO: Uncomment when https://github.com/Eventual-Inc/Daft/issues/3433 is fixed -# def test_column_desc(spark_session): -# df = spark_session.range(10) -# df_attr = df.select(col("id").desc()) -# assert df_attr.toPandas()["id"].iloc[0] == 9, "desc should sort in descending order" - - -# TODO: Add test when extract value is implemented -# def test_column_getitem(spark_session): -# df = spark_session.range(10) -# df_item = df.select(col("id")[0]) -# assert df_item.toPandas()["id"].iloc[0] == 0, "getitem should return first element" diff --git a/tests/connect/test_basics.py b/tests/connect/test_basics.py new file mode 100644 index 0000000000..44fd560b5b --- /dev/null +++ b/tests/connect/test_basics.py @@ -0,0 +1,347 @@ +from __future__ import annotations + +import pytest +from pyspark.sql import functions as F +from pyspark.sql.functions import col +from pyspark.sql.types import LongType, StringType, StructField, StructType + + +def test_alias(spark_session): + # Create DataFrame from range(10) + df = spark_session.range(10) + + # Simply rename the 'id' column to 'my_number' + df_renamed = df.select(col("id").alias("my_number")) + + # Verify the alias was set correctly + assert df_renamed.schema != df.schema, "Schema should be changed after alias" + + # Verify the data is unchanged but column name is different + df_rows = df.collect() + df_renamed_rows = df_renamed.collect() + assert [row.id for row in df_rows] == [ + row.my_number for row in df_renamed_rows + ], "Data should be unchanged after alias" + + +@pytest.mark.skip( + reason="Currently an issue in the spark connect code. It always passes the inferred schema instead of the supplied schema. see: https://issues.apache.org/jira/browse/SPARK-50627" +) +def test_analyze_plan(spark_session): + data = [[1000, 99]] + df1 = spark_session.createDataFrame(data, schema="Value int, Total int") + s = df1.schema + + # todo: this is INCORRECT but it is an issue with pyspark client + # right now it is assert str(s) == "StructType([StructField('_1', LongType(), True), StructField('_2', LongType(), True)])" + assert ( + str(s) == "StructType([StructField('Value', IntegerType(), True), StructField('Total', IntegerType(), True)])" + ) + + +def test_column_alias(spark_session): + df = spark_session.range(10) + df_alias = df.select(col("id").alias("my_number")) + assert "my_number" in df_alias.columns, "alias should rename column" + assert df_alias.toPandas()["my_number"].equals(df.toPandas()["id"]), "data should be unchanged" + + +def test_column_cast(spark_session): + df = spark_session.range(10) + df_cast = df.select(col("id").cast(StringType())) + assert df_cast.schema.fields[0].dataType == StringType(), "cast should change data type" + assert df_cast.toPandas()["id"].dtype == "object", "cast should change pandas dtype to object/string" + + +def test_column_null_checks(spark_session): + df = spark_session.range(10) + df_null = df.select(col("id").isNotNull().alias("not_null"), col("id").isNull().alias("is_null")) + assert df_null.toPandas()["not_null"].iloc[0], "isNotNull should be True for non-null values" + assert not df_null.toPandas()["is_null"].iloc[0], "isNull should be False for non-null values" + + +def test_column_name(spark_session): + df = spark_session.range(10) + df_name = df.select(col("id").name("renamed_id")) + assert "renamed_id" in df_name.columns, "name should rename column" + assert df_name.toPandas()["renamed_id"].equals(df.toPandas()["id"]), "data should be unchanged" + + +def test_range_operation(spark_session): + # Create a range using Spark + # For example, creating a range from 0 to 9 + spark_range = spark_session.range(10) # Creates DataFrame with numbers 0 to 9 + + # Convert to Pandas DataFrame + pandas_df = spark_range.toPandas() + + # Verify the DataFrame has expected values + assert len(pandas_df) == 10, "DataFrame should have 10 rows" + assert list(pandas_df["id"]) == list(range(10)), "DataFrame should contain values 0-9" + + +def test_range_collect(spark_session): + # Create a range using Spark + # For example, creating a range from 0 to 9 + spark_range = spark_session.range(10) # Creates DataFrame with numbers 0 to 9 + + # Collect the data + collected_rows = spark_range.collect() + + # Verify the collected data has expected values + assert len(collected_rows) == 10, "Should have 10 rows" + assert [row["id"] for row in collected_rows] == list(range(10)), "Should contain values 0-9" + + +def test_drop(spark_session): + # Create DataFrame from range(10) + df = spark_session.range(10) + + # Drop the 'id' column + df_dropped = df.drop("id") + + # Verify the drop was successful + assert "id" not in df_dropped.columns, "Column 'id' should be dropped" + assert len(df_dropped.columns) == len(df.columns) - 1, "Should have one less column after drop" + + # Verify the DataFrame has no columns after dropping all columns" + assert len(df_dropped.columns) == 0, "DataFrame should have no columns after dropping 'id'" + + +def test_range_first(spark_session): + spark_range = spark_session.range(10) + first_row = spark_range.first() + assert first_row["id"] == 0, "First row should have id=0" + + +def test_range_limit(spark_session): + spark_range = spark_session.range(10) + limited_df = spark_range.limit(5).toPandas() + assert len(limited_df) == 5, "Limited DataFrame should have 5 rows" + assert list(limited_df["id"]) == list(range(5)), "Limited DataFrame should contain values 0-4" + + +def test_filter(spark_session): + # Create DataFrame from range(10) + df = spark_session.range(10) + + # Filter for values less than 5 + df_filtered = df.filter(col("id") < 5) + + # Verify the schema is unchanged after filter + assert df_filtered.schema == df.schema, "Schema should be unchanged after filter" + + # Verify the filtered data is correct + df_filtered_pandas = df_filtered.toPandas() + assert len(df_filtered_pandas) == 5, "Should have 5 rows after filtering < 5" + assert all(df_filtered_pandas["id"] < 5), "All values should be less than 5" + + +def test_get_attr(spark_session): + # Create DataFrame from range(10) + df = spark_session.range(10) + + # Get column using df[...] + # df.get_attr("id") is equivalent to df["id"] + df_col = df["id"] + + # Check that column values match expected range + values = df.select(df_col).collect() # Changed to select column first + assert len(values) == 10 + assert [row[0] for row in values] == list(range(10)) # Need to extract values from Row objects + + +def test_group_by(spark_session): + # Create DataFrame from range(10) + df = spark_session.range(10) + + # Add a column that will have repeated values for grouping + df = df.withColumn("group", col("id") % 3) + + # Group by the new column and sum the ids in each group + df_grouped = df.groupBy("group").sum("id") + + # Convert to pandas to verify the sums + df_grouped_pandas = df_grouped.toPandas() + + # Sort by group to ensure consistent order for comparison + df_grouped_pandas = df_grouped_pandas.sort_values("group").reset_index(drop=True) + + # Verify the expected sums for each group + # group id + # 0 2 15 + # 1 1 12 + # 2 0 18 + expected = { + "group": [0, 1, 2], + "id": [18, 12, 15], # todo(correctness): should this be "id" for value here? + } + + assert df_grouped_pandas["group"].tolist() == expected["group"] + assert df_grouped_pandas["id"].tolist() == expected["id"] + + +def test_schema(spark_session): + # Create DataFrame from range(10) + df = spark_session.range(10) + + # Define the expected schema + # in reality should be nullable=False, but daft has all our structs as nullable=True + expected_schema = StructType([StructField("id", LongType(), nullable=True)]) + + # Verify the schema is as expected + assert df.schema == expected_schema, "Schema should match the expected schema" + + +def test_select(spark_session): + # Create DataFrame from range(10) + df = spark_session.range(10) + + # Select just the 'id' column + df_selected = df.select(col("id")) + + # Verify the schema is unchanged since we selected same column + assert df_selected.schema == df.schema, "Schema should be unchanged after selecting same column" + assert len(df_selected.collect()) == 10, "Row count should be unchanged after select" + + # Verify the data is unchanged + df_data = [row["id"] for row in df.collect()] + df_selected_data = [row["id"] for row in df_selected.collect()] + assert df_data == df_selected_data, "Data should be unchanged after select" + + +def test_show(spark_session, capsys): + df = spark_session.range(10) + df.show() + captured = capsys.readouterr() + expected = ( + "╭───────╮\n" + "│ id │\n" + "│ --- │\n" + "│ Int64 │\n" + "╞═══════╡\n" + "│ 0 │\n" + "├╌╌╌╌╌╌╌┤\n" + "│ 1 │\n" + "├╌╌╌╌╌╌╌┤\n" + "│ 2 │\n" + "├╌╌╌╌╌╌╌┤\n" + "│ 3 │\n" + "├╌╌╌╌╌╌╌┤\n" + "│ 4 │\n" + "├╌╌╌╌╌╌╌┤\n" + "│ 5 │\n" + "├╌╌╌╌╌╌╌┤\n" + "│ 6 │\n" + "├╌╌╌╌╌╌╌┤\n" + "│ 7 │\n" + "├╌╌╌╌╌╌╌┤\n" + "│ 8 │\n" + "├╌╌╌╌╌╌╌┤\n" + "│ 9 │\n" + "╰───────╯\n" + ) + assert captured.out == expected + + +def test_take(spark_session): + # Create DataFrame with 10 rows + df = spark_session.range(10) + + # Take first 5 rows and collect + result = df.take(5) + + # Verify the expected values + expected = df.limit(5).collect() + + assert result == expected + + # Test take with more rows than exist + result_large = df.take(20) + expected_large = df.collect() + assert result_large == expected_large # Should return all existing rows + + +def test_numeric_equals(spark_session): + """Test numeric equality comparison with NULL handling.""" + data = [(1, 10), (2, None)] + df = spark_session.createDataFrame(data, ["id", "value"]) + + result = df.withColumn("equals_20", F.col("value") == F.lit(20)).collect() + + assert result[0].equals_20 is False # 10 == 20 + assert result[1].equals_20 is None # NULL == 20 + + +def test_string_equals(spark_session): + """Test string equality comparison with NULL handling.""" + data = [(1, "apple"), (2, None)] + df = spark_session.createDataFrame(data, ["id", "text"]) + + result = df.withColumn("equals_banana", F.col("text") == F.lit("banana")).collect() + + assert result[0].equals_banana is False # apple == banana + assert result[1].equals_banana is None # NULL == banana + + +@pytest.mark.skip(reason="We believe null-safe equals are not yet implemented") +def test_null_safe_equals(spark_session): + """Test null-safe equality comparison.""" + data = [(1, 10), (2, None)] + df = spark_session.createDataFrame(data, ["id", "value"]) + + result = df.withColumn("null_safe_equals", F.col("value").eqNullSafe(F.lit(10))).collect() + + assert result[0].null_safe_equals is True # 10 <=> 10 + assert result[1].null_safe_equals is False # NULL <=> 10 + + +def test_not(spark_session): + """Test logical NOT operation with NULL handling.""" + data = [(True,), (False,), (None,)] + df = spark_session.createDataFrame(data, ["value"]) + + result = df.withColumn("not_value", ~F.col("value")).collect() + + assert result[0].not_value is False # NOT True + assert result[1].not_value is True # NOT False + assert result[2].not_value is None # NOT NULL + + +def test_with_column(spark_session): + # Create DataFrame from range(10) + df = spark_session.range(10) + + # Add a new column that's a boolean indicating if id > 2 + df_with_col = df.withColumn("double_id", col("id") > 2) + + # Verify the schema has both columns + assert "id" in df_with_col.schema.names, "Original column should still exist" + assert "double_id" in df_with_col.schema.names, "New column should be added" + + # Verify the data is correct + df_pandas = df_with_col.toPandas() + assert (df_pandas["double_id"] == (df_pandas["id"] > 2)).all(), "New column should be greater than 2 comparison" + + +def test_with_columns_renamed(spark_session): + # Test withColumnRenamed + df = spark_session.range(5) + renamed_df = df.withColumnRenamed("id", "number") + + collected = renamed_df.collect() + assert len(collected) == 5 + assert "number" in renamed_df.columns + assert "id" not in renamed_df.columns + assert [row["number"] for row in collected] == list(range(5)) + + # todo: this edge case is a spark connect bug; it will only send rename of id -> character over protobuf + # # Test withColumnsRenamed + # df = spark_session.range(2) + # renamed_df = df.withColumnsRenamed({"id": "number", "id": "character"}) + # + # collected = renamed_df.collect() + # assert len(collected) == 2 + # assert set(renamed_df.columns) == {"number", "character"} + # assert "id" not in renamed_df.columns + # assert [(row["number"], row["character"]) for row in collected] == [(0, 0), (1, 1)] diff --git a/tests/connect/test_collect.py b/tests/connect/test_collect.py deleted file mode 100644 index 0a9387dd0b..0000000000 --- a/tests/connect/test_collect.py +++ /dev/null @@ -1,14 +0,0 @@ -from __future__ import annotations - - -def test_range_collect(spark_session): - # Create a range using Spark - # For example, creating a range from 0 to 9 - spark_range = spark_session.range(10) # Creates DataFrame with numbers 0 to 9 - - # Collect the data - collected_rows = spark_range.collect() - - # Verify the collected data has expected values - assert len(collected_rows) == 10, "Should have 10 rows" - assert [row["id"] for row in collected_rows] == list(range(10)), "Should contain values 0-9" diff --git a/tests/connect/test_config_simple.py b/tests/connect/test_config.py similarity index 100% rename from tests/connect/test_config_simple.py rename to tests/connect/test_config.py diff --git a/tests/connect/test_create_df.py b/tests/connect/test_create.py similarity index 100% rename from tests/connect/test_create_df.py rename to tests/connect/test_create.py diff --git a/tests/connect/test_distinct.py b/tests/connect/test_distinct.py new file mode 100644 index 0000000000..110a0dcc03 --- /dev/null +++ b/tests/connect/test_distinct.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +from pyspark.sql import Row + + +def test_distinct(spark_session): + # Create simple DataFrame with single column + data = [(1,), (2,), (1,)] + df = spark_session.createDataFrame(data, ["id"]).distinct() + + assert df.count() == 2, "DataFrame should have 2 rows" + + assert df.sort().collect() == [Row(id=1), Row(id=2)], "DataFrame should contain expected values" diff --git a/tests/connect/test_drop.py b/tests/connect/test_drop.py deleted file mode 100644 index 11f640fc82..0000000000 --- a/tests/connect/test_drop.py +++ /dev/null @@ -1,16 +0,0 @@ -from __future__ import annotations - - -def test_drop(spark_session): - # Create DataFrame from range(10) - df = spark_session.range(10) - - # Drop the 'id' column - df_dropped = df.drop("id") - - # Verify the drop was successful - assert "id" not in df_dropped.columns, "Column 'id' should be dropped" - assert len(df_dropped.columns) == len(df.columns) - 1, "Should have one less column after drop" - - # Verify the DataFrame has no columns after dropping all columns" - assert len(df_dropped.columns) == 0, "DataFrame should have no columns after dropping 'id'" diff --git a/tests/connect/test_filter.py b/tests/connect/test_filter.py deleted file mode 100644 index 1586c7e7b5..0000000000 --- a/tests/connect/test_filter.py +++ /dev/null @@ -1,19 +0,0 @@ -from __future__ import annotations - -from pyspark.sql.functions import col - - -def test_filter(spark_session): - # Create DataFrame from range(10) - df = spark_session.range(10) - - # Filter for values less than 5 - df_filtered = df.filter(col("id") < 5) - - # Verify the schema is unchanged after filter - assert df_filtered.schema == df.schema, "Schema should be unchanged after filter" - - # Verify the filtered data is correct - df_filtered_pandas = df_filtered.toPandas() - assert len(df_filtered_pandas) == 5, "Should have 5 rows after filtering < 5" - assert all(df_filtered_pandas["id"] < 5), "All values should be less than 5" diff --git a/tests/connect/test_get_attr.py b/tests/connect/test_get_attr.py deleted file mode 100644 index c48f114091..0000000000 --- a/tests/connect/test_get_attr.py +++ /dev/null @@ -1,15 +0,0 @@ -from __future__ import annotations - - -def test_get_attr(spark_session): - # Create DataFrame from range(10) - df = spark_session.range(10) - - # Get column using df[...] - # df.get_attr("id") is equivalent to df["id"] - df_col = df["id"] - - # Check that column values match expected range - values = df.select(df_col).collect() # Changed to select column first - assert len(values) == 10 - assert [row[0] for row in values] == list(range(10)) # Need to extract values from Row objects diff --git a/tests/connect/test_group_by.py b/tests/connect/test_group_by.py deleted file mode 100644 index 1a83526732..0000000000 --- a/tests/connect/test_group_by.py +++ /dev/null @@ -1,33 +0,0 @@ -from __future__ import annotations - -from pyspark.sql.functions import col - - -def test_group_by(spark_session): - # Create DataFrame from range(10) - df = spark_session.range(10) - - # Add a column that will have repeated values for grouping - df = df.withColumn("group", col("id") % 3) - - # Group by the new column and sum the ids in each group - df_grouped = df.groupBy("group").sum("id") - - # Convert to pandas to verify the sums - df_grouped_pandas = df_grouped.toPandas() - - # Sort by group to ensure consistent order for comparison - df_grouped_pandas = df_grouped_pandas.sort_values("group").reset_index(drop=True) - - # Verify the expected sums for each group - # group id - # 0 2 15 - # 1 1 12 - # 2 0 18 - expected = { - "group": [0, 1, 2], - "id": [18, 12, 15], # todo(correctness): should this be "id" for value here? - } - - assert df_grouped_pandas["group"].tolist() == expected["group"] - assert df_grouped_pandas["id"].tolist() == expected["id"] diff --git a/tests/connect/test_csv.py b/tests/connect/test_io.py similarity index 76% rename from tests/connect/test_csv.py rename to tests/connect/test_io.py index 7e957dd394..c82cf9a01b 100644 --- a/tests/connect/test_csv.py +++ b/tests/connect/test_io.py @@ -1,6 +1,7 @@ from __future__ import annotations import os +import tempfile import pytest @@ -86,3 +87,27 @@ def test_write_csv_with_compression(spark_session, tmp_path): df_pandas = df.toPandas() df_read_pandas = df_read.toPandas() assert df_pandas["id"].equals(df_read_pandas["id"]) + + +def test_write_parquet(spark_session): + with tempfile.TemporaryDirectory() as temp_dir: + # Create DataFrame from range(10) + df = spark_session.range(10) + + # Write DataFrame to parquet directory + parquet_dir = os.path.join(temp_dir, "test.parquet") + df.write.parquet(parquet_dir) + + # List all files in the parquet directory + parquet_files = [f for f in os.listdir(parquet_dir) if f.endswith(".parquet")] + + # Assert there is at least one parquet file + assert len(parquet_files) > 0, "Expected at least one parquet file to be written" + + # Read back from the parquet directory (not specific file) + df_read = spark_session.read.parquet(parquet_dir) + + # Verify the data is unchanged + df_pandas = df.toPandas() + df_read_pandas = df_read.toPandas() + assert df_pandas["id"].equals(df_read_pandas["id"]), "Data should be unchanged after write/read" diff --git a/tests/connect/test_limit_simple.py b/tests/connect/test_limit_simple.py deleted file mode 100644 index d5f2c97dae..0000000000 --- a/tests/connect/test_limit_simple.py +++ /dev/null @@ -1,14 +0,0 @@ -from __future__ import annotations - - -def test_range_first(spark_session): - spark_range = spark_session.range(10) - first_row = spark_range.first() - assert first_row["id"] == 0, "First row should have id=0" - - -def test_range_limit(spark_session): - spark_range = spark_session.range(10) - limited_df = spark_range.limit(5).toPandas() - assert len(limited_df) == 5, "Limited DataFrame should have 5 rows" - assert list(limited_df["id"]) == list(range(5)), "Limited DataFrame should contain values 0-4" diff --git a/tests/connect/test_parquet.py b/tests/connect/test_parquet.py deleted file mode 100644 index 153af01ead..0000000000 --- a/tests/connect/test_parquet.py +++ /dev/null @@ -1,28 +0,0 @@ -from __future__ import annotations - -import os -import tempfile - - -def test_write_parquet(spark_session): - with tempfile.TemporaryDirectory() as temp_dir: - # Create DataFrame from range(10) - df = spark_session.range(10) - - # Write DataFrame to parquet directory - parquet_dir = os.path.join(temp_dir, "test.parquet") - df.write.parquet(parquet_dir) - - # List all files in the parquet directory - parquet_files = [f for f in os.listdir(parquet_dir) if f.endswith(".parquet")] - - # Assert there is at least one parquet file - assert len(parquet_files) > 0, "Expected at least one parquet file to be written" - - # Read back from the parquet directory (not specific file) - df_read = spark_session.read.parquet(parquet_dir) - - # Verify the data is unchanged - df_pandas = df.toPandas() - df_read_pandas = df_read.toPandas() - assert df_pandas["id"].equals(df_read_pandas["id"]), "Data should be unchanged after write/read" diff --git a/tests/connect/test_print_schema.py b/tests/connect/test_print_schema.py index 60c98e85b6..38afd193fd 100644 --- a/tests/connect/test_print_schema.py +++ b/tests/connect/test_print_schema.py @@ -8,7 +8,7 @@ def test_print_schema_range(spark_session, capsys) -> None: df.printSchema() captured = capsys.readouterr() - expected = "root\n" " |-- id: long (nullable = true)\n\n" + expected = "root\n |-- id: long (nullable = true)\n\n" assert captured.out == expected @@ -18,7 +18,7 @@ def test_print_schema_simple_df(spark_session, capsys) -> None: df.printSchema() captured = capsys.readouterr() - expected = "root\n" " |-- value: long (nullable = true)\n\n" + expected = "root\n |-- value: long (nullable = true)\n\n" assert captured.out == expected @@ -43,7 +43,7 @@ def test_print_schema_floating_point(spark_session, capsys) -> None: df.printSchema() captured = capsys.readouterr() - expected = "root\n" " |-- amount: double (nullable = true)\n\n" + expected = "root\n |-- amount: double (nullable = true)\n\n" assert captured.out == expected @@ -53,7 +53,7 @@ def test_print_schema_with_nulls(spark_session, capsys) -> None: df.printSchema() captured = capsys.readouterr() - expected = "root\n" " |-- id: long (nullable = true)\n" " |-- value: string (nullable = true)\n\n" + expected = "root\n |-- id: long (nullable = true)\n |-- value: string (nullable = true)\n\n" assert captured.out == expected diff --git a/tests/connect/test_range_simple.py b/tests/connect/test_range_simple.py deleted file mode 100644 index b277d38481..0000000000 --- a/tests/connect/test_range_simple.py +++ /dev/null @@ -1,14 +0,0 @@ -from __future__ import annotations - - -def test_range_operation(spark_session): - # Create a range using Spark - # For example, creating a range from 0 to 9 - spark_range = spark_session.range(10) # Creates DataFrame with numbers 0 to 9 - - # Convert to Pandas DataFrame - pandas_df = spark_range.toPandas() - - # Verify the DataFrame has expected values - assert len(pandas_df) == 10, "DataFrame should have 10 rows" - assert list(pandas_df["id"]) == list(range(10)), "DataFrame should contain values 0-9" diff --git a/tests/connect/test_schema.py b/tests/connect/test_schema.py deleted file mode 100644 index 1f4e1182fa..0000000000 --- a/tests/connect/test_schema.py +++ /dev/null @@ -1,15 +0,0 @@ -from __future__ import annotations - -from pyspark.sql.types import LongType, StructField, StructType - - -def test_schema(spark_session): - # Create DataFrame from range(10) - df = spark_session.range(10) - - # Define the expected schema - # in reality should be nullable=False, but daft has all our structs as nullable=True - expected_schema = StructType([StructField("id", LongType(), nullable=True)]) - - # Verify the schema is as expected - assert df.schema == expected_schema, "Schema should match the expected schema" diff --git a/tests/connect/test_select.py b/tests/connect/test_select.py deleted file mode 100644 index 9cee95e1ef..0000000000 --- a/tests/connect/test_select.py +++ /dev/null @@ -1,20 +0,0 @@ -from __future__ import annotations - -from pyspark.sql.functions import col - - -def test_select(spark_session): - # Create DataFrame from range(10) - df = spark_session.range(10) - - # Select just the 'id' column - df_selected = df.select(col("id")) - - # Verify the schema is unchanged since we selected same column - assert df_selected.schema == df.schema, "Schema should be unchanged after selecting same column" - assert len(df_selected.collect()) == 10, "Row count should be unchanged after select" - - # Verify the data is unchanged - df_data = [row["id"] for row in df.collect()] - df_selected_data = [row["id"] for row in df_selected.collect()] - assert df_data == df_selected_data, "Data should be unchanged after select" diff --git a/tests/connect/test_show.py b/tests/connect/test_show.py deleted file mode 100644 index 3d4234d70b..0000000000 --- a/tests/connect/test_show.py +++ /dev/null @@ -1,35 +0,0 @@ -from __future__ import annotations - - -def test_show(spark_session, capsys): - df = spark_session.range(10) - df.show() - captured = capsys.readouterr() - expected = ( - "╭───────╮\n" - "│ id │\n" - "│ --- │\n" - "│ Int64 │\n" - "╞═══════╡\n" - "│ 0 │\n" - "├╌╌╌╌╌╌╌┤\n" - "│ 1 │\n" - "├╌╌╌╌╌╌╌┤\n" - "│ 2 │\n" - "├╌╌╌╌╌╌╌┤\n" - "│ 3 │\n" - "├╌╌╌╌╌╌╌┤\n" - "│ 4 │\n" - "├╌╌╌╌╌╌╌┤\n" - "│ 5 │\n" - "├╌╌╌╌╌╌╌┤\n" - "│ 6 │\n" - "├╌╌╌╌╌╌╌┤\n" - "│ 7 │\n" - "├╌╌╌╌╌╌╌┤\n" - "│ 8 │\n" - "├╌╌╌╌╌╌╌┤\n" - "│ 9 │\n" - "╰───────╯\n" - ) - assert captured.out == expected diff --git a/tests/connect/test_spark_sql.py b/tests/connect/test_spark_sql.py new file mode 100644 index 0000000000..636e92b263 --- /dev/null +++ b/tests/connect/test_spark_sql.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +import pytest + + +def test_create_or_replace_temp_view(spark_session): + df = spark_session.createDataFrame([(1, "foo")], ["id", "name"]) + try: + df.createOrReplaceTempView("test_view") + except Exception as e: + pytest.fail(f"createOrReplaceTempView failed: {e}") + + +def test_sql(spark_session): + df = spark_session.createDataFrame([(1, "foo")], ["id", "name"]) + df.createOrReplaceTempView("test_view") + try: + result = spark_session.sql("SELECT * FROM test_view") + except Exception as e: + pytest.fail(f"sql failed: {e}") + assert result.collect() == [(1, "foo")] diff --git a/tests/connect/test_take.py b/tests/connect/test_take.py deleted file mode 100644 index 2e7809f232..0000000000 --- a/tests/connect/test_take.py +++ /dev/null @@ -1,19 +0,0 @@ -from __future__ import annotations - - -def test_take(spark_session): - # Create DataFrame with 10 rows - df = spark_session.range(10) - - # Take first 5 rows and collect - result = df.take(5) - - # Verify the expected values - expected = df.limit(5).collect() - - assert result == expected - - # Test take with more rows than exist - result_large = df.take(20) - expected_large = df.collect() - assert result_large == expected_large # Should return all existing rows diff --git a/tests/connect/test_unresolved.py b/tests/connect/test_unresolved.py deleted file mode 100644 index 272c4f48d2..0000000000 --- a/tests/connect/test_unresolved.py +++ /dev/null @@ -1,48 +0,0 @@ -import pytest -from pyspark.sql import functions as F - - -def test_numeric_equals(spark_session): - """Test numeric equality comparison with NULL handling.""" - data = [(1, 10), (2, None)] - df = spark_session.createDataFrame(data, ["id", "value"]) - - result = df.withColumn("equals_20", F.col("value") == F.lit(20)).collect() - - assert result[0].equals_20 is False # 10 == 20 - assert result[1].equals_20 is None # NULL == 20 - - -def test_string_equals(spark_session): - """Test string equality comparison with NULL handling.""" - data = [(1, "apple"), (2, None)] - df = spark_session.createDataFrame(data, ["id", "text"]) - - result = df.withColumn("equals_banana", F.col("text") == F.lit("banana")).collect() - - assert result[0].equals_banana is False # apple == banana - assert result[1].equals_banana is None # NULL == banana - - -@pytest.mark.skip(reason="We believe null-safe equals are not yet implemented") -def test_null_safe_equals(spark_session): - """Test null-safe equality comparison.""" - data = [(1, 10), (2, None)] - df = spark_session.createDataFrame(data, ["id", "value"]) - - result = df.withColumn("null_safe_equals", F.col("value").eqNullSafe(F.lit(10))).collect() - - assert result[0].null_safe_equals is True # 10 <=> 10 - assert result[1].null_safe_equals is False # NULL <=> 10 - - -def test_not(spark_session): - """Test logical NOT operation with NULL handling.""" - data = [(True,), (False,), (None,)] - df = spark_session.createDataFrame(data, ["value"]) - - result = df.withColumn("not_value", ~F.col("value")).collect() - - assert result[0].not_value is False # NOT True - assert result[1].not_value is True # NOT False - assert result[2].not_value is None # NOT NULL diff --git a/tests/connect/test_with_column.py b/tests/connect/test_with_column.py deleted file mode 100644 index ad237339b2..0000000000 --- a/tests/connect/test_with_column.py +++ /dev/null @@ -1,19 +0,0 @@ -from __future__ import annotations - -from pyspark.sql.functions import col - - -def test_with_column(spark_session): - # Create DataFrame from range(10) - df = spark_session.range(10) - - # Add a new column that's a boolean indicating if id > 2 - df_with_col = df.withColumn("double_id", col("id") > 2) - - # Verify the schema has both columns - assert "id" in df_with_col.schema.names, "Original column should still exist" - assert "double_id" in df_with_col.schema.names, "New column should be added" - - # Verify the data is correct - df_pandas = df_with_col.toPandas() - assert (df_pandas["double_id"] == (df_pandas["id"] > 2)).all(), "New column should be greater than 2 comparison" diff --git a/tests/connect/test_with_columns_renamed.py b/tests/connect/test_with_columns_renamed.py deleted file mode 100644 index 124f142ca2..0000000000 --- a/tests/connect/test_with_columns_renamed.py +++ /dev/null @@ -1,24 +0,0 @@ -from __future__ import annotations - - -def test_with_columns_renamed(spark_session): - # Test withColumnRenamed - df = spark_session.range(5) - renamed_df = df.withColumnRenamed("id", "number") - - collected = renamed_df.collect() - assert len(collected) == 5 - assert "number" in renamed_df.columns - assert "id" not in renamed_df.columns - assert [row["number"] for row in collected] == list(range(5)) - - # todo: this edge case is a spark connect bug; it will only send rename of id -> character over protobuf - # # Test withColumnsRenamed - # df = spark_session.range(2) - # renamed_df = df.withColumnsRenamed({"id": "number", "id": "character"}) - # - # collected = renamed_df.collect() - # assert len(collected) == 2 - # assert set(renamed_df.columns) == {"number", "character"} - # assert "id" not in renamed_df.columns - # assert [(row["number"], row["character"]) for row in collected] == [(0, 0), (1, 1)] diff --git a/tests/expressions/test_null_safe_equals.py b/tests/expressions/test_null_safe_equals.py new file mode 100644 index 0000000000..3f01f748d4 --- /dev/null +++ b/tests/expressions/test_null_safe_equals.py @@ -0,0 +1,402 @@ +from __future__ import annotations + +import pyarrow as pa +import pytest + +from daft.expressions import col, lit +from daft.table import MicroPartition + + +@pytest.mark.parametrize( + "data,value,expected_values", + [ + ([(1, 10), (2, None)], 10, [True, False]), # 10 <=> 10, NULL <=> 10 + ([(1, None), (2, None), (3, 10)], None, [True, True, False]), # NULL <=> NULL, NULL <=> NULL, 10 <=> NULL + ([(1, 10), (2, 20)], 10, [True, False]), # 10 <=> 10, 20 <=> 10 + ], +) +def test_null_safe_equals_basic(data, value, expected_values): + """Test basic null-safe equality comparison.""" + # Create a table with the test data + table = MicroPartition.from_pydict( + { + "id": [x[0] for x in data], + "value": [x[1] for x in data], + } + ) + + # Apply the null-safe equals operation + result = table.eval_expression_list([col("value").eq_null_safe(lit(value))]) + result_values = result.get_column("value").to_pylist() + + # Verify results + assert result_values == expected_values + + +@pytest.mark.parametrize( + "type_name,test_value,test_data,expected_values", + [ + ("int", 10, [(1, 10), (2, None), (3, 20)], [True, False, False]), + ("string", "hello", [(1, "hello"), (2, None), (3, "world")], [True, False, False]), + ("boolean", True, [(1, True), (2, None), (3, False)], [True, False, False]), + ("float", 1.5, [(1, 1.5), (2, None), (3, 2.5)], [True, False, False]), + ("binary", b"hello", [(1, b"hello"), (2, None), (3, b"world")], [True, False, False]), + ("fixed_size_binary", b"aaa", [(1, b"aaa"), (2, None), (3, b"bbb")], [True, False, False]), + ("null", None, [(1, 10), (2, None), (3, 20)], [False, True, False]), + ], +) +def test_null_safe_equals_types(type_name, test_value, test_data, expected_values): + """Test null-safe equality with different data types.""" + # Create a table with the test data + if type_name == "fixed_size_binary": + # Convert to PyArrow array for fixed size binary + value_array = pa.array([x[1] for x in test_data], type=pa.binary(3)) + table = MicroPartition.from_pydict( + { + "id": [x[0] for x in test_data], + "value": value_array.to_pylist(), + } + ) + else: + table = MicroPartition.from_pydict( + { + "id": [x[0] for x in test_data], + "value": [x[1] for x in test_data], + } + ) + + # Apply the null-safe equals operation + result = table.eval_expression_list([col("value").eq_null_safe(lit(test_value))]) + result_values = result.get_column("value").to_pylist() + + # Verify results + assert result_values == expected_values, f"Failed for {type_name} comparison" + + +@pytest.mark.parametrize( + "data,expected_values", + [ + ([(10, 10), (None, None), (10, None), (None, 10), (10, 20)], [True, True, False, False, False]), + ( + [("hello", "hello"), (None, None), ("hello", None), (None, "hello"), ("hello", "world")], + [True, True, False, False, False], + ), + ([(True, True), (None, None), (True, None), (None, True), (True, False)], [True, True, False, False, False]), + ([(1.5, 1.5), (None, None), (1.5, None), (None, 1.5), (1.5, 2.5)], [True, True, False, False, False]), + ( + [(b"hello", b"hello"), (None, None), (b"hello", None), (None, b"hello"), (b"hello", b"world")], + [True, True, False, False, False], + ), + ( + [(b"aaa", b"aaa"), (None, None), (b"aaa", None), (None, b"aaa"), (b"aaa", b"bbb")], + [True, True, False, False, False], + ), + ], +) +def test_null_safe_equals_column_comparison(data, expected_values): + """Test null-safe equality between two columns.""" + # Check if this is a fixed-size binary test case + is_fixed_binary = isinstance(data[0][0], bytes) and len(data[0][0]) == 3 + + # Create a table with the test data + if is_fixed_binary: + # Convert to PyArrow arrays for fixed size binary + left_array = pa.array([x[0] for x in data], type=pa.binary(3)) + right_array = pa.array([x[1] for x in data], type=pa.binary(3)) + table = MicroPartition.from_pydict( + { + "left": left_array.to_pylist(), + "right": right_array.to_pylist(), + } + ) + else: + table = MicroPartition.from_pydict( + { + "left": [x[0] for x in data], + "right": [x[1] for x in data], + } + ) + + # Apply the null-safe equals operation + result = table.eval_expression_list([col("left").eq_null_safe(col("right"))]) + result_values = result.get_column("left").to_pylist() + + # Verify results + assert result_values == expected_values + + +@pytest.mark.parametrize( + "filter_value,data,expected_ids", + [ + (10, [(1, 10), (2, None), (3, 20), (4, None)], {1}), # Only id=1 has value=10 + (None, [(1, 10), (2, None), (3, 20), (4, None)], {2, 4}), # id=2 and id=4 have NULL values + (20, [(1, 10), (2, None), (3, 20), (4, None)], {3}), # Only id=3 has value=20 + ], +) +def test_null_safe_equals_in_filter(filter_value, data, expected_ids): + """Test using null-safe equality in filter.""" + # Create a table with the test data + table = MicroPartition.from_pydict( + { + "id": [x[0] for x in data], + "value": [x[1] for x in data], + } + ) + + # Apply the filter with null-safe equals + result = table.filter([col("value").eq_null_safe(lit(filter_value))]) + result_ids = set(result.get_column("id").to_pylist()) + + # Verify results + assert result_ids == expected_ids + + +@pytest.mark.parametrize( + "operation,test_value,data,expected_values", + [ + ( + "NOT", + 10, + [(1, 10), (2, None), (3, 20)], + [False, True, True], # NOT (10 <=> 10), NOT (NULL <=> 10), NOT (20 <=> 10) + ), + ], +) +def test_null_safe_equals_chained_operations(operation, test_value, data, expected_values): + """Test chaining null-safe equality with other operations.""" + # Create a table with the test data + table = MicroPartition.from_pydict( + { + "id": [x[0] for x in data], + "value": [x[1] for x in data], + } + ) + + # Apply the operation + if operation == "NOT": + result = table.eval_expression_list([~col("value").eq_null_safe(lit(test_value))]) + + result_values = result.get_column("value").to_pylist() + + # Verify results + assert result_values == expected_values + + +def test_null_safe_equals_fixed_size_binary(): + """Test null-safe equality specifically for fixed-size binary arrays.""" + # Create arrays with fixed size binary data + l_arrow = pa.array([b"11111", b"22222", b"33333", None, b"12345", None], type=pa.binary(5)) + r_arrow = pa.array([b"11111", b"33333", b"11111", b"12345", None, None], type=pa.binary(5)) + + # Create table with these arrays + table = MicroPartition.from_pydict( + { + "left": l_arrow.to_pylist(), + "right": r_arrow.to_pylist(), + } + ) + + # Test column to column comparison + result = table.eval_expression_list([col("left").eq_null_safe(col("right"))]) + result_values = result.get_column("left").to_pylist() + assert result_values == [True, False, False, False, False, True] # True for equal values and both-null cases + + # Test column to scalar comparisons + test_cases = [ + (b"11111", [True, False, False, False, False, False]), # Matches first value + (b"22222", [False, True, False, False, False, False]), # Matches second value + (None, [False, False, False, True, False, True]), # True for all null values + ] + + for test_value, expected in test_cases: + result = table.eval_expression_list([col("left").eq_null_safe(lit(test_value))]) + result_values = result.get_column("left").to_pylist() + assert result_values == expected, f"Failed for test value: {test_value}" + + # Test the reverse comparison as well + result = table.eval_expression_list([col("right").eq_null_safe(lit(test_value))]) + result_values = result.get_column("right").to_pylist() + expected_reverse = [ + True if test_value == b"11111" else False, # First value is "11111" + False, # Second value is "33333" + True if test_value == b"11111" else False, # Third value is "11111" + True if test_value == b"12345" else False, # Fourth value is "12345" + True if test_value is None else False, # Fifth value is None + True if test_value is None else False, # Sixth value is None + ] + assert result_values == expected_reverse, f"Failed for reverse test value: {test_value}" + + +@pytest.mark.parametrize( + "type_name,left_data,right_data,expected_values", + [ + ("int", [1, 2, 3], [1, 2, 4], [True, True, False]), + ("float", [1.0, 2.0, 3.0], [1.0, 2.0, 4.0], [True, True, False]), + ("boolean", [True, False, True], [True, False, False], [True, True, False]), + ("string", ["a", "b", "c"], ["a", "b", "d"], [True, True, False]), + ("binary", [b"a", b"b", b"c"], [b"a", b"b", b"d"], [True, True, False]), + ("fixed_size_binary", [b"aaa", b"bbb", b"ccc"], [b"aaa", b"bbb", b"ddd"], [True, True, False]), + ], +) +def test_no_nulls_all_types(type_name, left_data, right_data, expected_values): + """Test null-safe equality with no nulls in either array for all data types.""" + if type_name == "fixed_size_binary": + # Convert to PyArrow arrays for fixed size binary + left_array = pa.array(left_data, type=pa.binary(3)) + right_array = pa.array(right_data, type=pa.binary(3)) + table = MicroPartition.from_pydict( + { + "left": left_array.to_pylist(), + "right": right_array.to_pylist(), + } + ) + else: + table = MicroPartition.from_pydict( + { + "left": left_data, + "right": right_data, + } + ) + + result = table.eval_expression_list([col("left").eq_null_safe(col("right"))]) + result_values = result.get_column("left").to_pylist() + + assert result_values == expected_values, f"Failed for {type_name} comparison" + + +@pytest.mark.parametrize( + "type_name,left_data,right_data,expected_values", + [ + ("int", [1, 2, 3], [1, None, 3], [True, False, True]), + ("float", [1.0, 2.0, 3.0], [1.0, None, 3.0], [True, False, True]), + ("boolean", [True, False, True], [True, None, True], [True, False, True]), + ("string", ["a", "b", "c"], ["a", None, "c"], [True, False, True]), + ("binary", [b"a", b"b", b"c"], [b"a", None, b"c"], [True, False, True]), + ("fixed_size_binary", [b"aaa", b"bbb", b"ccc"], [b"aaa", None, b"ccc"], [True, False, True]), + ], +) +def test_right_nulls_all_types(type_name, left_data, right_data, expected_values): + """Test null-safe equality where left array has no nulls and right array has some nulls.""" + if type_name == "fixed_size_binary": + # Convert to PyArrow arrays for fixed size binary + left_array = pa.array(left_data, type=pa.binary(3)) + right_array = pa.array(right_data, type=pa.binary(3)) + table = MicroPartition.from_pydict( + { + "left": left_array.to_pylist(), + "right": right_array.to_pylist(), + } + ) + else: + table = MicroPartition.from_pydict( + { + "left": left_data, + "right": right_data, + } + ) + + result = table.eval_expression_list([col("left").eq_null_safe(col("right"))]) + result_values = result.get_column("left").to_pylist() + + assert result_values == expected_values, f"Failed for {type_name} comparison" + + +@pytest.mark.parametrize( + "type_name,left_data,right_data,expected_values", + [ + ("int", [1, None, 3], [1, 2, 3], [True, False, True]), + ("float", [1.0, None, 3.0], [1.0, 2.0, 3.0], [True, False, True]), + ("boolean", [True, None, True], [True, False, True], [True, False, True]), + ("string", ["a", None, "c"], ["a", "b", "c"], [True, False, True]), + ("binary", [b"a", None, b"c"], [b"a", b"b", b"c"], [True, False, True]), + ("fixed_size_binary", [b"aaa", None, b"ccc"], [b"aaa", b"bbb", b"ccc"], [True, False, True]), + ], +) +def test_left_nulls_all_types(type_name, left_data, right_data, expected_values): + """Test null-safe equality where left array has some nulls and right array has no nulls.""" + if type_name == "fixed_size_binary": + # Convert to PyArrow arrays for fixed size binary + left_array = pa.array(left_data, type=pa.binary(3)) + right_array = pa.array(right_data, type=pa.binary(3)) + table = MicroPartition.from_pydict( + { + "left": left_array.to_pylist(), + "right": right_array.to_pylist(), + } + ) + else: + table = MicroPartition.from_pydict( + { + "left": left_data, + "right": right_data, + } + ) + + result = table.eval_expression_list([col("left").eq_null_safe(col("right"))]) + result_values = result.get_column("left").to_pylist() + + assert result_values == expected_values, f"Failed for {type_name} comparison" + + +@pytest.mark.parametrize( + "type_name,left_data,right_data", + [ + ("int", [1, 2, 3], [1, 2]), + ("float", [1.0, 2.0, 3.0], [1.0, 2.0]), + ("boolean", [True, False, True], [True, False]), + ("string", ["a", "b", "c"], ["a", "b"]), + ("binary", [b"a", b"b", b"c"], [b"a", b"b"]), + ("fixed_size_binary", [b"aaa", b"bbb", b"ccc"], [b"aaa", b"bbb"]), + ], +) +def test_length_mismatch_all_types(type_name, left_data, right_data): + """Test that length mismatches raise appropriate error for all data types.""" + # Create two separate tables + left_table = MicroPartition.from_pydict({"value": left_data}) + right_table = MicroPartition.from_pydict({"value": right_data}) + + with pytest.raises(ValueError) as exc_info: + result = left_table.eval_expression_list([col("value").eq_null_safe(right_table.get_column("value"))]) + # Force evaluation by accessing the result + result.get_column("value").to_pylist() + + # Verify error message format + error_msg = str(exc_info.value) + assert "trying to compare different length arrays" in error_msg + + +@pytest.mark.parametrize( + "type_name,left_data,right_data,expected_values", + [ + ("int", [None, None, 1], [None, None, 2], [True, True, False]), + ("float", [None, None, 1.0], [None, None, 2.0], [True, True, False]), + ("boolean", [None, None, True], [None, None, False], [True, True, False]), + ("string", [None, None, "a"], [None, None, "b"], [True, True, False]), + ("binary", [None, None, b"a"], [None, None, b"b"], [True, True, False]), + ("fixed_size_binary", [None, None, b"aaa"], [None, None, b"bbb"], [True, True, False]), + ], +) +def test_null_equals_null_all_types(type_name, left_data, right_data, expected_values): + """Test that NULL <=> NULL returns True for all data types.""" + if type_name == "fixed_size_binary": + # Convert to PyArrow arrays for fixed size binary + left_array = pa.array(left_data, type=pa.binary(3)) + right_array = pa.array(right_data, type=pa.binary(3)) + table = MicroPartition.from_pydict( + { + "left": left_array.to_pylist(), + "right": right_array.to_pylist(), + } + ) + else: + table = MicroPartition.from_pydict( + { + "left": left_data, + "right": right_data, + } + ) + + result = table.eval_expression_list([col("left").eq_null_safe(col("right"))]) + result_values = result.get_column("left").to_pylist() + + assert result_values == expected_values, f"Failed for {type_name} comparison" diff --git a/tests/io/test_s3_credentials_refresh.py b/tests/io/test_s3_credentials_refresh.py index 1b9aeccc8e..ec9a1113e0 100644 --- a/tests/io/test_s3_credentials_refresh.py +++ b/tests/io/test_s3_credentials_refresh.py @@ -10,6 +10,8 @@ import pytest import daft +import daft.context +from tests.conftest import get_tests_daft_runner_name from tests.io.mock_aws_server import start_service, stop_process @@ -72,7 +74,7 @@ def get_credentials(): key_id=aws_credentials["AWS_ACCESS_KEY_ID"], access_key=aws_credentials["AWS_SECRET_ACCESS_KEY"], session_token=aws_credentials["AWS_SESSION_TOKEN"], - expiry=(datetime.datetime.now() + datetime.timedelta(seconds=1)), + expiry=(datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(seconds=1)), ) static_config = daft.io.IOConfig( @@ -113,7 +115,26 @@ def get_credentials(): assert count_get_credentials == 2 df.write_parquet(output_file_path, io_config=dynamic_config) - assert count_get_credentials == 2 + + is_ray_runner = ( + get_tests_daft_runner_name() == "ray" + ) # hack because ray runner will not increment `count_get_credentials` + assert count_get_credentials == 2 or is_ray_runner + + df2 = daft.read_parquet(output_file_path, io_config=static_config) + + assert df.to_arrow() == df2.to_arrow() + + df.write_parquet(output_file_path, io_config=dynamic_config, write_mode="overwrite") + assert count_get_credentials == 2 or is_ray_runner + + df2 = daft.read_parquet(output_file_path, io_config=static_config) + + assert df.to_arrow() == df2.to_arrow() + + time.sleep(1) + df.write_parquet(output_file_path, io_config=dynamic_config, write_mode="overwrite") + assert count_get_credentials == 3 or is_ray_runner df2 = daft.read_parquet(output_file_path, io_config=static_config) diff --git a/tests/io/test_write_modes.py b/tests/io/test_write_modes.py index 40a224f2fa..4fb188ff5c 100644 --- a/tests/io/test_write_modes.py +++ b/tests/io/test_write_modes.py @@ -7,6 +7,20 @@ import daft +@pytest.fixture(scope="function") +def bucket(minio_io_config): + BUCKET = "write-modes-bucket" + + fs = s3fs.S3FileSystem( + key=minio_io_config.s3.key_id, + password=minio_io_config.s3.access_key, + client_kwargs={"endpoint_url": minio_io_config.s3.endpoint_url}, + ) + if not fs.exists(BUCKET): + fs.mkdir(BUCKET) + yield BUCKET + + def write( df: daft.DataFrame, path: str, @@ -56,12 +70,14 @@ def arrange_write_mode_test(existing_data, new_data, path, format, write_mode, p return read_back -@pytest.mark.parametrize("write_mode", ["append", "overwrite"]) -@pytest.mark.parametrize("format", ["csv", "parquet"]) -@pytest.mark.parametrize("num_partitions", [1, 2]) -@pytest.mark.parametrize("partition_cols", [None, ["a"]]) -def test_write_modes_local(tmp_path, write_mode, format, num_partitions, partition_cols): - path = str(tmp_path) +def _run_append_overwrite_test( + path, + write_mode, + format, + num_partitions, + partition_cols, + io_config, +): existing_data = {"a": ["a", "a", "b", "b"], "b": [1, 2, 3, 4]} new_data = { "a": ["a", "a", "b", "b"], @@ -75,7 +91,7 @@ def test_write_modes_local(tmp_path, write_mode, format, num_partitions, partiti format, write_mode, partition_cols, - None, + io_config, ) # Check the data @@ -91,8 +107,48 @@ def test_write_modes_local(tmp_path, write_mode, format, num_partitions, partiti @pytest.mark.parametrize("write_mode", ["append", "overwrite"]) @pytest.mark.parametrize("format", ["csv", "parquet"]) -def test_write_modes_local_empty_data(tmp_path, write_mode, format): - path = str(tmp_path) +@pytest.mark.parametrize("num_partitions", [1, 2]) +@pytest.mark.parametrize("partition_cols", [None, ["a"]]) +def test_append_and_overwrite_local(tmp_path, write_mode, format, num_partitions, partition_cols): + _run_append_overwrite_test( + path=str(tmp_path), + write_mode=write_mode, + format=format, + num_partitions=num_partitions, + partition_cols=partition_cols, + io_config=None, + ) + + +@pytest.mark.integration() +@pytest.mark.parametrize("write_mode", ["append", "overwrite"]) +@pytest.mark.parametrize("format", ["csv", "parquet"]) +@pytest.mark.parametrize("num_partitions", [1, 2]) +@pytest.mark.parametrize("partition_cols", [None, ["a"]]) +def test_append_and_overwrite_s3_minio( + minio_io_config, + bucket, + write_mode, + format, + num_partitions, + partition_cols, +): + _run_append_overwrite_test( + path=f"s3://{bucket}/{uuid.uuid4()!s}", + write_mode=write_mode, + format=format, + num_partitions=num_partitions, + partition_cols=partition_cols, + io_config=minio_io_config, + ) + + +def _run_write_modes_empty_test( + path, + write_mode, + format, + io_config, +): existing_data = {"a": ["a", "a", "b", "b"], "b": ["c", "d", "e", "f"]} new_data = { "a": ["a", "a", "b", "b"], @@ -106,7 +162,7 @@ def test_write_modes_local_empty_data(tmp_path, write_mode, format): format, write_mode, None, - None, + io_config, ) # Check the data @@ -122,95 +178,121 @@ def test_write_modes_local_empty_data(tmp_path, write_mode, format): raise ValueError(f"Unsupported write_mode: {write_mode}") -@pytest.fixture(scope="function") -def bucket(minio_io_config): - BUCKET = "write-modes-bucket" - - fs = s3fs.S3FileSystem( - key=minio_io_config.s3.key_id, - password=minio_io_config.s3.access_key, - client_kwargs={"endpoint_url": minio_io_config.s3.endpoint_url}, +@pytest.mark.parametrize("write_mode", ["append", "overwrite"]) +@pytest.mark.parametrize("format", ["csv", "parquet"]) +def test_write_modes_local_empty_data(tmp_path, write_mode, format): + _run_write_modes_empty_test( + path=str(tmp_path), + write_mode=write_mode, + format=format, + io_config=None, ) - if not fs.exists(BUCKET): - fs.mkdir(BUCKET) - yield BUCKET @pytest.mark.integration() @pytest.mark.parametrize("write_mode", ["append", "overwrite"]) @pytest.mark.parametrize("format", ["csv", "parquet"]) -@pytest.mark.parametrize("num_partitions", [1, 2]) -@pytest.mark.parametrize("partition_cols", [None, ["a"]]) -def test_write_modes_s3_minio( +def test_write_modes_s3_minio_empty_data( minio_io_config, bucket, write_mode, format, - num_partitions, - partition_cols, ): - path = f"s3://{bucket}/{uuid.uuid4()!s}" + _run_write_modes_empty_test( + path=f"s3://{bucket}/{uuid.uuid4()!s}", + write_mode=write_mode, + format=format, + io_config=minio_io_config, + ) + + +OVERWRITE_PARTITION_TEST_CASES = [ + pytest.param( + { + "a": ["a", "a", "b", "b"], + "b": [5, 6, 7, 8], + }, + { + "a": ["a", "a", "b", "b"], + "b": [5, 6, 7, 8], + }, + id="overwrite-all", + ), + pytest.param( + { + "a": ["a", "a"], + "b": [5, 6], + }, + { + "a": ["a", "a", "b", "b"], + "b": [5, 6, 3, 4], + }, + id="overwrite-some", + ), + pytest.param( + { + "a": ["b", "b", "c", "c"], + "b": [9, 10, 11, 12], + }, + { + "a": ["a", "a", "b", "b", "c", "c"], + "b": [1, 2, 9, 10, 11, 12], + }, + id="overwrite-and-append", + ), +] + + +def _run_overwrite_partitions_test( + path, + format, + new_data, + expected_read_back, + io_config, +): existing_data = {"a": ["a", "a", "b", "b"], "b": [1, 2, 3, 4]} - new_data = { - "a": ["a", "a", "b", "b"], - "b": [5, 6, 7, 8], - } read_back = arrange_write_mode_test( - daft.from_pydict(existing_data).into_partitions(num_partitions), - daft.from_pydict(new_data).into_partitions(num_partitions), + daft.from_pydict(existing_data), + daft.from_pydict(new_data), path, format, - write_mode, - partition_cols, - minio_io_config, + "overwrite-partitions", + ["a"], + io_config, ) # Check the data - if write_mode == "append": - assert read_back["a"] == ["a"] * 4 + ["b"] * 4 - assert read_back["b"] == [1, 2, 5, 6, 3, 4, 7, 8] - elif write_mode == "overwrite": - assert read_back["a"] == ["a", "a", "b", "b"] - assert read_back["b"] == [5, 6, 7, 8] - else: - raise ValueError(f"Unsupported write_mode: {write_mode}") + for col in expected_read_back: + assert read_back[col] == expected_read_back[col] + + +@pytest.mark.parametrize("format", ["csv", "parquet"]) +@pytest.mark.parametrize("new_data, expected_read_back", OVERWRITE_PARTITION_TEST_CASES) +def test_overwrite_partitions_local(tmp_path, format, new_data, expected_read_back): + _run_overwrite_partitions_test( + path=str(tmp_path), + format=format, + new_data=new_data, + expected_read_back=expected_read_back, + io_config=None, + ) @pytest.mark.integration() -@pytest.mark.parametrize("write_mode", ["append", "overwrite"]) @pytest.mark.parametrize("format", ["csv", "parquet"]) -def test_write_modes_s3_minio_empty_data( +@pytest.mark.parametrize("new_data, expected_read_back", OVERWRITE_PARTITION_TEST_CASES) +def test_overwrite_partitions_s3_minio( minio_io_config, bucket, - write_mode, format, + new_data, + expected_read_back, ): - path = f"s3://{bucket}/{uuid.uuid4()!s}" - existing_data = {"a": ["a", "a", "b", "b"], "b": ["c", "d", "e", "f"]} - new_data = { - "a": ["a", "a", "b", "b"], - "b": ["g", "h", "i", "j"], - } - - read_back = arrange_write_mode_test( - daft.from_pydict(existing_data), - daft.from_pydict(new_data).where(daft.lit(False)), # Empty data - path, - format, - write_mode, - None, - minio_io_config, + _run_overwrite_partitions_test( + path=f"s3://{bucket}/{uuid.uuid4()!s}", + format=format, + new_data=new_data, + expected_read_back=expected_read_back, + io_config=minio_io_config, ) - - # Check the data - if write_mode == "append": - # The data should be the same as the existing data - assert read_back["a"] == ["a", "a", "b", "b"] - assert read_back["b"] == ["c", "d", "e", "f"] - elif write_mode == "overwrite": - # The data should be empty because we are overwriting the existing data - assert read_back["a"] == [] - assert read_back["b"] == [] - else: - raise ValueError(f"Unsupported write_mode: {write_mode}") diff --git a/tests/sql/test_binary_op_exprs.py b/tests/sql/test_binary_op_exprs.py index cfc47efb44..c4a16507c3 100644 --- a/tests/sql/test_binary_op_exprs.py +++ b/tests/sql/test_binary_op_exprs.py @@ -75,20 +75,20 @@ def test_unsupported_div_floor(): _assert_df_op_raise( lambda: df.select(daft.col("A") // daft.col("C")).collect(), - "TypeError Cannot perform floor divide on types: Int64, Boolean", + "Cannot perform floor divide on types: Int64, Boolean", ) _assert_df_op_raise( lambda: df.select(daft.col("C") // daft.col("A")).collect(), - "TypeError Cannot perform floor divide on types: Boolean, Int64", + "Cannot perform floor divide on types: Boolean, Int64", ) _assert_df_op_raise( lambda: df.select(daft.col("B") // daft.col("C")).collect(), - "TypeError Cannot perform floor divide on types: Float64, Boolean", + "Cannot perform floor divide on types: Float64, Boolean", ) _assert_df_op_raise( lambda: df.select(daft.col("C") // daft.col("B")).collect(), - "TypeError Cannot perform floor divide on types: Boolean, Float64", + "Cannot perform floor divide on types: Boolean, Float64", ) diff --git a/tests/sql/test_sql_null_safe_equals.py b/tests/sql/test_sql_null_safe_equals.py new file mode 100644 index 0000000000..973e5c5d48 --- /dev/null +++ b/tests/sql/test_sql_null_safe_equals.py @@ -0,0 +1,106 @@ +import pytest + +import daft +from daft.sql import SQLCatalog + + +@pytest.mark.parametrize( + "query,expected", + [ + ("SELECT * FROM df1 WHERE val <=> 20", {"id": [2], "val": [20]}), + ("SELECT * FROM df1 WHERE val <=> NULL", {"id": [3], "val": [None]}), + ( + "SELECT df1.id, df1.val, df2.score FROM df1 JOIN df2 ON df1.id <=> df2.id", + {"id": [1, 2, None], "val": [10, 20, 40], "score": [0.1, 0.2, 0.3]}, + ), + ( + "SELECT * FROM df1 WHERE val <=> 10 OR val <=> NULL", + {"id": [1, 3], "val": [10, None]}, # Matches both 10 and NULL values + ), + ], +) +def test_null_safe_equals_basic(query, expected): + """Test basic null-safe equality operator (<=>).""" + df1 = daft.from_pydict({"id": [1, 2, 3, None], "val": [10, 20, None, 40]}) + df2 = daft.from_pydict({"id": [1, 2, None, 4], "score": [0.1, 0.2, 0.3, 0.4]}) + + catalog = SQLCatalog({"df1": df1, "df2": df2}) + result = daft.sql(query, catalog).to_pydict() + assert result == expected + + +@pytest.mark.parametrize( + "query,expected", + [ + ("SELECT * FROM df WHERE NOT (val <=> NULL)", {"id": [1, 3, None], "val": [10, 30, 40]}), + ("SELECT * FROM df WHERE val <=> 10 OR id > 2", {"id": [1, 3], "val": [10, 30]}), + ( + "SELECT *, CASE WHEN val <=> NULL THEN 'is_null' ELSE 'not_null' END as val_status FROM df", + { + "id": [1, 2, 3, None, None], + "val": [10, None, 30, 40, None], + "val_status": ["not_null", "is_null", "not_null", "not_null", "is_null"], + }, + ), + ], +) +def test_null_safe_equals_complex(query, expected): + """Test complex expressions using null-safe equality.""" + df = daft.from_pydict({"id": [1, 2, 3, None, None], "val": [10, None, 30, 40, None]}) + + catalog = SQLCatalog({"df": df}) + result = daft.sql(query, catalog).to_pydict() + assert result == expected + + +@pytest.mark.parametrize( + "query,expected", + [ + ( + "SELECT * FROM df WHERE int_val <=> 1", + {"int_val": [1], "str_val": ["a"], "bool_val": [True], "float_val": [1.1]}, + ), + ( + "SELECT * FROM df WHERE str_val <=> 'c'", + {"int_val": [3], "str_val": ["c"], "bool_val": [False], "float_val": [3.3]}, + ), + ( + "SELECT * FROM df WHERE bool_val <=> false", + {"int_val": [3], "str_val": ["c"], "bool_val": [False], "float_val": [3.3]}, + ), + ( + "SELECT * FROM df WHERE float_val <=> 1.1", + {"int_val": [1], "str_val": ["a"], "bool_val": [True], "float_val": [1.1]}, + ), + ( + "SELECT * FROM df WHERE int_val <=> NULL", + {"int_val": [None], "str_val": [None], "bool_val": [None], "float_val": [None]}, + ), + ( + "SELECT * FROM df WHERE str_val <=> NULL", + {"int_val": [None], "str_val": [None], "bool_val": [None], "float_val": [None]}, + ), + ( + "SELECT * FROM df WHERE bool_val <=> NULL", + {"int_val": [None], "str_val": [None], "bool_val": [None], "float_val": [None]}, + ), + ( + "SELECT * FROM df WHERE float_val <=> NULL", + {"int_val": [None], "str_val": [None], "bool_val": [None], "float_val": [None]}, + ), + ], +) +def test_null_safe_equals_types(query, expected): + """Test null-safe equality with different data types.""" + df = daft.from_pydict( + { + "int_val": [1, None, 3], + "str_val": ["a", None, "c"], + "bool_val": [True, None, False], + "float_val": [1.1, None, 3.3], + } + ) + + catalog = SQLCatalog({"df": df}) + result = daft.sql(query, catalog).to_pydict() + assert result == expected diff --git a/tests/sql/test_table_functions/__init__.py b/tests/sql/test_table_functions/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/sql/test_table_functions/test_read_iceberg.py b/tests/sql/test_table_functions/test_read_iceberg.py new file mode 100644 index 0000000000..94e9ab43fb --- /dev/null +++ b/tests/sql/test_table_functions/test_read_iceberg.py @@ -0,0 +1,20 @@ +import pytest + +import daft + + +@pytest.mark.skip( + "invoke manually via `uv run tests/sql/test_table_functions/test_read_iceberg.py `" +) +def test_read_iceberg(metadata_location): + df = daft.sql(f"SELECT * FROM read_iceberg('{metadata_location}')") + print(df.collect()) + + +if __name__ == "__main__": + import sys + + if len(sys.argv) < 2: + print("usage: test_read_iceberg.py ") + sys.exit(1) + test_read_iceberg(metadata_location=sys.argv[1]) diff --git a/tests/sql/test_table_funcs.py b/tests/sql/test_table_functions/test_table_functions.py similarity index 70% rename from tests/sql/test_table_funcs.py rename to tests/sql/test_table_functions/test_table_functions.py index 5b765dc0d1..ad0851fb10 100644 --- a/tests/sql/test_table_funcs.py +++ b/tests/sql/test_table_functions/test_table_functions.py @@ -13,18 +13,42 @@ def sample_schema(): return {"a": daft.DataType.float32(), "b": daft.DataType.string()} +def test_sql_read_json(): + df = daft.sql("SELECT * FROM read_json('tests/assets/json-data/small.jsonl')").collect() + expected = daft.read_json("tests/assets/json-data/small.jsonl").collect() + assert df.to_pydict() == expected.to_pydict() + + +def test_sql_read_json_path(): + df = daft.sql("SELECT * FROM 'tests/assets/json-data/small.jsonl'").collect() + expected = daft.read_json("tests/assets/json-data/small.jsonl").collect() + assert df.to_pydict() == expected.to_pydict() + + def test_sql_read_parquet(): df = daft.sql("SELECT * FROM read_parquet('tests/assets/parquet-data/mvp.parquet')").collect() expected = daft.read_parquet("tests/assets/parquet-data/mvp.parquet").collect() assert df.to_pydict() == expected.to_pydict() +def test_sql_read_parquet_path(): + df = daft.sql("SELECT * FROM 'tests/assets/parquet-data/mvp.parquet'").collect() + expected = daft.read_parquet("tests/assets/parquet-data/mvp.parquet").collect() + assert df.to_pydict() == expected.to_pydict() + + def test_sql_read_csv(sample_csv_path): df = daft.sql(f"SELECT * FROM read_csv('{sample_csv_path}')").collect() expected = daft.read_csv(sample_csv_path).collect() assert df.to_pydict() == expected.to_pydict() +def test_sql_read_csv_path(sample_csv_path): + df = daft.sql(f"SELECT * FROM '{sample_csv_path}'").collect() + expected = daft.read_csv(sample_csv_path).collect() + assert df.to_pydict() == expected.to_pydict() + + @pytest.mark.parametrize("has_headers", [True, False]) def test_read_csv_headers(sample_csv_path, has_headers): df1 = daft.read_csv(sample_csv_path, has_headers=has_headers) diff --git a/tests/sql/test_uri_exprs.py b/tests/sql/test_uri_exprs.py new file mode 100644 index 0000000000..089635ab0b --- /dev/null +++ b/tests/sql/test_uri_exprs.py @@ -0,0 +1,96 @@ +import os +import tempfile + +import daft +from daft import col, lit + + +def test_url_download(): + df = daft.from_pydict({"one": [1]}) # just have a single row, doesn't matter what it is + url = "https://raw.githubusercontent.com/Eventual-Inc/Daft/refs/heads/main/LICENSE" + + # download one + df_actual = daft.sql(f"SELECT url_download('{url}') as downloaded FROM df").collect().to_pydict() + df_expect = df.select(lit(url).url.download().alias("downloaded")).collect().to_pydict() + + assert df_actual == df_expect + + +def test_url_download_multi(): + df = daft.from_pydict( + { + "urls": [ + "https://raw.githubusercontent.com/Eventual-Inc/Daft/refs/heads/main/README.rst", + "https://raw.githubusercontent.com/Eventual-Inc/Daft/refs/heads/main/LICENSE", + ] + } + ) + + actual = ( + daft.sql( + """ + SELECT + url_download(urls) as downloaded, + url_download(urls, max_connections=>1) as downloaded_single_conn, + url_download(urls, on_error=>'null') as downloaded_ignore_errors + FROM df + """ + ) + .collect() + .to_pydict() + ) + + expected = ( + df.select( + col("urls").url.download().alias("downloaded"), + col("urls").url.download(max_connections=1).alias("downloaded_single_conn"), + col("urls").url.download(on_error="null").alias("downloaded_ignore_errors"), + ) + .collect() + .to_pydict() + ) + + assert actual == expected + + +def test_url_upload(): + with tempfile.TemporaryDirectory() as tmp_dir: + df = daft.from_pydict( + { + "data": [b"test1", b"test2"], + "paths": [ + os.path.join(tmp_dir, "test1.txt"), + os.path.join(tmp_dir, "test2.txt"), + ], + } + ) + + actual = ( + daft.sql( + """ + SELECT + url_upload(data, paths) as uploaded, + url_upload(data, paths, max_connections=>1) as uploaded_single_conn, + url_upload(data, paths, on_error=>'null') as uploaded_ignore_errors + FROM df + """ + ) + .collect() + .to_pydict() + ) + + expected = ( + df.select( + col("data").url.upload(daft.col("paths")).alias("uploaded"), + col("data").url.upload(daft.col("paths"), max_connections=1).alias("uploaded_single_conn"), + col("data").url.upload(daft.col("paths"), on_error="null").alias("uploaded_ignore_errors"), + ) + .collect() + .to_pydict() + ) + + assert actual == expected + + # Verify files were created + assert os.path.exists(os.path.join(tmp_dir, "test1.txt")) + assert os.path.exists(os.path.join(tmp_dir, "test2.txt")) diff --git a/tests/table/binary/test_concat.py b/tests/table/binary/test_concat.py new file mode 100644 index 0000000000..cc603d5d39 --- /dev/null +++ b/tests/table/binary/test_concat.py @@ -0,0 +1,258 @@ +from __future__ import annotations + +import pytest + +from daft.expressions import col, lit +from daft.table import MicroPartition + + +@pytest.mark.parametrize( + "input_a,input_b,expected_result", + [ + # Basic ASCII concatenation + ( + [b"Hello", b"Test", b"", b"End"], + [b" World", b"ing", b"Empty", b"!"], + [b"Hello World", b"Testing", b"Empty", b"End!"], + ), + # Special binary sequences + ( + [ + b"\x00\x01", # Null and control chars + b"\xff\xfe", # High-value bytes + b"Hello\x00", # String with null + b"\xe2\x98", # Partial UTF-8 + b"\xf0\x9f\x98", # Another partial UTF-8 + ], + [ + b"\x02\x03", # More control chars + b"\xfd\xfc", # More high-value bytes + b"\x00World", # Null and string + b"\x83", # Complete the UTF-8 snowman + b"\x89", # Complete the UTF-8 winking face + ], + [ + b"\x00\x01\x02\x03", # Concatenated control chars + b"\xff\xfe\xfd\xfc", # Concatenated high-value bytes + b"Hello\x00\x00World", # String with multiple nulls + b"\xe2\x98\x83", # Complete UTF-8 snowman (☃) + b"\xf0\x9f\x98\x89", # Complete UTF-8 winking face (😉) + ], + ), + # Nulls and empty strings + ( + [b"Hello", None, b"", b"Test", None, b"End", b""], + [b" World", b"!", None, None, b"ing", b"", b"Empty"], + [b"Hello World", None, None, None, None, b"End", b"Empty"], + ), + # Mixed length concatenation + ( + [b"a", b"ab", b"abc", b"abcd"], + [b"1", b"12", b"123", b"1234"], + [b"a1", b"ab12", b"abc123", b"abcd1234"], + ), + # Empty string combinations + ( + [b"", b"", b"Hello", b"World", b""], + [b"", b"Test", b"", b"", b"!"], + [b"", b"Test", b"Hello", b"World", b"!"], + ), + # Complex UTF-8 sequences + ( + [ + b"\xe2\x98\x83", # Snowman + b"\xf0\x9f\x98\x89", # Winking face + b"\xf0\x9f\x8c\x88", # Rainbow + b"\xe2\x98\x83\xf0\x9f\x98\x89", # Snowman + Winking face + ], + [ + b"\xf0\x9f\x98\x89", # Winking face + b"\xe2\x98\x83", # Snowman + b"\xe2\x98\x83", # Snowman + b"\xf0\x9f\x8c\x88", # Rainbow + ], + [ + b"\xe2\x98\x83\xf0\x9f\x98\x89", # Snowman + Winking face + b"\xf0\x9f\x98\x89\xe2\x98\x83", # Winking face + Snowman + b"\xf0\x9f\x8c\x88\xe2\x98\x83", # Rainbow + Snowman + b"\xe2\x98\x83\xf0\x9f\x98\x89\xf0\x9f\x8c\x88", # Snowman + Winking face + Rainbow + ], + ), + # Zero bytes in different positions + ( + [ + b"\x00abc", # Leading zero + b"abc\x00", # Trailing zero + b"ab\x00c", # Middle zero + b"\x00ab\x00c\x00", # Multiple zeros + ], + [ + b"def\x00", # Trailing zero + b"\x00def", # Leading zero + b"d\x00ef", # Middle zero + b"\x00de\x00f\x00", # Multiple zeros + ], + [ + b"\x00abcdef\x00", # Zeros at ends + b"abc\x00\x00def", # Adjacent zeros + b"ab\x00cd\x00ef", # Separated zeros + b"\x00ab\x00c\x00\x00de\x00f\x00", # Many zeros + ], + ), + ], +) +def test_binary_concat( + input_a: list[bytes | None], input_b: list[bytes | None], expected_result: list[bytes | None] +) -> None: + table = MicroPartition.from_pydict({"a": input_a, "b": input_b}) + result = table.eval_expression_list([col("a").binary.concat(col("b"))]) + assert result.to_pydict() == {"a": expected_result} + + +@pytest.mark.parametrize( + "input_data,literal,expected_result", + [ + # Basic broadcasting + ( + [b"Hello", b"Goodbye", b"Test"], + b" World!", + [b"Hello World!", b"Goodbye World!", b"Test World!"], + ), + # Broadcasting with nulls + ( + [b"Hello", None, b"Test"], + b" World!", + [b"Hello World!", None, b"Test World!"], + ), + # Broadcasting with special sequences + ( + [b"\x00\x01", b"\xff\xfe", b"Hello\x00"], + b"\x02\x03", + [b"\x00\x01\x02\x03", b"\xff\xfe\x02\x03", b"Hello\x00\x02\x03"], + ), + # Broadcasting with empty strings + ( + [b"", b"Test", b""], + b"\xff\xfe", + [b"\xff\xfe", b"Test\xff\xfe", b"\xff\xfe"], + ), + # Broadcasting with UTF-8 + ( + [b"Hello", b"Test", b"Goodbye"], + b"\xe2\x98\x83", # Snowman + [b"Hello\xe2\x98\x83", b"Test\xe2\x98\x83", b"Goodbye\xe2\x98\x83"], + ), + # Broadcasting with zero bytes + ( + [b"Hello", b"Test\x00", b"\x00World"], + b"\x00", + [b"Hello\x00", b"Test\x00\x00", b"\x00World\x00"], + ), + # Broadcasting with literal None + ( + [b"Hello", None, b"Test", b""], + None, + [None, None, None, None], # Any concat with None should result in None + ), + ], +) +def test_binary_concat_broadcast( + input_data: list[bytes | None], literal: bytes | None, expected_result: list[bytes | None] +) -> None: + # Test right-side broadcasting + table = MicroPartition.from_pydict({"a": input_data}) + result = table.eval_expression_list([col("a").binary.concat(literal)]) + assert result.to_pydict() == {"a": expected_result} + + # Test left-side broadcasting + table = MicroPartition.from_pydict({"b": input_data}) + result = table.eval_expression_list([lit(literal).binary.concat(col("b"))]) + if literal is None: + # When literal is None, all results should be None + assert result.to_pydict() == {"literal": [None] * len(input_data)} + else: + assert result.to_pydict() == { + "literal": [ + lit + data if data is not None else None for lit, data in zip([literal] * len(input_data), input_data) + ] + } + + +def test_binary_concat_edge_cases() -> None: + # Test various edge cases + table = MicroPartition.from_pydict( + { + "a": [ + b"", # Empty string + b"\x00", # Single null byte + b"\xff", # Single high byte + b"Hello", # Normal string + None, # Null value + b"\xe2\x98\x83", # UTF-8 sequence + b"\xf0\x9f\x98\x89", # Another UTF-8 sequence + b"\x80\x81\x82", # Binary sequence + b"\xff\xff\xff", # High bytes + ], + "b": [ + b"", # Empty + Empty + b"\x00", # Null + Null + b"\x00", # High + Null + b"", # Normal + Empty + None, # Null + Null + b"\xf0\x9f\x98\x89", # UTF-8 + UTF-8 + b"\xe2\x98\x83", # UTF-8 + UTF-8 + b"\x83\x84\x85", # Binary + Binary + b"\xfe\xfe\xfe", # High bytes + High bytes + ], + } + ) + result = table.eval_expression_list([col("a").binary.concat(col("b"))]) + assert result.to_pydict() == { + "a": [ + b"", # Empty + Empty = Empty + b"\x00\x00", # Null + Null = Two nulls + b"\xff\x00", # High + Null = High then null + b"Hello", # Normal + Empty = Normal + None, # Null + Null = Null + b"\xe2\x98\x83\xf0\x9f\x98\x89", # Snowman + Winking face + b"\xf0\x9f\x98\x89\xe2\x98\x83", # Winking face + Snowman + b"\x80\x81\x82\x83\x84\x85", # Binary sequence concatenation + b"\xff\xff\xff\xfe\xfe\xfe", # High bytes concatenation + ] + } + + +def test_binary_concat_errors() -> None: + # Test concat with incompatible type (string) + table = MicroPartition.from_pydict({"a": [b"hello", b"world"], "b": ["foo", "bar"]}) + with pytest.raises(Exception, match="Expects inputs to concat to be binary, but received a#Binary and b#Utf8"): + table.eval_expression_list([col("a").binary.concat(col("b"))]) + + # Test concat with incompatible type (integer) + table = MicroPartition.from_pydict({"a": [b"hello", b"world"], "b": [1, 2]}) + with pytest.raises(Exception, match="Expects inputs to concat to be binary, but received a#Binary and b#Int64"): + table.eval_expression_list([col("a").binary.concat(col("b"))]) + + # Test concat with incompatible type (float) + table = MicroPartition.from_pydict({"a": [b"hello", b"world"], "b": [1.0, 2.0]}) + with pytest.raises(Exception, match="Expects inputs to concat to be binary, but received a#Binary and b#Float64"): + table.eval_expression_list([col("a").binary.concat(col("b"))]) + + # Test concat with incompatible type (boolean) + table = MicroPartition.from_pydict({"a": [b"hello", b"world"], "b": [True, False]}) + with pytest.raises(Exception, match="Expects inputs to concat to be binary, but received a#Binary and b#Boolean"): + table.eval_expression_list([col("a").binary.concat(col("b"))]) + + # Test concat with wrong number of arguments + table = MicroPartition.from_pydict({"a": [b"hello", b"world"], "b": [b"foo", b"bar"], "c": [b"test", b"data"]}) + with pytest.raises( + Exception, match="(?:ExpressionBinaryNamespace.)?concat\\(\\) takes 2 positional arguments but 3 were given" + ): + table.eval_expression_list([col("a").binary.concat(col("b"), col("c"))]) + + # Test concat with no arguments + table = MicroPartition.from_pydict({"a": [b"hello", b"world"]}) + with pytest.raises( + Exception, match="(?:ExpressionBinaryNamespace.)?concat\\(\\) missing 1 required positional argument: 'other'" + ): + table.eval_expression_list([col("a").binary.concat()]) diff --git a/tests/table/binary/test_fixed_size_binary_concat.py b/tests/table/binary/test_fixed_size_binary_concat.py new file mode 100644 index 0000000000..9d87a5de4e --- /dev/null +++ b/tests/table/binary/test_fixed_size_binary_concat.py @@ -0,0 +1,250 @@ +from __future__ import annotations + +import pyarrow as pa +import pytest + +from daft import DataType +from daft.expressions import col, lit +from daft.table import MicroPartition + + +@pytest.mark.parametrize( + "input_data1,input_data2,expected_result,size1,size2", + [ + # Basic concatenation + ([b"abc", b"def", b"ghi"], [b"xyz", b"uvw", b"rst"], [b"abcxyz", b"defuvw", b"ghirst"], 3, 3), + # With nulls + ([b"ab", None, b"cd"], [b"xy", b"uv", None], [b"abxy", None, None], 2, 2), + # Special sequences + ( + [b"\x00\x01", b"\xff\xfe", b"\xe2\x98"], + [b"\x99\x00", b"\x01\xff", b"\xfe\xe2"], + [b"\x00\x01\x99\x00", b"\xff\xfe\x01\xff", b"\xe2\x98\xfe\xe2"], + 2, + 2, + ), + # Complex UTF-8 sequences + ( + [b"\xe2\x98", b"\xf0\x9f", b"\xf0\x9f"], # Partial UTF-8 sequences + [b"\x83\x00", b"\x98\x89", b"\x8c\x88"], # Complete the sequences + [b"\xe2\x98\x83\x00", b"\xf0\x9f\x98\x89", b"\xf0\x9f\x8c\x88"], # Complete UTF-8 characters + 2, + 2, + ), + # Mixed length concatenation + ( + [b"a", b"b", b"c", b"d"], # Single bytes + [b"12", b"34", b"56", b"78"], # Two bytes + [b"a12", b"b34", b"c56", b"d78"], # Three bytes result + 1, + 2, + ), + # Zero bytes in different positions + ( + [b"\x00a", b"a\x00", b"\x00\x00"], # Zeros in different positions + [b"b\x00", b"\x00b", b"c\x00"], # More zeros + [b"\x00ab\x00", b"a\x00\x00b", b"\x00\x00c\x00"], # Combined zeros + 2, + 2, + ), + ], +) +def test_fixed_size_binary_concat( + input_data1: list[bytes | None], + input_data2: list[bytes | None], + expected_result: list[bytes | None], + size1: int, + size2: int, +) -> None: + table = MicroPartition.from_pydict( + { + "a": pa.array(input_data1, type=pa.binary(size1)), + "b": pa.array(input_data2, type=pa.binary(size2)), + } + ) + # Verify inputs are FixedSizeBinary before concatenating + assert table.schema()["a"].dtype == DataType.fixed_size_binary(size1) + assert table.schema()["b"].dtype == DataType.fixed_size_binary(size2) + result = table.eval_expression_list([col("a").binary.concat(col("b"))]) + assert result.to_pydict() == {"a": expected_result} + # Result should be FixedSizeBinary with combined size when both inputs are FixedSizeBinary + assert result.schema()["a"].dtype == DataType.fixed_size_binary(size1 + size2) + + +def test_fixed_size_binary_concat_large() -> None: + # Test concatenating large fixed size binary strings + size1, size2 = 100, 50 + large_binary1 = b"x" * size1 + large_binary2 = b"y" * size2 + + table = MicroPartition.from_pydict( + { + "a": pa.array([large_binary1, b"a" * size1, large_binary1], type=pa.binary(size1)), + "b": pa.array([large_binary2, large_binary2, b"b" * size2], type=pa.binary(size2)), + } + ) + + # Verify inputs are FixedSizeBinary + assert table.schema()["a"].dtype == DataType.fixed_size_binary(size1) + assert table.schema()["b"].dtype == DataType.fixed_size_binary(size2) + + result = table.eval_expression_list([col("a").binary.concat(col("b"))]) + assert result.to_pydict() == { + "a": [ + b"x" * size1 + b"y" * size2, # Large + Large + b"a" * size1 + b"y" * size2, # Small repeated + Large + b"x" * size1 + b"b" * size2, # Large + Small repeated + ] + } + # Result should be FixedSizeBinary with combined size + assert result.schema()["a"].dtype == DataType.fixed_size_binary(size1 + size2) + + +@pytest.mark.parametrize( + "input_data,literal,expected_result,size", + [ + # Basic broadcasting with fixed size + ([b"abc", b"def", b"ghi"], b"xyz", [b"abcxyz", b"defxyz", b"ghixyz"], 3), + # Broadcasting with nulls + ([b"ab", None, b"cd"], b"xy", [b"abxy", None, b"cdxy"], 2), + # Broadcasting with special sequences + ( + [b"\x00\x01", b"\xff\xfe", b"\xe2\x98"], + b"\x99\x00", + [b"\x00\x01\x99\x00", b"\xff\xfe\x99\x00", b"\xe2\x98\x99\x00"], + 2, + ), + # Broadcasting with UTF-8 + ( + [b"\xe2\x98", b"\xf0\x9f", b"\xf0\x9f"], + b"\x83\x00", + [b"\xe2\x98\x83\x00", b"\xf0\x9f\x83\x00", b"\xf0\x9f\x83\x00"], + 2, + ), + ], +) +def test_fixed_size_binary_concat_broadcast( + input_data: list[bytes | None], + literal: bytes | None, + expected_result: list[bytes | None], + size: int, +) -> None: + table = MicroPartition.from_pydict( + { + "a": pa.array(input_data, type=pa.binary(size)), + } + ) + # Verify input is FixedSizeBinary + assert table.schema()["a"].dtype == DataType.fixed_size_binary(size) + + # Test right-side broadcasting + result = table.eval_expression_list([col("a").binary.concat(lit(literal))]) + assert result.to_pydict() == {"a": expected_result} + # Result should be Binary when using literals + assert result.schema()["a"].dtype == DataType.binary() + + # Test left-side broadcasting + result = table.eval_expression_list([lit(literal).binary.concat(col("a"))]) + assert result.to_pydict() == {"literal": [literal + data if data is not None else None for data in input_data]} + # Result should be Binary when using literals + assert result.schema()["literal"].dtype == DataType.binary() + + +def test_fixed_size_binary_concat_edge_cases() -> None: + # Test various edge cases with different fixed sizes + cases = [ + # Single byte values + (1, [b"\x00", b"\xff", b"a", None], 1, [b"\x01", b"\x02", b"b", None], [b"\x00\x01", b"\xff\x02", b"ab", None]), + # Two byte values + ( + 2, + [b"\x00\x00", b"\xff\xff", b"ab", None], + 2, + [b"\x01\x01", b"\x02\x02", b"cd", None], + [b"\x00\x00\x01\x01", b"\xff\xff\x02\x02", b"abcd", None], + ), + # Four byte values (common for integers, floats) + ( + 4, + [b"\x00" * 4, b"\xff" * 4, b"abcd", None], + 2, + [b"\x01\x01", b"\x02\x02", b"ef", None], + [b"\x00" * 4 + b"\x01\x01", b"\xff" * 4 + b"\x02\x02", b"abcdef", None], + ), + ] + + for size1, input_data1, size2, input_data2, expected_result in cases: + table = MicroPartition.from_pydict( + { + "a": pa.array(input_data1, type=pa.binary(size1)), + "b": pa.array(input_data2, type=pa.binary(size2)), + } + ) + # Verify inputs are FixedSizeBinary + assert table.schema()["a"].dtype == DataType.fixed_size_binary(size1) + assert table.schema()["b"].dtype == DataType.fixed_size_binary(size2) + result = table.eval_expression_list([col("a").binary.concat(col("b"))]) + assert result.to_pydict() == {"a": expected_result} + # Result should be FixedSizeBinary with combined size + assert result.schema()["a"].dtype == DataType.fixed_size_binary(size1 + size2) + + +def test_fixed_size_binary_concat_with_binary() -> None: + # Test concatenating FixedSizeBinary with regular Binary + table = MicroPartition.from_pydict( + { + "a": pa.array([b"abc", b"def", None], type=pa.binary(3)), + "b": pa.array([b"x", b"yz", b"uvw"]), # Regular Binary + } + ) + # Verify first input is FixedSizeBinary + assert table.schema()["a"].dtype == DataType.fixed_size_binary(3) + assert table.schema()["b"].dtype == DataType.binary() + + # Test FixedSizeBinary + Binary + result = table.eval_expression_list([col("a").binary.concat(col("b"))]) + assert result.to_pydict() == {"a": [b"abcx", b"defyz", None]} + # Result should be Binary when mixing types + assert result.schema()["a"].dtype == DataType.binary() + + # Test Binary + FixedSizeBinary + result = table.eval_expression_list([col("b").binary.concat(col("a"))]) + assert result.to_pydict() == {"b": [b"xabc", b"yzdef", None]} + # Result should be Binary when mixing types + assert result.schema()["b"].dtype == DataType.binary() + + +def test_fixed_size_binary_concat_with_literals() -> None: + table = MicroPartition.from_pydict( + { + "a": pa.array([b"abc", b"def", None], type=pa.binary(3)), + } + ) + # Verify input is FixedSizeBinary + assert table.schema()["a"].dtype == DataType.fixed_size_binary(3) + + # Test with literal + result = table.eval_expression_list([col("a").binary.concat(lit(b"xyz"))]) + assert result.to_pydict() == {"a": [b"abcxyz", b"defxyz", None]} + # Result should be Binary when using literals + assert result.schema()["a"].dtype == DataType.binary() + + # Test with null literal + result = table.eval_expression_list([col("a").binary.concat(lit(None))]) + assert result.to_pydict() == {"a": [None, None, None]} + # Result should be Binary when using literals + assert result.schema()["a"].dtype == DataType.binary() + + +def test_fixed_size_binary_concat_errors() -> None: + # Test error cases + table = MicroPartition.from_pydict( + { + "a": pa.array([b"abc", b"def"], type=pa.binary(3)), + "b": [1, 2], # Wrong type + } + ) + + # Test concat with wrong type + with pytest.raises(Exception, match="Expects inputs to concat to be binary, but received a#Binary and b#Int64"): + table.eval_expression_list([col("a").binary.concat(col("b"))]) diff --git a/tests/table/binary/test_fixed_size_binary_length.py b/tests/table/binary/test_fixed_size_binary_length.py new file mode 100644 index 0000000000..3ea737c7fc --- /dev/null +++ b/tests/table/binary/test_fixed_size_binary_length.py @@ -0,0 +1,121 @@ +from __future__ import annotations + +import pyarrow as pa +import pytest + +from daft import DataType +from daft.expressions import col +from daft.table import MicroPartition + + +@pytest.mark.parametrize( + "input_data,expected_result,size", + [ + # Basic binary data + ([b"abc", b"def", b"ghi"], [3, 3, 3], 3), + # With nulls + ([b"ab", None, b"cd"], [2, None, 2], 2), + # Special sequences + ([b"\x00\x01", b"\xff\xfe", b"\xe2\x98"], [2, 2, 2], 2), + # Complex UTF-8 sequences + ( + [b"\xe2\x98\x83", b"\xf0\x9f\x98", b"\xf0\x9f\x8c"], # Snowman, partial faces + [3, 3, 3], + 3, + ), + # Zero bytes in different positions + ( + [b"\x00ab", b"a\x00b", b"ab\x00"], # Leading, middle, trailing zeros + [3, 3, 3], + 3, + ), + # High value bytes + ([b"\xff\xff\xff", b"\xfe\xfe\xfe", b"\xfd\xfd\xfd"], [3, 3, 3], 3), + # Mixed binary content + ( + [b"a\xff\x00", b"\x00\xff\x83", b"\xe2\x98\x83"], # Mix of ASCII, high bytes, nulls, and UTF-8 + [3, 3, 3], + 3, + ), + ], +) +def test_fixed_size_binary_length( + input_data: list[bytes | None], + expected_result: list[int | None], + size: int, +) -> None: + table = MicroPartition.from_pydict( + { + "a": pa.array(input_data, type=pa.binary(size)), + } + ) + # Verify input is FixedSizeBinary before getting length + assert table.schema()["a"].dtype == DataType.fixed_size_binary(size) + result = table.eval_expression_list([col("a").binary.length()]) + assert result.to_pydict() == {"a": expected_result} + # Result should be UInt64 since length can't be negative + assert result.schema()["a"].dtype == DataType.uint64() + + +def test_fixed_size_binary_length_large() -> None: + # Test with larger fixed size binary values + size = 100 + input_data = [ + b"x" * size, # Repeated ASCII + b"\x00" * size, # All nulls + b"\xff" * size, # All high bytes + None, # Null value + (b"Hello\x00World!" * 9)[:size], # Pattern with null byte (12 bytes * 9 = 108 bytes, truncated to 100) + (b"\xe2\x98\x83" * 34)[:size], # Repeated UTF-8 sequence (3 bytes * 34 = 102 bytes, truncated to 100) + ] + expected_result = [size, size, size, None, size, size] + + table = MicroPartition.from_pydict( + { + "a": pa.array(input_data, type=pa.binary(size)), + } + ) + # Verify input is FixedSizeBinary + assert table.schema()["a"].dtype == DataType.fixed_size_binary(size) + result = table.eval_expression_list([col("a").binary.length()]) + assert result.to_pydict() == {"a": expected_result} + assert result.schema()["a"].dtype == DataType.uint64() + + +def test_fixed_size_binary_length_edge_cases() -> None: + # Test various edge cases with different sizes + cases = [ + # Single byte values + (1, [b"\x00", b"\xff", b"a", None], [1, 1, 1, None]), + # Two byte values + (2, [b"\x00\x00", b"\xff\xff", b"ab", None], [2, 2, 2, None]), + # Four byte values (common for integers, floats) + (4, [b"\x00\x00\x00\x00", b"\xff\xff\xff\xff", b"abcd", None], [4, 4, 4, None]), + # Eight byte values (common for timestamps, large integers) + (8, [b"\x00" * 8, b"\xff" * 8, b"abcdefgh", None], [8, 8, 8, None]), + ] + + for size, input_data, expected_result in cases: + table = MicroPartition.from_pydict( + { + "a": pa.array(input_data, type=pa.binary(size)), + } + ) + # Verify input is FixedSizeBinary + assert table.schema()["a"].dtype == DataType.fixed_size_binary(size) + result = table.eval_expression_list([col("a").binary.length()]) + assert result.to_pydict() == {"a": expected_result} + assert result.schema()["a"].dtype == DataType.uint64() + + +def test_fixed_size_binary_length_errors() -> None: + # Test error cases + table = MicroPartition.from_pydict( + { + "a": pa.array([1, 2], type=pa.int64()), # Wrong type + } + ) + + # Test length on wrong type + with pytest.raises(Exception, match="Expects input to length to be binary, but received a#Int64"): + table.eval_expression_list([col("a").binary.length()]) diff --git a/tests/table/binary/test_fixed_size_binary_slice.py b/tests/table/binary/test_fixed_size_binary_slice.py new file mode 100644 index 0000000000..81ce4e122e --- /dev/null +++ b/tests/table/binary/test_fixed_size_binary_slice.py @@ -0,0 +1,259 @@ +from __future__ import annotations + +import pyarrow as pa +import pytest + +from daft import DataType +from daft.expressions import col, lit +from daft.table import MicroPartition + + +@pytest.mark.parametrize( + "input_data,start,length,expected_result,size", + [ + # Basic slicing + ([b"abc", b"def", b"ghi"], [0, 1, 2], [2, 1, 1], [b"ab", b"e", b"i"], 3), + # With nulls + ([b"ab", None, b"cd"], [0, 1, 0], [1, 1, 2], [b"a", None, b"cd"], 2), + # Special sequences + ([b"\x00\x01", b"\xff\xfe", b"\xe2\x98"], [1, 0, 1], [1, 1, 1], [b"\x01", b"\xff", b"\x98"], 2), + # Edge cases + ( + [b"abc", b"def", b"ghi"], + [3, 2, 1], # Start at or beyond length + [1, 2, 3], + [b"", b"f", b"hi"], + 3, + ), + # UTF-8 sequences + ( + [b"\xe2\x98\x83", b"\xf0\x9f\x98", b"\xf0\x9f\x8c"], # UTF-8 characters and partials + [0, 1, 2], + [2, 2, 1], + [b"\xe2\x98", b"\x9f\x98", b"\x8c"], + 3, + ), + # Zero bytes in different positions + ([b"\x00ab", b"a\x00b", b"ab\x00"], [0, 1, 2], [2, 1, 1], [b"\x00a", b"\x00", b"\x00"], 3), + # High value bytes + ([b"\xff\xff\xff", b"\xfe\xfe\xfe", b"\xfd\xfd\xfd"], [0, 1, 2], [2, 1, 1], [b"\xff\xff", b"\xfe", b"\xfd"], 3), + # Mixed content + ( + [b"a\xff\x00", b"\x00\xff\x83", b"\xe2\x98\x83"], + [1, 0, 0], + [2, 2, 2], + [b"\xff\x00", b"\x00\xff", b"\xe2\x98"], + 3, + ), + ], +) +def test_fixed_size_binary_slice( + input_data: list[bytes | None], + start: list[int], + length: list[int], + expected_result: list[bytes | None], + size: int, +) -> None: + table = MicroPartition.from_pydict( + { + "a": pa.array(input_data, type=pa.binary(size)), + "start": start, + "length": length, + } + ) + # Verify input is FixedSizeBinary before slicing + assert table.schema()["a"].dtype == DataType.fixed_size_binary(size) + result = table.eval_expression_list([col("a").binary.slice(col("start"), col("length"))]) + assert result.to_pydict() == {"a": expected_result} + # Result should be regular Binary since slice might be smaller + assert result.schema()["a"].dtype == DataType.binary() + + +@pytest.mark.parametrize( + "input_data,start,expected_result,size", + [ + # Without length parameter + ([b"abc", b"def", b"ghi"], [1, 0, 2], [b"bc", b"def", b"i"], 3), + # With nulls + ([b"ab", None, b"cd"], [1, 0, 1], [b"b", None, b"d"], 2), + # UTF-8 sequences + ([b"\xe2\x98\x83", b"\xf0\x9f\x98", b"\xf0\x9f\x8c"], [1, 0, 2], [b"\x98\x83", b"\xf0\x9f\x98", b"\x8c"], 3), + # Special bytes + ([b"\x00\xff\x7f", b"\xff\x00\xff", b"\x7f\xff\x00"], [1, 2, 0], [b"\xff\x7f", b"\xff", b"\x7f\xff\x00"], 3), + ], +) +def test_fixed_size_binary_slice_no_length( + input_data: list[bytes | None], + start: list[int], + expected_result: list[bytes | None], + size: int, +) -> None: + table = MicroPartition.from_pydict( + { + "a": pa.array(input_data, type=pa.binary(size)), + "start": start, + } + ) + # Verify input is FixedSizeBinary before slicing + assert table.schema()["a"].dtype == DataType.fixed_size_binary(size) + result = table.eval_expression_list([col("a").binary.slice(col("start"))]) + assert result.to_pydict() == {"a": expected_result} + # Result should be regular Binary since slice might be smaller + assert result.schema()["a"].dtype == DataType.binary() + + +def test_fixed_size_binary_slice_computed() -> None: + # Test with computed start index (length - 2) + size = 4 + table = MicroPartition.from_pydict( + { + "a": pa.array( + [ + b"abcd", # Start at 2, take 2 + b"\xff\xfe\xfd\xfc", # Start at 2, take 2 + b"\x00\x01\x02\x03", # Start at 2, take 2 + b"\xe2\x98\x83\x00", # Start at 2, take 2 + ], + type=pa.binary(size), + ), + } + ) + # Verify input is FixedSizeBinary + assert table.schema()["a"].dtype == DataType.fixed_size_binary(size) + + # Test with computed start (size - 2) and fixed length + result = table.eval_expression_list( + [ + col("a").binary.slice( + (lit(size) - 2).cast(DataType.int32()), # Start 2 chars from end + 2, # Take 2 chars + ) + ] + ) + assert result.to_pydict() == {"a": [b"cd", b"\xfd\xfc", b"\x02\x03", b"\x83\x00"]} + assert result.schema()["a"].dtype == DataType.binary() + + # Test with fixed start and computed length (size - start) + result = table.eval_expression_list( + [ + col("a").binary.slice( + 1, # Start at second char + (lit(size) - 1).cast(DataType.int32()), # Take remaining chars + ) + ] + ) + assert result.to_pydict() == {"a": [b"bcd", b"\xfe\xfd\xfc", b"\x01\x02\x03", b"\x98\x83\x00"]} + assert result.schema()["a"].dtype == DataType.binary() + + +def test_fixed_size_binary_slice_edge_cases() -> None: + # Test various edge cases with different fixed sizes + cases = [ + # Single byte values + ( + 1, # size + [b"\x00", b"\xff", b"a", None], # input + [0, 0, 0, 0], # start + [1, 1, 1, 1], # length + [b"\x00", b"\xff", b"a", None], # expected + ), + # Two byte values with boundary cases + ( + 2, # size + [b"\x00\x01", b"\xff\xfe", b"ab", None], # input + [1, 2, 0, 1], # start + [1, 0, 2, 1], # length + [b"\x01", b"", b"ab", None], # expected + ), + # Four byte values with various slices + ( + 4, # size + [b"\x00\x01\x02\x03", b"abcd", b"\xff\xfe\xfd\xfc", None], # input + [0, 1, 2, 0], # start + [2, 2, 2, 4], # length + [b"\x00\x01", b"bc", b"\xfd\xfc", None], # expected + ), + ] + + for size, input_data, start, length, expected in cases: + table = MicroPartition.from_pydict( + { + "a": pa.array(input_data, type=pa.binary(size)), + "start": start, + "length": length, + } + ) + # Verify input is FixedSizeBinary + assert table.schema()["a"].dtype == DataType.fixed_size_binary(size) + result = table.eval_expression_list([col("a").binary.slice(col("start"), col("length"))]) + assert result.to_pydict() == {"a": expected} + assert result.schema()["a"].dtype == DataType.binary() + + +def test_fixed_size_binary_slice_with_literals() -> None: + table = MicroPartition.from_pydict( + { + "a": pa.array([b"abc", b"def", None], type=pa.binary(3)), + } + ) + # Verify input is FixedSizeBinary before slicing + assert table.schema()["a"].dtype == DataType.fixed_size_binary(3) + + # Test with literal start and length + result = table.eval_expression_list([col("a").binary.slice(lit(1), lit(1))]) + assert result.to_pydict() == {"a": [b"b", b"e", None]} + assert result.schema()["a"].dtype == DataType.binary() + + # Test with only start + result = table.eval_expression_list([col("a").binary.slice(lit(0))]) + assert result.to_pydict() == {"a": [b"abc", b"def", None]} + assert result.schema()["a"].dtype == DataType.binary() + + # Test with start beyond length + result = table.eval_expression_list([col("a").binary.slice(lit(3), lit(1))]) + assert result.to_pydict() == {"a": [b"", b"", None]} + assert result.schema()["a"].dtype == DataType.binary() + + # Test with zero length + result = table.eval_expression_list([col("a").binary.slice(lit(0), lit(0))]) + assert result.to_pydict() == {"a": [b"", b"", None]} + assert result.schema()["a"].dtype == DataType.binary() + + +def test_fixed_size_binary_slice_errors() -> None: + # Test error cases + table = MicroPartition.from_pydict( + { + "a": pa.array([b"abc", b"def"], type=pa.binary(3)), + "b": [1, 2], # Wrong type + "start": [-1, 0], # Negative start + "length": [0, -1], # Negative length + } + ) + + # Test slice on wrong type + with pytest.raises( + Exception, match="Expects inputs to binary_slice to be binary, integer and integer or null but received Int64" + ): + table.eval_expression_list([col("b").binary.slice(lit(0))]) + + # Test negative start + with pytest.raises(Exception, match="DaftError::ComputeError Failed to cast numeric value to target type"): + table.eval_expression_list([col("a").binary.slice(col("start"))]) + + # Test negative length + with pytest.raises(Exception, match="DaftError::ComputeError Failed to cast numeric value to target type"): + table.eval_expression_list([col("a").binary.slice(lit(0), col("length"))]) + + # Test with wrong number of arguments (too many) + with pytest.raises( + Exception, + match="(?:ExpressionBinaryNamespace.)?slice\\(\\) takes from 2 to 3 positional arguments but 4 were given", + ): + table.eval_expression_list([col("a").binary.slice(lit(0), lit(1), lit(2))]) + + # Test with wrong number of arguments (too few) + with pytest.raises( + Exception, match="(?:ExpressionBinaryNamespace.)?slice\\(\\) missing 1 required positional argument: 'start'" + ): + table.eval_expression_list([col("a").binary.slice()]) diff --git a/tests/table/binary/test_length.py b/tests/table/binary/test_length.py new file mode 100644 index 0000000000..6c497dc015 --- /dev/null +++ b/tests/table/binary/test_length.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +import pytest + +from daft.expressions import col +from daft.table import MicroPartition + + +def test_binary_length() -> None: + table = MicroPartition.from_pydict( + { + "col": [ + b"foo", # Basic ASCII + None, # Null value + b"", # Empty string + b"Hello\xe2\x98\x83World", # UTF-8 character in middle + b"\xf0\x9f\x98\x89test", # UTF-8 bytes at start + b"test\xf0\x9f\x8c\x88", # UTF-8 bytes at end + b"\xe2\x98\x83\xf0\x9f\x98\x89\xf0\x9f\x8c\x88", # Multiple UTF-8 sequences + b"Hello\x00World", # Null character + b"\xff\xfe\xfd", # High bytes + b"\x00\x01\x02", # Control characters + b"a" * 1000, # Long ASCII string + b"\xff" * 1000, # Long binary string + ] + } + ) + result = table.eval_expression_list([col("col").binary.length()]) + assert result.to_pydict() == { + "col": [ + 3, # "foo" + None, # None + 0, # "" + 13, # "Hello☃World" (5 + 3 + 5 = 13 bytes) + 8, # "😉test" (4 + 4 = 8 bytes) + 8, # "test🌈" (4 + 4 = 8 bytes) + 11, # "☃😉🌈" (3 + 4 + 4 = 11 bytes) + 11, # "Hello\x00World" + 3, # "\xff\xfe\xfd" + 3, # "\x00\x01\x02" + 1000, # Long ASCII string + 1000, # Long binary string + ] + } + + +@pytest.mark.parametrize( + "input_data,expected_lengths", + [ + # Basic ASCII strings + ([b"hello", b"world", b"test"], [5, 5, 4]), + # Empty strings and nulls + ([b"", None, b"", None], [0, None, 0, None]), + # Special binary sequences + ( + [ + b"\x00\x01\x02", # Control characters + b"\xff\xfe\xfd", # High bytes + b"Hello\x00World", # String with null + b"\xe2\x98\x83", # UTF-8 snowman + b"\xf0\x9f\x98\x89", # UTF-8 winking face + ], + [3, 3, 11, 3, 4], + ), + # Mixed content + ( + [ + b"Hello\xe2\x98\x83World", # String with UTF-8 + b"\xf0\x9f\x98\x89test", # UTF-8 at start + b"test\xf0\x9f\x8c\x88", # UTF-8 at end + b"\xe2\x98\x83\xf0\x9f\x98\x89\xf0\x9f\x8c\x88", # Multiple UTF-8 + ], + [13, 8, 8, 11], # Fixed lengths for UTF-8 sequences + ), + # Large strings + ([b"a" * 1000, b"\xff" * 1000, b"x" * 500], [1000, 1000, 500]), + ], +) +def test_binary_length_parameterized(input_data: list[bytes | None], expected_lengths: list[int | None]) -> None: + table = MicroPartition.from_pydict({"col": input_data}) + result = table.eval_expression_list([col("col").binary.length()]) + assert result.to_pydict() == {"col": expected_lengths} + + +def test_binary_length_errors() -> None: + # Test length with wrong number of arguments + table = MicroPartition.from_pydict({"a": [b"hello", b"world"], "b": [b"foo", b"bar"]}) + with pytest.raises( + Exception, match="(?:ExpressionBinaryNamespace.)?length\\(\\) takes 1 positional argument but 2 were given" + ): + table.eval_expression_list([col("a").binary.length(col("b"))]) diff --git a/tests/table/binary/test_slice.py b/tests/table/binary/test_slice.py new file mode 100644 index 0000000000..7bf63d2ca8 --- /dev/null +++ b/tests/table/binary/test_slice.py @@ -0,0 +1,642 @@ +from __future__ import annotations + +import pytest + +from daft import DataType +from daft.expressions import col, lit +from daft.table import MicroPartition + + +def test_binary_slice() -> None: + table = MicroPartition.from_pydict( + { + "col": [ + b"foo", + None, + b"barbarbar", + b"quux", + b"1", + b"", + b"Hello\xe2\x98\x83World", # UTF-8 character in middle + b"\xf0\x9f\x98\x89test", # UTF-8 bytes at start + b"test\xf0\x9f\x8c\x88", # UTF-8 bytes at end + b"\xe2\x98\x83\xf0\x9f\x98\x89\xf0\x9f\x8c\x88", # Multiple UTF-8 sequences + b"Hello\x00World", # Null character + b"\xff\xfe\xfd", # High bytes + b"\x00\x01\x02", # Control characters + ] + } + ) + result = table.eval_expression_list([col("col").binary.slice(0, 5)]) + assert result.to_pydict() == { + "col": [ + b"foo", + None, + b"barba", + b"quux", + b"1", + b"", + b"Hello", # Should handle UTF-8 correctly + b"\xf0\x9f\x98\x89t", # Should include full UTF-8 sequence + b"test\xf0", # Should split UTF-8 sequence + b"\xe2\x98\x83\xf0\x9f", # Should split between sequences + b"Hello", # Should handle null character + b"\xff\xfe\xfd", # Should handle high bytes + b"\x00\x01\x02", # Should handle control characters + ] + } + + +@pytest.mark.parametrize( + "input_data,start_data,length_data,expected_result", + [ + # Test with column for start position + ( + [ + b"hello", + b"world", + b"test", + b"Hello\xe2\x98\x83World", + b"\xf0\x9f\x98\x89test", + b"test\xf0\x9f\x8c\x88", + b"\xff\xfe\xfd", + ], + [1, 0, 2, 5, 1, 4, 1], + 3, + [b"ell", b"wor", b"st", b"\xe2\x98\x83", b"\x9f\x98\x89", b"\xf0\x9f\x8c", b"\xfe\xfd"], + ), + # Test with column for length + ( + [ + b"hello", + b"world", + b"test", + b"Hello\xe2\x98\x83World", + b"\xf0\x9f\x98\x89test", + b"test\xf0\x9f\x8c\x88", + b"\xff\xfe\xfd", + ], + 1, + [2, 3, 4, 5, 2, 1, 2], + [b"el", b"orl", b"est", b"ello\xe2", b"\x9f\x98", b"e", b"\xfe\xfd"], + ), + # Test with both start and length as columns + ( + [ + b"hello", + b"world", + b"test", + b"Hello\xe2\x98\x83World", + b"\xf0\x9f\x98\x89test", + b"test\xf0\x9f\x8c\x88", + b"\xff\xfe\xfd", + ], + [1, 0, 2, 5, 1, 4, 0], + [2, 3, 1, 2, 3, 1, 3], + [b"el", b"wor", b"s", b"\xe2\x98", b"\x9f\x98\x89", b"\xf0", b"\xff\xfe\xfd"], + ), + # Test with nulls in start column + ( + [ + b"hello", + b"world", + b"test", + b"Hello\xe2\x98\x83World", + b"\xf0\x9f\x98\x89test", + b"test\xf0\x9f\x8c\x88", + b"\xff\xfe\xfd", + ], + [1, None, 2, None, 1, None, 1], + 3, + [b"ell", None, b"st", None, b"\x9f\x98\x89", None, b"\xfe\xfd"], + ), + # Test with nulls in length column + ( + [ + b"hello", + b"world", + b"test", + b"Hello\xe2\x98\x83World", + b"\xf0\x9f\x98\x89test", + b"test\xf0\x9f\x8c\x88", + b"\xff\xfe\xfd", + ], + 1, + [2, None, 4, None, 2, None, None], + [b"el", None, b"est", None, b"\x9f\x98", None, None], + ), + # Test with nulls in both columns + ( + [ + b"hello", + b"world", + b"test", + b"Hello\xe2\x98\x83World", + b"\xf0\x9f\x98\x89test", + b"test\xf0\x9f\x8c\x88", + b"\xff\xfe\xfd", + ], + [1, None, 2, 5, None, 4, None], + [2, 3, None, None, 2, None, 2], + [b"el", None, None, None, None, None, None], + ), + # Test with all nulls in start column + ( + [ + b"hello", + b"world", + b"test", + b"Hello\xe2\x98\x83World", + b"\xf0\x9f\x98\x89test", + b"test\xf0\x9f\x8c\x88", + b"\xff\xfe\xfd", + ], + [1, None, None, None, None, None, None], + [2, 3, 1, 2, 3, 1, 3], + [b"el", None, None, None, None, None, None], + ), + # Test with all nulls in length column + ( + [ + b"hello", + b"world", + b"test", + b"Hello\xe2\x98\x83World", + b"\xf0\x9f\x98\x89test", + b"test\xf0\x9f\x8c\x88", + b"\xff\xfe\xfd", + ], + [1, 0, 2, 5, 1, 4, 0], + [2, None, None, None, None, None, None], + [b"el", None, None, None, None, None, None], + ), + ], +) +def test_binary_slice_with_columns( + input_data: list[bytes | None], + start_data: list[int | None] | int, + length_data: list[int | None] | int, + expected_result: list[bytes | None], +) -> None: + table_data = {"col": input_data} + if isinstance(start_data, list): + table_data["start"] = start_data + start = col("start") + else: + start = start_data + + if isinstance(length_data, list): + table_data["length"] = length_data + length = col("length") + else: + length = length_data + + table = MicroPartition.from_pydict(table_data) + result = table.eval_expression_list([col("col").binary.slice(start, length)]) + assert result.to_pydict() == {"col": expected_result} + + +@pytest.mark.parametrize( + "input_data,start,length,expected_result", + [ + # Test start beyond string length + ( + [ + b"hello", + b"world", + b"Hello\xe2\x98\x83World", + b"\xf0\x9f\x98\x89test", + b"test\xf0\x9f\x8c\x88", + b"\xff\xfe\xfd\xfc", + ], + [10, 20, 15, 10, 10, 5], + 2, + [b"", b"", b"", b"", b"", b""], + ), + # Test start way beyond string length + ( + [ + b"hello", # len 5 + b"world", # len 5 + b"test", # len 4 + b"\xff\xfe\xfd", # len 3 + ], + [100, 1000, 50, 25], + 5, + [b"", b"", b"", b""], + ), + # Test start beyond length with None length + ( + [ + b"hello", + b"world", + b"test", + b"\xff\xfe\xfd", + ], + [10, 20, 15, 8], + None, + [b"", b"", b"", b""], + ), + # Test zero length + ( + [ + b"hello", + b"world", + b"Hello\xe2\x98\x83World", + b"\xf0\x9f\x98\x89test", + b"test\xf0\x9f\x8c\x88", + b"\xff\xfe\xfd\xfc", + ], + [1, 0, 5, 0, 4, 2], + 0, + [b"", b"", b"", b"", b"", b""], + ), + # Test very large length + ( + [ + b"hello", + b"world", + b"Hello\xe2\x98\x83World", + b"\xf0\x9f\x98\x89test", + b"test\xf0\x9f\x8c\x88", + b"\xff\xfe\xfd\xfc", + ], + [0, 1, 5, 0, 4, 1], + 100, + [b"hello", b"orld", b"\xe2\x98\x83World", b"\xf0\x9f\x98\x89test", b"\xf0\x9f\x8c\x88", b"\xfe\xfd\xfc"], + ), + # Test empty strings + ( + [b"", b"", b"", b""], + [0, 1, 2, 3], + 3, + [b"", b"", b"", b""], + ), + # Test start + length overflow + ( + [ + b"hello", + b"world", + b"Hello\xe2\x98\x83World", + b"\xf0\x9f\x98\x89test", + b"test\xf0\x9f\x8c\x88", + b"\xff\xfe\xfd\xfc", + ], + [2, 3, 5, 0, 4, 2], + 9999999999, + [b"llo", b"ld", b"\xe2\x98\x83World", b"\xf0\x9f\x98\x89test", b"\xf0\x9f\x8c\x88", b"\xfd\xfc"], + ), + # Test UTF-8 and binary sequence boundaries + ( + [ + b"Hello\xe2\x98\x83World", + b"\xf0\x9f\x98\x89test", + b"test\xf0\x9f\x8c\x88", + b"\xff\xfe\xfd\xfc", + b"\x00\x01\x02\x03", + ], + [4, 0, 3, 1, 2], + 2, + [b"o\xe2", b"\xf0\x9f", b"t\xf0", b"\xfe\xfd", b"\x02\x03"], + ), + ], +) +def test_binary_slice_edge_cases( + input_data: list[bytes], + start: list[int], + length: int, + expected_result: list[bytes | None], +) -> None: + table = MicroPartition.from_pydict({"col": input_data, "start": start}) + result = table.eval_expression_list([col("col").binary.slice(col("start"), length)]) + assert result.to_pydict() == {"col": expected_result} + + +def test_binary_slice_errors() -> None: + # Test negative start + table = MicroPartition.from_pydict( + {"col": [b"hello", b"world", b"Hello\xe2\x98\x83World", b"\xff\xfe\xfd"], "start": [-1, -2, -3, -1]} + ) + with pytest.raises(Exception, match="DaftError::ComputeError Failed to cast numeric value to target type"): + table.eval_expression_list([col("col").binary.slice(col("start"), 2)]) + + # Test negative length + table = MicroPartition.from_pydict({"col": [b"hello", b"world", b"Hello\xe2\x98\x83World", b"\xff\xfe\xfd"]}) + with pytest.raises(Exception, match="DaftError::ComputeError Failed to cast numeric value to target type"): + table.eval_expression_list([col("col").binary.slice(0, -3)]) + + # Test both negative + table = MicroPartition.from_pydict( + {"col": [b"hello", b"world", b"Hello\xe2\x98\x83World", b"\xff\xfe\xfd"], "start": [-2, -1, -3, -2]} + ) + with pytest.raises(Exception, match="DaftError::ComputeError Failed to cast numeric value to target type"): + table.eval_expression_list([col("col").binary.slice(col("start"), -2)]) + + # Test negative length in column + table = MicroPartition.from_pydict( + {"col": [b"hello", b"world", b"Hello\xe2\x98\x83World", b"\xff\xfe\xfd"], "length": [-2, -3, -4, -2]} + ) + with pytest.raises(Exception, match="DaftError::ComputeError Failed to cast numeric value to target type"): + table.eval_expression_list([col("col").binary.slice(0, col("length"))]) + + # Test slice with wrong number of arguments (too many) + table = MicroPartition.from_pydict( + {"col": [b"hello", b"world"], "start": [1, 2], "length": [2, 3], "extra": [4, 5]} + ) + with pytest.raises( + Exception, + match="(?:ExpressionBinaryNamespace.)?slice\\(\\) takes from 2 to 3 positional arguments but 4 were given", + ): + table.eval_expression_list([col("col").binary.slice(col("start"), col("length"), col("extra"))]) + + # Test slice with wrong number of arguments (too few) + table = MicroPartition.from_pydict({"col": [b"hello", b"world"], "start": [1, 2]}) + with pytest.raises( + Exception, match="(?:ExpressionBinaryNamespace.)?slice\\(\\) missing 1 required positional argument: 'start'" + ): + table.eval_expression_list([col("col").binary.slice()]) + + +def test_binary_slice_computed() -> None: + # Test with computed start index (length - 5) + table = MicroPartition.from_pydict( + { + "col": [ + b"hello world", # len=11, start=6, expect "world" + b"python programming", # len=17, start=12, expect "mming" + b"data science", # len=12, start=7, expect "ience" + b"artificial", # len=10, start=5, expect "icial" + b"intelligence", # len=12, start=7, expect "gence" + b"Hello\xe2\x98\x83World", # len=12, start=7, expect "World" + b"test\xf0\x9f\x98\x89test", # len=12, start=7, expect "test" + b"test\xf0\x9f\x8c\x88test", # len=12, start=7, expect "test" + b"\xff\xfe\xfd\xfc\xfb", # len=5, start=0, expect "\xff\xfe\xfd" + ] + } + ) + result = table.eval_expression_list( + [ + col("col").binary.slice( + (col("col").binary.length() - 5).cast(DataType.int32()), # start 5 chars from end + 3, # take 3 chars + ) + ] + ) + assert result.to_pydict() == { + "col": [b"wor", b"mmi", b"ien", b"ici", b"gen", b"Wor", b"\x89te", b"\x88te", b"\xff\xfe\xfd"] + } + + # Test with computed length (half of string length) + table = MicroPartition.from_pydict( + { + "col": [ + b"hello world", # len=11, len/2=5, expect "hello" + b"python programming", # len=17, len/2=8, expect "python pr" + b"data science", # len=12, len/2=6, expect "data s" + b"artificial", # len=10, len/2=5, expect "artif" + b"intelligence", # len=12, len/2=6, expect "intell" + b"Hello\xe2\x98\x83World", # len=12, len/2=6, expect "Hello\xe2" + b"test\xf0\x9f\x98\x89test", # len=12, len/2=6, expect "test\xf0\x9f" + b"test\xf0\x9f\x8c\x88test", # len=12, len/2=6, expect "test\xf0\x9f" + b"\xff\xfe\xfd\xfc\xfb", # len=5, len/2=2, expect "\xff\xfe" + ] + } + ) + result = table.eval_expression_list( + [ + col("col").binary.slice( + 0, # start from beginning + (col("col").binary.length() / 2).cast(DataType.int32()), # take half of string + ) + ] + ) + assert result.to_pydict() == { + "col": [ + b"hello", + b"python pr", + b"data s", + b"artif", + b"intell", + b"Hello\xe2", + b"test\xf0\x9f", + b"test\xf0\x9f", + b"\xff\xfe", + ] + } + + # Test with both computed start and length + table = MicroPartition.from_pydict( + { + "col": [ + b"hello world", # len=11, start=2, len=3, expect "llo" + b"python programming", # len=17, start=3, len=5, expect "hon pr" + b"data science", # len=12, start=2, len=4, expect "ta s" + b"artificial", # len=10, start=2, len=3, expect "tif" + b"intelligence", # len=12, start=2, len=4, expect "tell" + b"Hello\xe2\x98\x83World", # len=12, start=2, len=4, expect "llo\xe2" + b"test\xf0\x9f\x98\x89test", # len=12, start=2, len=4, expect "st\xf0\x9f" + b"test\xf0\x9f\x8c\x88test", # len=12, start=2, len=4, expect "st\xf0\x9f" + b"\xff\xfe\xfd\xfc\xfb", # len=5, start=1, len=2, expect "\xfe\xfd" + ] + } + ) + result = table.eval_expression_list( + [ + col("col").binary.slice( + (col("col").binary.length() / 5).cast(DataType.int32()), # start at 1/5 of string + (col("col").binary.length() / 3).cast(DataType.int32()), # take 1/3 of string + ) + ] + ) + assert result.to_pydict() == { + "col": [b"llo", b"hon pr", b"ta s", b"tif", b"tell", b"llo\xe2", b"st\xf0\x9f", b"st\xf0\x9f", b"\xfe"] + } + + +def test_binary_slice_type_errors() -> None: + # Test slice with string start type + table = MicroPartition.from_pydict({"col": [b"hello", b"world"], "start": ["1", "2"]}) + with pytest.raises( + Exception, + match="Expects inputs to binary_slice to be binary, integer and integer or null but received Binary, Utf8 and Int32", + ): + table.eval_expression_list([col("col").binary.slice(col("start"), 2)]) + + # Test slice with float start type + table = MicroPartition.from_pydict({"col": [b"hello", b"world"], "start": [1.5, 2.5]}) + with pytest.raises( + Exception, + match="Expects inputs to binary_slice to be binary, integer and integer or null but received Binary, Float64 and Int32", + ): + table.eval_expression_list([col("col").binary.slice(col("start"), 2)]) + + # Test slice with boolean start type + table = MicroPartition.from_pydict({"col": [b"hello", b"world"], "start": [True, False]}) + with pytest.raises( + Exception, + match="Expects inputs to binary_slice to be binary, integer and integer or null but received Binary, Boolean and Int32", + ): + table.eval_expression_list([col("col").binary.slice(col("start"), 2)]) + + # Test slice with binary start type + table = MicroPartition.from_pydict({"col": [b"hello", b"world"], "start": [b"1", b"2"]}) + with pytest.raises( + Exception, + match="Expects inputs to binary_slice to be binary, integer and integer or null but received Binary, Binary and Int32", + ): + table.eval_expression_list([col("col").binary.slice(col("start"), 2)]) + + # Test slice with null start type + table = MicroPartition.from_pydict({"col": [b"hello", b"world"], "start": [None, None]}) + with pytest.raises( + Exception, + match="Expects inputs to binary_slice to be binary, integer and integer or null but received Binary, Null and Int32", + ): + table.eval_expression_list([col("col").binary.slice(col("start"), 2)]) + + +def test_binary_slice_multiple_slices() -> None: + # Test taking multiple different slices from the same binary data + table = MicroPartition.from_pydict( + { + "col": [ + b"hello", # Simple ASCII + b"Hello\xe2\x98\x83World", # With UTF-8 character + b"\xf0\x9f\x98\x89test\xf0\x9f\x8c\x88", # Multiple UTF-8 sequences + b"\xff\xfe\xfd\xfc\xfb", # Raw bytes + ] + } + ) + + # Get multiple slices + result = table.eval_expression_list( + [ + col("col").binary.slice(1, 3).alias("slice1"), # Middle slice + col("col").binary.slice(0, 1).alias("slice2"), # First byte + col("col").binary.slice(2, 2).alias("slice3"), # Another middle slice + col("col") + .binary.slice((col("col").binary.length().cast(DataType.int64()) - 1), 1) + .alias("slice4"), # Last byte + ] + ) + + assert result.to_pydict() == { + "slice1": [b"ell", b"ell", b"\x9f\x98\x89", b"\xfe\xfd\xfc"], + "slice2": [b"h", b"H", b"\xf0", b"\xff"], + "slice3": [b"ll", b"ll", b"\x98\x89", b"\xfd\xfc"], + "slice4": [b"o", b"d", b"\x88", b"\xfb"], + } + + # Test with computed indices + result = table.eval_expression_list( + [ + # First half + col("col").binary.slice(0, (col("col").binary.length() / 2).cast(DataType.int32())).alias("first_half"), + # Second half + col("col") + .binary.slice( + (col("col").binary.length() / 2).cast(DataType.int32()), + (col("col").binary.length() / 2).cast(DataType.int32()), + ) + .alias("second_half"), + # Middle third + col("col") + .binary.slice( + (col("col").binary.length() / 3).cast(DataType.int32()), + (col("col").binary.length() / 3).cast(DataType.int32()), + ) + .alias("middle_third"), + ] + ) + + assert result.to_pydict() == { + "first_half": [b"he", b"Hello\xe2", b"\xf0\x9f\x98\x89te", b"\xff\xfe"], + "second_half": [b"ll", b"\x98\x83Worl", b"st\xf0\x9f\x8c\x88", b"\xfd\xfc"], + "middle_third": [b"e", b"o\xe2\x98\x83", b"test", b"\xfe"], + } + + +@pytest.mark.parametrize( + "input_data,start,length,expected_result", + [ + # Test single binary value with multiple start positions + ( + b"hello", # Single binary value + [0, 1, 2, 3, 4, 5], # Multiple start positions + 2, # Fixed length + [b"he", b"el", b"ll", b"lo", b"o", b""], # Expected slices + ), + # Test single binary value with multiple lengths + ( + b"hello", + 1, # Fixed start + [1, 2, 3, 4], # Multiple lengths + [b"e", b"el", b"ell", b"ello"], # Expected slices + ), + # Test single binary value with both start and length as arrays + ( + b"hello", + [0, 1, 2, 3], # Multiple starts + [1, 2, 3, 2], # Multiple lengths + [b"h", b"el", b"llo", b"lo"], # Expected slices + ), + # Test with UTF-8 sequences + ( + # Single UTF-8 string with snowman character + b"Hello\xe2\x98\x83World", + [0, 5, 8], # Multiple starts + [5, 3, 5], # Multiple lengths + # Expected: "Hello", snowman char, "World" + [b"Hello", b"\xe2\x98\x83", b"World"], + ), + # Test with binary data + ( + b"\xff\xfe\xfd\xfc\xfb", # Single binary sequence + [0, 2, 4], # Multiple starts + [2, 2, 1], # Multiple lengths + [b"\xff\xfe", b"\xfd\xfc", b"\xfb"], # Expected slices + ), + # Test edge cases with single binary value + ( + b"test", + [0, 4, 2], # Start at beginning, end, and middle + [4, 0, 5], # Full length, zero length, overflow length + [b"test", b"", b"st"], # Expected results + ), + # Test with nulls in start/length + ( + b"hello", + [0, None, 2, None], # Mix of valid starts and nulls + [2, 3, None, None], # Mix of valid lengths and nulls + [b"he", None, None, None], # Expected results with nulls + ), + ], +) +def test_binary_slice_broadcasting( + input_data: bytes, + start: list[int] | int, + length: list[int] | int, + expected_result: list[bytes | None], +) -> None: + # Create table with slice parameters + table_data = {} + + # Add start/length columns if they are arrays + if isinstance(start, list): + table_data["start"] = start + start_expr = col("start") + else: + start_expr = start + + if isinstance(length, list): + table_data["length"] = length + length_expr = col("length") + else: + length_expr = length + + # Create table with just the slice parameters + table = MicroPartition.from_pydict(table_data) + + # Perform slice operation on the raw binary value + result = table.eval_expression_list([lit(input_data).binary.slice(start_expr, length_expr)]) + assert result.to_pydict() == {"literal": expected_result} diff --git a/tools/git_utils.py b/tools/git_utils.py index a149363bdc..0e75301646 100644 --- a/tools/git_utils.py +++ b/tools/git_utils.py @@ -77,7 +77,7 @@ def get_name_and_commit_hash(branch_name: Optional[str]) -> tuple[str, str]: def parse_questions(questions: Optional[str], total_number_of_questions: int) -> list[int]: if questions is None: - return list(range(total_number_of_questions)) + return list(range(1, total_number_of_questions + 1)) else: def to_int(q: str) -> int: diff --git a/tools/tpcds.py b/tools/tpcds.py index d6ec36cc21..cf8aed8998 100644 --- a/tools/tpcds.py +++ b/tools/tpcds.py @@ -63,7 +63,7 @@ def run( ) parser.add_argument( "--cluster-profile", - choices=["debug_xs-x86", "medium-x86"], + choices=["debug_xs-x86", "medium-x86", "benchmarking-arm"], type=str, required=False, help="The ray cluster configuration to run on",