Skip to content

Commit

Permalink
Merge pull request numpy#9947 from eric-wieser/tidy-dtype-indexing
Browse files Browse the repository at this point in the history
MAINT/TST: Tidy dtype indexing
  • Loading branch information
mhvk authored Nov 2, 2017
2 parents c733359 + 9271d81 commit 103f23c
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 51 deletions.
117 changes: 66 additions & 51 deletions numpy/core/src/multiarray/descriptor.c
Original file line number Diff line number Diff line change
Expand Up @@ -3752,11 +3752,9 @@ descr_repeat(PyObject *self, Py_ssize_t length)
return (PyObject *)new;
}

static PyObject *
descr_subscript(PyArray_Descr *self, PyObject *op)
static int
_check_has_fields(PyArray_Descr *self)
{
PyObject *retval;

if (!PyDataType_HASFIELDS(self)) {
PyObject *astr = arraydescr_str(self);
#if defined(NPY_PY3K)
Expand All @@ -3767,74 +3765,91 @@ descr_subscript(PyArray_Descr *self, PyObject *op)
PyErr_Format(PyExc_KeyError,
"There are no fields in dtype %s.", PyBytes_AsString(astr));
Py_DECREF(astr);
return -1;
}
else {
return 0;
}
}

static PyObject *
_subscript_by_name(PyArray_Descr *self, PyObject *op)
{
PyObject *obj = PyDict_GetItem(self->fields, op);
PyObject *descr;
PyObject *s;

if (obj == NULL) {
if (PyUnicode_Check(op)) {
s = PyUnicode_AsUnicodeEscapeString(op);
}
else {
s = op;
}

PyErr_Format(PyExc_KeyError,
"Field named \'%s\' not found.", PyBytes_AsString(s));
if (s != op) {
Py_DECREF(s);
}
return NULL;
}
descr = PyTuple_GET_ITEM(obj, 0);
Py_INCREF(descr);
return descr;
}

static PyObject *
_subscript_by_index(PyArray_Descr *self, Py_ssize_t i)
{
PyObject *name = PySequence_GetItem(self->names, i);
if (name == NULL) {
PyErr_Format(PyExc_IndexError,
"Field index %zd out of range.", i);
return NULL;
}
return _subscript_by_name(self, name);
}

static PyObject *
descr_subscript(PyArray_Descr *self, PyObject *op)
{
if (_check_has_fields(self) < 0) {
return NULL;
}

#if defined(NPY_PY3K)
if (PyUString_Check(op)) {
#else
if (PyUString_Check(op) || PyUnicode_Check(op)) {
#endif
PyObject *obj = PyDict_GetItem(self->fields, op);
PyObject *descr;
PyObject *s;

if (obj == NULL) {
if (PyUnicode_Check(op)) {
s = PyUnicode_AsUnicodeEscapeString(op);
}
else {
s = op;
}

PyErr_Format(PyExc_KeyError,
"Field named \'%s\' not found.", PyBytes_AsString(s));
if (s != op) {
Py_DECREF(s);
}
return NULL;
}
descr = PyTuple_GET_ITEM(obj, 0);
Py_INCREF(descr);
retval = descr;
return _subscript_by_name(self, op);
}
else if (PyInt_Check(op)) {
PyObject *name;
int size = PyTuple_GET_SIZE(self->names);
int value = PyArray_PyIntAsInt(op);
int orig_value = value;

Py_ssize_t i = PyArray_PyIntAsIntp(op);
if (PyErr_Occurred()) {
return NULL;
}
if (value < 0) {
value += size;
}
if (value < 0 || value >= size) {
PyErr_Format(PyExc_IndexError,
"Field index %d out of range.", orig_value);
return NULL;
}
name = PyTuple_GET_ITEM(self->names, value);
retval = descr_subscript(self, name);
return _subscript_by_index(self, i);
}
else {
PyErr_SetString(PyExc_ValueError,
"Field key must be an integer, string, or unicode.");
return NULL;
}
return retval;
}

static PySequenceMethods descr_as_sequence = {
descr_length,
(binaryfunc)NULL,
descr_repeat,
NULL, NULL,
NULL, /* sq_ass_item */
NULL, /* ssizessizeobjargproc sq_ass_slice */
0, /* sq_contains */
0, /* sq_inplace_concat */
0, /* sq_inplace_repeat */
(lenfunc) descr_length, /* sq_length */
(binaryfunc) NULL, /* sq_concat */
(ssizeargfunc) descr_repeat, /* sq_repeat */
(ssizeargfunc) NULL, /* sq_item */
(ssizessizeargfunc) NULL, /* sq_slice */
(ssizeobjargproc) NULL, /* sq_ass_item */
(ssizessizeobjargproc) NULL, /* sq_ass_slice */
(objobjproc) NULL, /* sq_contains */
(binaryfunc) NULL, /* sq_inplace_concat */
(ssizeargfunc) NULL, /* sq_inplace_repeat */
};

static PyMappingMethods descr_as_mapping = {
Expand Down
8 changes: 8 additions & 0 deletions numpy/core/tests/test_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,14 @@ def make_dtype(off):
dt = make_dtype(np.uint32(0))
np.zeros(1, dtype=dt)[0].item()

def test_fields_by_index(self):
dt = np.dtype([('a', np.int8), ('b', np.float32, 3)])
assert_dtype_equal(dt[0], np.dtype(np.int8))
assert_dtype_equal(dt[1], np.dtype((np.float32, 3)))
assert_dtype_equal(dt[-1], dt[1])
assert_dtype_equal(dt[-2], dt[0])
assert_raises(IndexError, lambda: dt[-3])


class TestSubarray(object):
def test_single_subarray(self):
Expand Down

0 comments on commit 103f23c

Please sign in to comment.