diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 09894cf74f..228e02c3d8 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -1428,7 +1428,11 @@ struct npy_format_descriptor< }; template -struct npy_format_descriptor::value>> { +struct npy_format_descriptor< + T, + enable_if_t::value + || ((std::is_same::value || std::is_same::value) + && sizeof(T) == sizeof(PyObject *))>> { static constexpr auto name = const_name("object"); static constexpr int value = npy_api::NPY_OBJECT_; diff --git a/tests/test_numpy_array.cpp b/tests/test_numpy_array.cpp index c2f754208b..79ade3ba1a 100644 --- a/tests/test_numpy_array.cpp +++ b/tests/test_numpy_array.cpp @@ -156,6 +156,55 @@ py::handle auxiliaries(T &&r, T2 &&r2) { return l.release(); } +template +PyObjectType convert_to_pyobjecttype(py::object obj); + +template <> +PyObject *convert_to_pyobjecttype(py::object obj) { + return obj.release().ptr(); +} + +template <> +py::handle convert_to_pyobjecttype(py::object obj) { + return obj.release(); +} + +template <> +py::object convert_to_pyobjecttype(py::object obj) { + return obj; +} + +template +std::string pass_array_return_sum_str_values(const py::array_t &objs) { + std::string sum_str_values; + for (const auto &obj : objs) { + sum_str_values += py::str(obj.attr("value")); + } + return sum_str_values; +} + +template +py::list pass_array_return_as_list(const py::array_t &objs) { + return objs; +} + +template +py::array_t return_array_cpp_loop(const py::list &objs) { + py::size_t arr_size = py::len(objs); + py::array_t arr_from_list(static_cast(arr_size)); + PyObjectType *data = arr_from_list.mutable_data(); + for (py::size_t i = 0; i < arr_size; i++) { + assert(!data[i]); + data[i] = convert_to_pyobjecttype(objs[i].attr("value")); + } + return arr_from_list; +} + +template +py::array_t return_array_from_list(const py::list &objs) { + return objs; +} + // note: declaration at local scope would create a dangling reference! static int data_i = 42; @@ -520,28 +569,21 @@ TEST_SUBMODULE(numpy_array, sm) { sm.def("round_trip_float", [](double d) { return d; }); sm.def("pass_array_pyobject_ptr_return_sum_str_values", - [](const py::array_t &objs) { - std::string sum_str_values; - for (const auto &obj : objs) { - sum_str_values += py::str(obj.attr("value")); - } - return sum_str_values; - }); - - sm.def("pass_array_pyobject_ptr_return_as_list", - [](const py::array_t &objs) -> py::list { return objs; }); - - sm.def("return_array_pyobject_ptr_cpp_loop", [](const py::list &objs) { - py::size_t arr_size = py::len(objs); - py::array_t arr_from_list(static_cast(arr_size)); - PyObject **data = arr_from_list.mutable_data(); - for (py::size_t i = 0; i < arr_size; i++) { - assert(data[i] == nullptr); - data[i] = py::cast(objs[i].attr("value")); - } - return arr_from_list; - }); - - sm.def("return_array_pyobject_ptr_from_list", - [](const py::list &objs) -> py::array_t { return objs; }); + pass_array_return_sum_str_values); + sm.def("pass_array_handle_return_sum_str_values", + pass_array_return_sum_str_values); + sm.def("pass_array_object_return_sum_str_values", + pass_array_return_sum_str_values); + + sm.def("pass_array_pyobject_ptr_return_as_list", pass_array_return_as_list); + sm.def("pass_array_handle_return_as_list", pass_array_return_as_list); + sm.def("pass_array_object_return_as_list", pass_array_return_as_list); + + sm.def("return_array_pyobject_ptr_cpp_loop", return_array_cpp_loop); + sm.def("return_array_handle_cpp_loop", return_array_cpp_loop); + sm.def("return_array_object_cpp_loop", return_array_cpp_loop); + + sm.def("return_array_pyobject_ptr_from_list", return_array_from_list); + sm.def("return_array_handle_from_list", return_array_from_list); + sm.def("return_array_object_from_list", return_array_from_list); } diff --git a/tests/test_numpy_array.py b/tests/test_numpy_array.py index 6e8bde826f..b1c6875f9e 100644 --- a/tests/test_numpy_array.py +++ b/tests/test_numpy_array.py @@ -629,45 +629,61 @@ def UnwrapPyValueHolder(vhs): return [vh.value for vh in vhs] -def test_pass_array_pyobject_ptr_return_sum_str_values_ndarray(): +PASS_ARRAY_PYOBJECT_RETURN_SUM_STR_VALUES_FUNCTIONS = [ + m.pass_array_pyobject_ptr_return_sum_str_values, + m.pass_array_handle_return_sum_str_values, + m.pass_array_object_return_sum_str_values, +] + + +@pytest.mark.parametrize( + "pass_array", PASS_ARRAY_PYOBJECT_RETURN_SUM_STR_VALUES_FUNCTIONS +) +def test_pass_array_object_return_sum_str_values_ndarray(pass_array): # Intentionally all temporaries, do not change. assert ( - m.pass_array_pyobject_ptr_return_sum_str_values( - np.array(WrapWithPyValueHolder(-3, "four", 5.0), dtype=object) - ) + pass_array(np.array(WrapWithPyValueHolder(-3, "four", 5.0), dtype=object)) == "-3four5.0" ) -def test_pass_array_pyobject_ptr_return_sum_str_values_list(): +@pytest.mark.parametrize( + "pass_array", PASS_ARRAY_PYOBJECT_RETURN_SUM_STR_VALUES_FUNCTIONS +) +def test_pass_array_object_return_sum_str_values_list(pass_array): # Intentionally all temporaries, do not change. - assert ( - m.pass_array_pyobject_ptr_return_sum_str_values( - WrapWithPyValueHolder(2, "three", -4.0) - ) - == "2three-4.0" - ) + assert pass_array(WrapWithPyValueHolder(2, "three", -4.0)) == "2three-4.0" -def test_pass_array_pyobject_ptr_return_as_list(): +@pytest.mark.parametrize( + "pass_array", + [ + m.pass_array_pyobject_ptr_return_as_list, + m.pass_array_handle_return_as_list, + m.pass_array_object_return_as_list, + ], +) +def test_pass_array_object_return_as_list(pass_array): # Intentionally all temporaries, do not change. assert UnwrapPyValueHolder( - m.pass_array_pyobject_ptr_return_as_list( - np.array(WrapWithPyValueHolder(-1, "two", 3.0), dtype=object) - ) + pass_array(np.array(WrapWithPyValueHolder(-1, "two", 3.0), dtype=object)) ) == [-1, "two", 3.0] @pytest.mark.parametrize( - ("return_array_pyobject_ptr", "unwrap"), + ("return_array", "unwrap"), [ (m.return_array_pyobject_ptr_cpp_loop, list), + (m.return_array_handle_cpp_loop, list), + (m.return_array_object_cpp_loop, list), (m.return_array_pyobject_ptr_from_list, UnwrapPyValueHolder), + (m.return_array_handle_from_list, UnwrapPyValueHolder), + (m.return_array_object_from_list, UnwrapPyValueHolder), ], ) -def test_return_array_pyobject_ptr_cpp_loop(return_array_pyobject_ptr, unwrap): +def test_return_array_object_cpp_loop(return_array, unwrap): # Intentionally all temporaries, do not change. - arr_from_list = return_array_pyobject_ptr(WrapWithPyValueHolder(6, "seven", -8.0)) + arr_from_list = return_array(WrapWithPyValueHolder(6, "seven", -8.0)) assert isinstance(arr_from_list, np.ndarray) assert arr_from_list.dtype == np.dtype("O") assert unwrap(arr_from_list) == [6, "seven", -8.0]