diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py
index 7174ef73f4..b213972d1d 100644
--- a/daft/dataframe/dataframe.py
+++ b/daft/dataframe/dataframe.py
@@ -686,7 +686,9 @@ def write_csv(
)
@DataframePublicAPI
- def write_iceberg(self, table: "pyiceberg.table.Table", mode: str = "append") -> "DataFrame":
+ def write_iceberg(
+ self, table: "pyiceberg.table.Table", mode: str = "append", io_config: Optional[IOConfig] = None
+ ) -> "DataFrame":
"""Writes the DataFrame to an `Iceberg `__ table, returning a new DataFrame with the operations that occurred.
Can be run in either `append` or `overwrite` mode which will either appends the rows in the DataFrame or will delete the existing rows and then append the DataFrame rows respectively.
@@ -697,6 +699,7 @@ def write_iceberg(self, table: "pyiceberg.table.Table", mode: str = "append") ->
Args:
table (pyiceberg.table.Table): Destination `PyIceberg Table `__ to write dataframe to.
mode (str, optional): Operation mode of the write. `append` or `overwrite` Iceberg Table. Defaults to "append".
+ io_config (IOConfig, optional): A custom IOConfig to use when accessing Iceberg object storage data. If provided, configurations set in `table` are ignored.
Returns:
DataFrame: The operations that occurred with this write.
@@ -705,6 +708,8 @@ def write_iceberg(self, table: "pyiceberg.table.Table", mode: str = "append") ->
import pyiceberg
from packaging.version import parse
+ from daft.io._iceberg import _convert_iceberg_file_io_properties_to_io_config
+
if len(table.spec().fields) > 0 and parse(pyiceberg.__version__) < parse("0.7.0"):
raise ValueError("pyiceberg>=0.7.0 is required to write to a partitioned table")
@@ -719,12 +724,17 @@ def write_iceberg(self, table: "pyiceberg.table.Table", mode: str = "append") ->
if mode not in ["append", "overwrite"]:
raise ValueError(f"Only support `append` or `overwrite` mode. {mode} is unsupported")
+ io_config = (
+ _convert_iceberg_file_io_properties_to_io_config(table.io.properties) if io_config is None else io_config
+ )
+ io_config = get_context().daft_planning_config.default_io_config if io_config is None else io_config
+
operations = []
path = []
rows = []
size = []
- builder = self._builder.write_iceberg(table)
+ builder = self._builder.write_iceberg(table, io_config)
write_df = DataFrame(builder)
write_df.collect()
diff --git a/daft/io/_iceberg.py b/daft/io/_iceberg.py
index dbf94dd76d..c3ea30aaa9 100644
--- a/daft/io/_iceberg.py
+++ b/daft/io/_iceberg.py
@@ -9,82 +9,51 @@
from daft.logical.builder import LogicalPlanBuilder
if TYPE_CHECKING:
- from pyiceberg.table import Table as PyIcebergTable
+ import pyiceberg
def _convert_iceberg_file_io_properties_to_io_config(props: Dict[str, Any]) -> Optional["IOConfig"]:
- import pyiceberg
- from packaging.version import parse
- from pyiceberg.io import (
- S3_ACCESS_KEY_ID,
- S3_ENDPOINT,
- S3_REGION,
- S3_SECRET_ACCESS_KEY,
- S3_SESSION_TOKEN,
- )
-
+ """Property keys defined here: https://github.com/apache/iceberg-python/blob/main/pyiceberg/io/__init__.py."""
from daft.io import AzureConfig, GCSConfig, IOConfig, S3Config
- s3_mapping = {
- S3_REGION: "region_name",
- S3_ENDPOINT: "endpoint_url",
- S3_ACCESS_KEY_ID: "key_id",
- S3_SECRET_ACCESS_KEY: "access_key",
- S3_SESSION_TOKEN: "session_token",
- }
- s3_args = dict() # type: ignore
- for pyiceberg_key, daft_key in s3_mapping.items():
- value = props.get(pyiceberg_key, None)
- if value is not None:
- s3_args[daft_key] = value
-
- if len(s3_args) > 0:
- s3_config = S3Config(**s3_args)
- else:
- s3_config = None
-
- gcs_config = None
- azure_config = None
- if parse(pyiceberg.__version__) >= parse("0.5.0"):
- from pyiceberg.io import GCS_PROJECT_ID, GCS_TOKEN
-
- gcs_mapping = {GCS_PROJECT_ID: "project_id", GCS_TOKEN: "token"}
- gcs_args = dict() # type: ignore
- for pyiceberg_key, daft_key in gcs_mapping.items():
- value = props.get(pyiceberg_key, None)
- if value is not None:
- gcs_args[daft_key] = value
-
- if len(gcs_args) > 0:
- gcs_config = GCSConfig(**gcs_args)
-
- azure_mapping = {
- "adlfs.account-name": "storage_account",
- "adlfs.account-key": "access_key",
- "adlfs.sas-token": "sas_token",
- "adlfs.tenant-id": "tenant_id",
- "adlfs.client-id": "client_id",
- "adlfs.client-secret": "client_secret",
- }
-
- azure_args = dict() # type: ignore
- for pyiceberg_key, daft_key in azure_mapping.items():
- value = props.get(pyiceberg_key, None)
- if value is not None:
- azure_args[daft_key] = value
-
- if len(azure_args) > 0:
- azure_config = AzureConfig(**azure_args)
-
- if any([s3_config, gcs_config, azure_config]):
- return IOConfig(s3=s3_config, gcs=gcs_config, azure=azure_config)
- else:
+ any_props_set = False
+
+ def get_first_property_value(*property_names: str) -> Optional[Any]:
+ for property_name in property_names:
+ if property_value := props.get(property_name):
+ nonlocal any_props_set
+ any_props_set = True
+ return property_value
return None
+ io_config = IOConfig(
+ s3=S3Config(
+ endpoint_url=get_first_property_value("s3.endpoint"),
+ region_name=get_first_property_value("s3.region", "client.region"),
+ key_id=get_first_property_value("s3.access-key-id", "client.access-key-id"),
+ access_key=get_first_property_value("s3.secret-access-key", "client.secret-access-key"),
+ session_token=get_first_property_value("s3.session-token", "client.session-token"),
+ ),
+ azure=AzureConfig(
+ storage_account=get_first_property_value("adls.account-name", "adlfs.account-name"),
+ access_key=get_first_property_value("adls.account-key", "adlfs.account-key"),
+ sas_token=get_first_property_value("adls.sas-token", "adlfs.sas-token"),
+ tenant_id=get_first_property_value("adls.tenant-id", "adlfs.tenant-id"),
+ client_id=get_first_property_value("adls.client-id", "adlfs.client-id"),
+ client_secret=get_first_property_value("adls.client-secret", "adlfs.client-secret"),
+ ),
+ gcs=GCSConfig(
+ project_id=get_first_property_value("gcs.project-id"),
+ token=get_first_property_value("gcs.oauth2.token"),
+ ),
+ )
+
+ return io_config if any_props_set else None
+
@PublicAPI
def read_iceberg(
- pyiceberg_table: "PyIcebergTable",
+ table: "pyiceberg.table.Table",
snapshot_id: Optional[int] = None,
io_config: Optional["IOConfig"] = None,
) -> DataFrame:
@@ -93,8 +62,8 @@ def read_iceberg(
Example:
>>> import pyiceberg
>>>
- >>> pyiceberg_table = pyiceberg.Table(...)
- >>> df = daft.read_iceberg(pyiceberg_table)
+ >>> table = pyiceberg.Table(...)
+ >>> df = daft.read_iceberg(table)
>>>
>>> # Filters on this dataframe can now be pushed into
>>> # the read operation from Iceberg
@@ -106,9 +75,9 @@ def read_iceberg(
official project for Python.
Args:
- pyiceberg_table: Iceberg table created using the PyIceberg library
- snapshot_id: Snapshot ID of the table to query
- io_config: A custom IOConfig to use when accessing Iceberg object storage data. Defaults to None.
+ table (pyiceberg.table.Table): `PyIceberg Table `__ created using the PyIceberg library
+ snapshot_id (int, optional): Snapshot ID of the table to query
+ io_config (IOConfig, optional): A custom IOConfig to use when accessing Iceberg object storage data. If provided, configurations set in `table` are ignored.
Returns:
DataFrame: a DataFrame with the schema converted from the specified Iceberg table
@@ -116,16 +85,14 @@ def read_iceberg(
from daft.iceberg.iceberg_scan import IcebergScanOperator
io_config = (
- _convert_iceberg_file_io_properties_to_io_config(pyiceberg_table.io.properties)
- if io_config is None
- else io_config
+ _convert_iceberg_file_io_properties_to_io_config(table.io.properties) if io_config is None else io_config
)
io_config = context.get_context().daft_planning_config.default_io_config if io_config is None else io_config
multithreaded_io = context.get_context().get_or_create_runner().name != "ray"
storage_config = StorageConfig(multithreaded_io, io_config)
- iceberg_operator = IcebergScanOperator(pyiceberg_table, snapshot_id=snapshot_id, storage_config=storage_config)
+ iceberg_operator = IcebergScanOperator(table, snapshot_id=snapshot_id, storage_config=storage_config)
handle = ScanOperatorHandle.from_python_scan_operator(iceberg_operator)
builder = LogicalPlanBuilder.from_tabular_scan(scan_operator=handle)
diff --git a/daft/logical/builder.py b/daft/logical/builder.py
index b7316a0a80..e97e07ae4c 100644
--- a/daft/logical/builder.py
+++ b/daft/logical/builder.py
@@ -297,9 +297,8 @@ def write_tabular(
builder = self._builder.table_write(str(root_dir), file_format, part_cols_pyexprs, compression, io_config)
return LogicalPlanBuilder(builder)
- def write_iceberg(self, table: IcebergTable) -> LogicalPlanBuilder:
+ def write_iceberg(self, table: IcebergTable, io_config: IOConfig) -> LogicalPlanBuilder:
from daft.iceberg.iceberg_write import get_missing_columns, partition_field_to_expr
- from daft.io._iceberg import _convert_iceberg_file_io_properties_to_io_config
name = ".".join(table.name())
location = f"{table.location()}/data"
@@ -314,7 +313,6 @@ def write_iceberg(self, table: IcebergTable) -> LogicalPlanBuilder:
partition_cols = [partition_field_to_expr(field, schema)._expr for field in partition_spec.fields]
props = table.properties
columns = [col.name for col in schema.columns]
- io_config = _convert_iceberg_file_io_properties_to_io_config(table.io.properties)
builder = builder.iceberg_write(
name, location, partition_spec.spec_id, partition_cols, schema, props, columns, io_config
)