Skip to content

Commit

Permalink
better test
Browse files Browse the repository at this point in the history
  • Loading branch information
Colin Ho authored and Colin Ho committed Jan 17, 2025
1 parent e60049a commit 8236cdd
Showing 1 changed file with 33 additions and 7 deletions.
40 changes: 33 additions & 7 deletions tests/dataframe/test_iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,17 +88,43 @@ def test_iter_rows_column_formats(make_df, format, data, expected):

rows = list(df.iter_rows(column_format=format))

def compare_values(v1, v2):
if isinstance(v1, np.ndarray) and isinstance(v2, np.ndarray):
return np.array_equal(v1, v2)
if isinstance(v1, dict) and isinstance(v2, dict):
return all(compare_values(v1[k], v2[k]) for k in v1)
return v1 == v2
# Compare each row
assert len(rows) == len(expected)
for actual_row, expected_row in zip(rows, [{"a": e} for e in expected]):
assert actual_row == expected_row


@pytest.mark.parametrize(
"data, expected",
[
pytest.param(
[[1, 2], [3, 4]],
[
np.array([1, 2], dtype=np.int64),
np.array([3, 4], dtype=np.int64),
],
id="list_of_ints",
),
pytest.param(
[[1.0, 2.0], [3.0, 4.0]],
[
np.array([1.0, 2.0], dtype=np.float64),
np.array([3.0, 4.0], dtype=np.float64),
],
id="list_of_floats",
),
],
)
def test_iter_rows_lists_to_numpy(make_df, data, expected):
df = make_df({"a": data})

rows = list(df.iter_rows(column_format="arrow"))

# Compare each row
assert len(rows) == len(expected)
for actual_row, expected_row in zip(rows, [{"a": e} for e in expected]):
assert compare_values(actual_row, expected_row)
np_data = actual_row["a"].values.to_numpy()
np.testing.assert_array_equal(np_data, expected_row["a"])


def test_iter_rows_arrow_column_format_not_compatible():
Expand Down

0 comments on commit 8236cdd

Please sign in to comment.