diff --git a/daft/datatype.py b/daft/datatype.py index 646dd231ec..b15902c41d 100644 --- a/daft/datatype.py +++ b/daft/datatype.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Union from daft.context import get_context from daft.daft import ImageMode, PyDataType, PyTimeUnit @@ -83,6 +83,40 @@ def __init__(self) -> None: "use a creator method like DataType.int32() or use DataType.from_arrow_type(pa_type)" ) + @classmethod + def _infer_type(cls, user_provided_type: DataTypeLike) -> DataType: + from typing import get_args, get_origin + + if isinstance(user_provided_type, DataType): + return user_provided_type + elif isinstance(user_provided_type, dict): + return DataType.struct({k: DataType._infer_type(user_provided_type[k]) for k in user_provided_type}) + elif get_origin(user_provided_type) is not None: + origin_type = get_origin(user_provided_type) + if origin_type is list: + child_type = get_args(user_provided_type)[0] + return DataType.list(DataType._infer_type(child_type)) + elif origin_type is dict: + (key_type, val_type) = get_args(user_provided_type) + return DataType.map(DataType._infer_type(key_type), DataType._infer_type(val_type)) + else: + raise ValueError(f"Unrecognized Python origin type, cannot convert to Daft type: {origin_type}") + elif isinstance(user_provided_type, type): + if user_provided_type is str: + return DataType.string() + elif user_provided_type is int: + return DataType.int64() + elif user_provided_type is float: + return DataType.float64() + elif user_provided_type is bytes: + return DataType.binary() + elif user_provided_type is object: + return DataType.python() + else: + raise ValueError(f"Unrecognized Python type, cannot convert to Daft type: {user_provided_type}") + else: + raise ValueError(f"Unable to infer Daft DataType for provided value: {user_provided_type}") + @staticmethod def _from_pydatatype(pydt: PyDataType) -> DataType: dt = DataType.__new__(DataType) @@ -538,6 +572,10 @@ def __hash__(self) -> int: return self._dtype.__hash__() +# Type alias for a union of types that can be inferred into a DataType +DataTypeLike = Union[DataType, type] + + _EXT_TYPE_REGISTERED = False _STATIC_DAFT_EXTENSION = None diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 1dfd4730a7..5e865ff936 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -36,7 +36,7 @@ from daft.daft import udf as _udf from daft.daft import url_download as _url_download from daft.daft import utf8_count_matches as _utf8_count_matches -from daft.datatype import DataType, TimeUnit +from daft.datatype import DataType, DataTypeLike, TimeUnit from daft.dependencies import pa from daft.expressions.testing import expr_structurally_equal from daft.logical.schema import Field, Schema @@ -542,7 +542,7 @@ def alias(self, name: builtins.str) -> Expression: expr = self._expr.alias(name) return Expression._from_pyexpr(expr) - def cast(self, dtype: DataType) -> Expression: + def cast(self, dtype: DataTypeLike) -> Expression: """Casts an expression to the given datatype if possible. The following combinations of datatype casting is valid: @@ -622,8 +622,10 @@ def cast(self, dtype: DataType) -> Expression: Returns: Expression: Expression with the specified new datatype """ - assert isinstance(dtype, DataType) - expr = self._expr.cast(dtype._dtype) + assert isinstance(dtype, (DataType, type)) + inferred_dtype = DataType._infer_type(dtype) + + expr = self._expr.cast(inferred_dtype._dtype) return Expression._from_pyexpr(expr) def ceil(self) -> Expression: @@ -999,7 +1001,7 @@ def if_else(self, if_true: Expression, if_false: Expression) -> Expression: if_false = Expression._to_expression(if_false) return Expression._from_pyexpr(self._expr.if_else(if_true._expr, if_false._expr)) - def apply(self, func: Callable, return_dtype: DataType) -> Expression: + def apply(self, func: Callable, return_dtype: DataTypeLike) -> Expression: """Apply a function on each value in a given expression. .. NOTE:: @@ -1039,6 +1041,8 @@ def apply(self, func: Callable, return_dtype: DataType) -> Expression: """ from daft.udf import UDF + inferred_return_dtype = DataType._infer_type(return_dtype) + def batch_func(self_series): return [func(x) for x in self_series.to_pylist()] @@ -1050,7 +1054,7 @@ def batch_func(self_series): return UDF( inner=batch_func, name=name, - return_dtype=return_dtype, + return_dtype=inferred_return_dtype, )(self) def is_null(self) -> Expression: diff --git a/daft/udf.py b/daft/udf.py index ccd13ec808..36e841c683 100644 --- a/daft/udf.py +++ b/daft/udf.py @@ -6,7 +6,7 @@ from typing import Any, Callable, Dict, Optional, Tuple, Union from daft.daft import PyDataType, ResourceRequest -from daft.datatype import DataType +from daft.datatype import DataType, DataTypeLike from daft.dependencies import np, pa from daft.expressions import Expression from daft.series import PySeries, Series @@ -394,7 +394,7 @@ def __hash__(self) -> int: def udf( *, - return_dtype: DataType, + return_dtype: DataTypeLike, num_cpus: float | None = None, num_gpus: float | None = None, memory_bytes: int | None = None, @@ -511,6 +511,7 @@ def udf( Returns: Callable[[UserDefinedPyFuncLike], UDF]: UDF decorator - converts a user-provided Python function as a UDF that can be called on Expressions """ + inferred_return_dtype = DataType._infer_type(return_dtype) def _udf(f: UserDefinedPyFuncLike) -> UDF: # Grab a name for the UDF. It **should** be unique. @@ -534,7 +535,7 @@ def _udf(f: UserDefinedPyFuncLike) -> UDF: return UDF( inner=f, name=name, - return_dtype=return_dtype, + return_dtype=inferred_return_dtype, resource_request=resource_request, batch_size=batch_size, ) diff --git a/tests/test_datatypes.py b/tests/test_datatypes.py index 5df20df90f..27270be427 100644 --- a/tests/test_datatypes.py +++ b/tests/test_datatypes.py @@ -1,6 +1,8 @@ from __future__ import annotations import copy +import sys +from typing import Dict, List import pytest @@ -30,3 +32,59 @@ def test_datatype_pickling(dtype) -> None: copy_dtype = copy.deepcopy(dtype) assert copy_dtype == dtype + + +@pytest.mark.parametrize( + ["source", "expected"], + [ + (str, DataType.string()), + (int, DataType.int64()), + (float, DataType.float64()), + (bytes, DataType.binary()), + (object, DataType.python()), + ( + {"foo": str, "bar": int}, + DataType.struct({"foo": DataType.string(), "bar": DataType.int64()}), + ), + ], +) +def test_datatype_parsing(source, expected): + assert DataType._infer_type(source) == expected + + +# These tests are only valid for more modern versions of Python, but can't be skipped in the conventional +# way either because we cannot even run the subscripting during import-time +if sys.version_info >= (3, 9): + + @pytest.mark.parametrize( + ["source", "expected"], + [ + # These tests must be run in later version of Python that allow for subscripting of types + (list[str], DataType.list(DataType.string())), + (dict[str, int], DataType.map(DataType.string(), DataType.int64())), + ( + {"foo": list[str], "bar": int}, + DataType.struct({"foo": DataType.list(DataType.string()), "bar": DataType.int64()}), + ), + (list[list[str]], DataType.list(DataType.list(DataType.string()))), + ], + ) + def test_subscripted_datatype_parsing(source, expected): + assert DataType._infer_type(source) == expected + + +@pytest.mark.parametrize( + ["source", "expected"], + [ + # These tests must be run in later version of Python that allow for subscripting of types + (List[str], DataType.list(DataType.string())), + (Dict[str, int], DataType.map(DataType.string(), DataType.int64())), + ( + {"foo": List[str], "bar": int}, + DataType.struct({"foo": DataType.list(DataType.string()), "bar": DataType.int64()}), + ), + (List[List[str]], DataType.list(DataType.list(DataType.string()))), + ], +) +def test_legacy_subscripted_datatype_parsing(source, expected): + assert DataType._infer_type(source) == expected