Skip to content

Commit

Permalink
Ensure compatibility with numpy 2.0.0 (#6976)
Browse files Browse the repository at this point in the history
* Ensure compatibility with numpy 2.0.0

Following the conversion guide copy=False is no longer required and will result in an error: https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword.

* Update src/datasets/formatting/formatting.py

Co-authored-by: Quentin Lhoest <[email protected]>

* Update src/datasets/formatting/formatting.py

Co-authored-by: Quentin Lhoest <[email protected]>

* make style

---------

Co-authored-by: Quentin Lhoest <[email protected]>
  • Loading branch information
KennethEnevoldsen and lhoestq authored Jun 19, 2024
1 parent e47a746 commit 84d9dea
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion src/datasets/formatting/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,14 +187,20 @@ def _arrow_array_to_numpy(self, pa_array: pa.Array) -> np.ndarray:
else:
zero_copy_only = _is_zero_copy_only(pa_array.type) and not _is_array_with_nulls(pa_array)
array: List = pa_array.to_numpy(zero_copy_only=zero_copy_only).tolist()

if len(array) > 0:
if any(
(isinstance(x, np.ndarray) and (x.dtype == object or x.shape != array[0].shape))
or (isinstance(x, float) and np.isnan(x))
for x in array
):
if np.lib.NumpyVersion(np.__version__) >= "2.0.0b1":
return np.asarray(array, dtype=object)
return np.array(array, copy=False, dtype=object)
return np.array(array, copy=False)
if np.lib.NumpyVersion(np.__version__) >= "2.0.0b1":
return np.asarray(array)
else:
return np.array(array, copy=False)


class PandasArrowExtractor(BaseArrowExtractor[pd.DataFrame, pd.Series, pd.DataFrame]):
Expand Down

0 comments on commit 84d9dea

Please sign in to comment.