diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index 7174ef73f4..b213972d1d 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -686,7 +686,9 @@ def write_csv( ) @DataframePublicAPI - def write_iceberg(self, table: "pyiceberg.table.Table", mode: str = "append") -> "DataFrame": + def write_iceberg( + self, table: "pyiceberg.table.Table", mode: str = "append", io_config: Optional[IOConfig] = None + ) -> "DataFrame": """Writes the DataFrame to an `Iceberg `__ table, returning a new DataFrame with the operations that occurred. Can be run in either `append` or `overwrite` mode which will either appends the rows in the DataFrame or will delete the existing rows and then append the DataFrame rows respectively. @@ -697,6 +699,7 @@ def write_iceberg(self, table: "pyiceberg.table.Table", mode: str = "append") -> Args: table (pyiceberg.table.Table): Destination `PyIceberg Table `__ to write dataframe to. mode (str, optional): Operation mode of the write. `append` or `overwrite` Iceberg Table. Defaults to "append". + io_config (IOConfig, optional): A custom IOConfig to use when accessing Iceberg object storage data. If provided, configurations set in `table` are ignored. Returns: DataFrame: The operations that occurred with this write. @@ -705,6 +708,8 @@ def write_iceberg(self, table: "pyiceberg.table.Table", mode: str = "append") -> import pyiceberg from packaging.version import parse + from daft.io._iceberg import _convert_iceberg_file_io_properties_to_io_config + if len(table.spec().fields) > 0 and parse(pyiceberg.__version__) < parse("0.7.0"): raise ValueError("pyiceberg>=0.7.0 is required to write to a partitioned table") @@ -719,12 +724,17 @@ def write_iceberg(self, table: "pyiceberg.table.Table", mode: str = "append") -> if mode not in ["append", "overwrite"]: raise ValueError(f"Only support `append` or `overwrite` mode. {mode} is unsupported") + io_config = ( + _convert_iceberg_file_io_properties_to_io_config(table.io.properties) if io_config is None else io_config + ) + io_config = get_context().daft_planning_config.default_io_config if io_config is None else io_config + operations = [] path = [] rows = [] size = [] - builder = self._builder.write_iceberg(table) + builder = self._builder.write_iceberg(table, io_config) write_df = DataFrame(builder) write_df.collect() diff --git a/daft/io/_iceberg.py b/daft/io/_iceberg.py index dbf94dd76d..c3ea30aaa9 100644 --- a/daft/io/_iceberg.py +++ b/daft/io/_iceberg.py @@ -9,82 +9,51 @@ from daft.logical.builder import LogicalPlanBuilder if TYPE_CHECKING: - from pyiceberg.table import Table as PyIcebergTable + import pyiceberg def _convert_iceberg_file_io_properties_to_io_config(props: Dict[str, Any]) -> Optional["IOConfig"]: - import pyiceberg - from packaging.version import parse - from pyiceberg.io import ( - S3_ACCESS_KEY_ID, - S3_ENDPOINT, - S3_REGION, - S3_SECRET_ACCESS_KEY, - S3_SESSION_TOKEN, - ) - + """Property keys defined here: https://github.com/apache/iceberg-python/blob/main/pyiceberg/io/__init__.py.""" from daft.io import AzureConfig, GCSConfig, IOConfig, S3Config - s3_mapping = { - S3_REGION: "region_name", - S3_ENDPOINT: "endpoint_url", - S3_ACCESS_KEY_ID: "key_id", - S3_SECRET_ACCESS_KEY: "access_key", - S3_SESSION_TOKEN: "session_token", - } - s3_args = dict() # type: ignore - for pyiceberg_key, daft_key in s3_mapping.items(): - value = props.get(pyiceberg_key, None) - if value is not None: - s3_args[daft_key] = value - - if len(s3_args) > 0: - s3_config = S3Config(**s3_args) - else: - s3_config = None - - gcs_config = None - azure_config = None - if parse(pyiceberg.__version__) >= parse("0.5.0"): - from pyiceberg.io import GCS_PROJECT_ID, GCS_TOKEN - - gcs_mapping = {GCS_PROJECT_ID: "project_id", GCS_TOKEN: "token"} - gcs_args = dict() # type: ignore - for pyiceberg_key, daft_key in gcs_mapping.items(): - value = props.get(pyiceberg_key, None) - if value is not None: - gcs_args[daft_key] = value - - if len(gcs_args) > 0: - gcs_config = GCSConfig(**gcs_args) - - azure_mapping = { - "adlfs.account-name": "storage_account", - "adlfs.account-key": "access_key", - "adlfs.sas-token": "sas_token", - "adlfs.tenant-id": "tenant_id", - "adlfs.client-id": "client_id", - "adlfs.client-secret": "client_secret", - } - - azure_args = dict() # type: ignore - for pyiceberg_key, daft_key in azure_mapping.items(): - value = props.get(pyiceberg_key, None) - if value is not None: - azure_args[daft_key] = value - - if len(azure_args) > 0: - azure_config = AzureConfig(**azure_args) - - if any([s3_config, gcs_config, azure_config]): - return IOConfig(s3=s3_config, gcs=gcs_config, azure=azure_config) - else: + any_props_set = False + + def get_first_property_value(*property_names: str) -> Optional[Any]: + for property_name in property_names: + if property_value := props.get(property_name): + nonlocal any_props_set + any_props_set = True + return property_value return None + io_config = IOConfig( + s3=S3Config( + endpoint_url=get_first_property_value("s3.endpoint"), + region_name=get_first_property_value("s3.region", "client.region"), + key_id=get_first_property_value("s3.access-key-id", "client.access-key-id"), + access_key=get_first_property_value("s3.secret-access-key", "client.secret-access-key"), + session_token=get_first_property_value("s3.session-token", "client.session-token"), + ), + azure=AzureConfig( + storage_account=get_first_property_value("adls.account-name", "adlfs.account-name"), + access_key=get_first_property_value("adls.account-key", "adlfs.account-key"), + sas_token=get_first_property_value("adls.sas-token", "adlfs.sas-token"), + tenant_id=get_first_property_value("adls.tenant-id", "adlfs.tenant-id"), + client_id=get_first_property_value("adls.client-id", "adlfs.client-id"), + client_secret=get_first_property_value("adls.client-secret", "adlfs.client-secret"), + ), + gcs=GCSConfig( + project_id=get_first_property_value("gcs.project-id"), + token=get_first_property_value("gcs.oauth2.token"), + ), + ) + + return io_config if any_props_set else None + @PublicAPI def read_iceberg( - pyiceberg_table: "PyIcebergTable", + table: "pyiceberg.table.Table", snapshot_id: Optional[int] = None, io_config: Optional["IOConfig"] = None, ) -> DataFrame: @@ -93,8 +62,8 @@ def read_iceberg( Example: >>> import pyiceberg >>> - >>> pyiceberg_table = pyiceberg.Table(...) - >>> df = daft.read_iceberg(pyiceberg_table) + >>> table = pyiceberg.Table(...) + >>> df = daft.read_iceberg(table) >>> >>> # Filters on this dataframe can now be pushed into >>> # the read operation from Iceberg @@ -106,9 +75,9 @@ def read_iceberg( official project for Python. Args: - pyiceberg_table: Iceberg table created using the PyIceberg library - snapshot_id: Snapshot ID of the table to query - io_config: A custom IOConfig to use when accessing Iceberg object storage data. Defaults to None. + table (pyiceberg.table.Table): `PyIceberg Table `__ created using the PyIceberg library + snapshot_id (int, optional): Snapshot ID of the table to query + io_config (IOConfig, optional): A custom IOConfig to use when accessing Iceberg object storage data. If provided, configurations set in `table` are ignored. Returns: DataFrame: a DataFrame with the schema converted from the specified Iceberg table @@ -116,16 +85,14 @@ def read_iceberg( from daft.iceberg.iceberg_scan import IcebergScanOperator io_config = ( - _convert_iceberg_file_io_properties_to_io_config(pyiceberg_table.io.properties) - if io_config is None - else io_config + _convert_iceberg_file_io_properties_to_io_config(table.io.properties) if io_config is None else io_config ) io_config = context.get_context().daft_planning_config.default_io_config if io_config is None else io_config multithreaded_io = context.get_context().get_or_create_runner().name != "ray" storage_config = StorageConfig(multithreaded_io, io_config) - iceberg_operator = IcebergScanOperator(pyiceberg_table, snapshot_id=snapshot_id, storage_config=storage_config) + iceberg_operator = IcebergScanOperator(table, snapshot_id=snapshot_id, storage_config=storage_config) handle = ScanOperatorHandle.from_python_scan_operator(iceberg_operator) builder = LogicalPlanBuilder.from_tabular_scan(scan_operator=handle) diff --git a/daft/logical/builder.py b/daft/logical/builder.py index b7316a0a80..e97e07ae4c 100644 --- a/daft/logical/builder.py +++ b/daft/logical/builder.py @@ -297,9 +297,8 @@ def write_tabular( builder = self._builder.table_write(str(root_dir), file_format, part_cols_pyexprs, compression, io_config) return LogicalPlanBuilder(builder) - def write_iceberg(self, table: IcebergTable) -> LogicalPlanBuilder: + def write_iceberg(self, table: IcebergTable, io_config: IOConfig) -> LogicalPlanBuilder: from daft.iceberg.iceberg_write import get_missing_columns, partition_field_to_expr - from daft.io._iceberg import _convert_iceberg_file_io_properties_to_io_config name = ".".join(table.name()) location = f"{table.location()}/data" @@ -314,7 +313,6 @@ def write_iceberg(self, table: IcebergTable) -> LogicalPlanBuilder: partition_cols = [partition_field_to_expr(field, schema)._expr for field in partition_spec.fields] props = table.properties columns = [col.name for col in schema.columns] - io_config = _convert_iceberg_file_io_properties_to_io_config(table.io.properties) builder = builder.iceberg_write( name, location, partition_spec.spec_id, partition_cols, schema, props, columns, io_config )