Skip to content

Commit

Permalink
[BUG] Implement deserialize for Python objects serialized as sequences (
Browse files Browse the repository at this point in the history
#3339)

`visit_seq` is used when using serde_json to serialize/deserialize Rust
objects, since the byte buffer is just stored as a list of numbers in
JSON.
  • Loading branch information
kevinzwang authored Nov 20, 2024
1 parent b89ee3d commit ec24c80
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 12 deletions.
19 changes: 13 additions & 6 deletions src/common/py-serde/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,19 @@ impl<'de> Visitor<'de> for PyObjectVisitor {
where
E: DeError,
{
Python::with_gil(|py| {
py.import_bound(pyo3::intern!(py, "daft.pickle"))
.and_then(|m| m.getattr(pyo3::intern!(py, "loads")))
.and_then(|f| Ok(f.call1((v,))?.into()))
.map_err(|e| DeError::custom(e.to_string()))
})
self.visit_bytes(&v)
}

fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: serde::de::SeqAccess<'de>,
{
let mut v: Vec<u8> = Vec::with_capacity(seq.size_hint().unwrap_or_default());
while let Some(elem) = seq.next_element()? {
v.push(elem);
}

self.visit_bytes(&v)
}
}

Expand Down
18 changes: 12 additions & 6 deletions tests/io/test_s3_credentials_refresh.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,8 @@ def test_s3_credentials_refresh(aws_log_file: io.IOBase):
server_url = f"http://{host}:{port}"

bucket_name = "mybucket"
file_name = "test.parquet"

s3_file_path = f"s3://{bucket_name}/{file_name}"
input_file_path = f"s3://{bucket_name}/input.parquet"
output_file_path = f"s3://{bucket_name}/output.parquet"

old_env = os.environ.copy()
# Set required AWS environment variables before starting server.
Expand Down Expand Up @@ -98,21 +97,28 @@ def get_credentials():
)

df = daft.from_pydict({"a": [1, 2, 3]})
df.write_parquet(s3_file_path, io_config=static_config)
df.write_parquet(input_file_path, io_config=static_config)

df = daft.read_parquet(s3_file_path, io_config=dynamic_config)
df = daft.read_parquet(input_file_path, io_config=dynamic_config)
assert count_get_credentials == 1

df.collect()
assert count_get_credentials == 1

df = daft.read_parquet(s3_file_path, io_config=dynamic_config)
df = daft.read_parquet(input_file_path, io_config=dynamic_config)
assert count_get_credentials == 1

time.sleep(1)
df.collect()
assert count_get_credentials == 2

df.write_parquet(output_file_path, io_config=dynamic_config)
assert count_get_credentials == 2

df2 = daft.read_parquet(output_file_path, io_config=static_config)

assert df.to_arrow() == df2.to_arrow()

# Shutdown moto server.
stop_process(process)
# Restore old set of environment variables.
Expand Down

0 comments on commit ec24c80

Please sign in to comment.