diff --git a/numpy/core/src/multiarray/descriptor.c b/numpy/core/src/multiarray/descriptor.c index 78ed7a1de315..49f086d21459 100644 --- a/numpy/core/src/multiarray/descriptor.c +++ b/numpy/core/src/multiarray/descriptor.c @@ -3752,11 +3752,9 @@ descr_repeat(PyObject *self, Py_ssize_t length) return (PyObject *)new; } -static PyObject * -descr_subscript(PyArray_Descr *self, PyObject *op) +static int +_check_has_fields(PyArray_Descr *self) { - PyObject *retval; - if (!PyDataType_HASFIELDS(self)) { PyObject *astr = arraydescr_str(self); #if defined(NPY_PY3K) @@ -3767,74 +3765,91 @@ descr_subscript(PyArray_Descr *self, PyObject *op) PyErr_Format(PyExc_KeyError, "There are no fields in dtype %s.", PyBytes_AsString(astr)); Py_DECREF(astr); + return -1; + } + else { + return 0; + } +} + +static PyObject * +_subscript_by_name(PyArray_Descr *self, PyObject *op) +{ + PyObject *obj = PyDict_GetItem(self->fields, op); + PyObject *descr; + PyObject *s; + + if (obj == NULL) { + if (PyUnicode_Check(op)) { + s = PyUnicode_AsUnicodeEscapeString(op); + } + else { + s = op; + } + + PyErr_Format(PyExc_KeyError, + "Field named \'%s\' not found.", PyBytes_AsString(s)); + if (s != op) { + Py_DECREF(s); + } return NULL; } + descr = PyTuple_GET_ITEM(obj, 0); + Py_INCREF(descr); + return descr; +} + +static PyObject * +_subscript_by_index(PyArray_Descr *self, Py_ssize_t i) +{ + PyObject *name = PySequence_GetItem(self->names, i); + if (name == NULL) { + PyErr_Format(PyExc_IndexError, + "Field index %zd out of range.", i); + return NULL; + } + return _subscript_by_name(self, name); +} + +static PyObject * +descr_subscript(PyArray_Descr *self, PyObject *op) +{ + if (_check_has_fields(self) < 0) { + return NULL; + } + #if defined(NPY_PY3K) if (PyUString_Check(op)) { #else if (PyUString_Check(op) || PyUnicode_Check(op)) { #endif - PyObject *obj = PyDict_GetItem(self->fields, op); - PyObject *descr; - PyObject *s; - - if (obj == NULL) { - if (PyUnicode_Check(op)) { - s = PyUnicode_AsUnicodeEscapeString(op); - } - else { - s = op; - } - - PyErr_Format(PyExc_KeyError, - "Field named \'%s\' not found.", PyBytes_AsString(s)); - if (s != op) { - Py_DECREF(s); - } - return NULL; - } - descr = PyTuple_GET_ITEM(obj, 0); - Py_INCREF(descr); - retval = descr; + return _subscript_by_name(self, op); } else if (PyInt_Check(op)) { - PyObject *name; - int size = PyTuple_GET_SIZE(self->names); - int value = PyArray_PyIntAsInt(op); - int orig_value = value; - + Py_ssize_t i = PyArray_PyIntAsIntp(op); if (PyErr_Occurred()) { return NULL; } - if (value < 0) { - value += size; - } - if (value < 0 || value >= size) { - PyErr_Format(PyExc_IndexError, - "Field index %d out of range.", orig_value); - return NULL; - } - name = PyTuple_GET_ITEM(self->names, value); - retval = descr_subscript(self, name); + return _subscript_by_index(self, i); } else { PyErr_SetString(PyExc_ValueError, "Field key must be an integer, string, or unicode."); return NULL; } - return retval; } static PySequenceMethods descr_as_sequence = { - descr_length, - (binaryfunc)NULL, - descr_repeat, - NULL, NULL, - NULL, /* sq_ass_item */ - NULL, /* ssizessizeobjargproc sq_ass_slice */ - 0, /* sq_contains */ - 0, /* sq_inplace_concat */ - 0, /* sq_inplace_repeat */ + (lenfunc) descr_length, /* sq_length */ + (binaryfunc) NULL, /* sq_concat */ + (ssizeargfunc) descr_repeat, /* sq_repeat */ + (ssizeargfunc) NULL, /* sq_item */ + (ssizessizeargfunc) NULL, /* sq_slice */ + (ssizeobjargproc) NULL, /* sq_ass_item */ + (ssizessizeobjargproc) NULL, /* sq_ass_slice */ + (objobjproc) NULL, /* sq_contains */ + (binaryfunc) NULL, /* sq_inplace_concat */ + (ssizeargfunc) NULL, /* sq_inplace_repeat */ }; static PyMappingMethods descr_as_mapping = { diff --git a/numpy/core/tests/test_dtype.py b/numpy/core/tests/test_dtype.py index 9cefb2ad1d05..7f5ab2c9dd8e 100644 --- a/numpy/core/tests/test_dtype.py +++ b/numpy/core/tests/test_dtype.py @@ -298,6 +298,14 @@ def make_dtype(off): dt = make_dtype(np.uint32(0)) np.zeros(1, dtype=dt)[0].item() + def test_fields_by_index(self): + dt = np.dtype([('a', np.int8), ('b', np.float32, 3)]) + assert_dtype_equal(dt[0], np.dtype(np.int8)) + assert_dtype_equal(dt[1], np.dtype((np.float32, 3))) + assert_dtype_equal(dt[-1], dt[1]) + assert_dtype_equal(dt[-2], dt[0]) + assert_raises(IndexError, lambda: dt[-3]) + class TestSubarray(object): def test_single_subarray(self):