Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix view_attr not being respected by __getitem__ and subarray #2139

Merged
merged 5 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions tiledb/dense_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,7 @@ def __getitem__(self, selection):

"""
if self.view_attr:
result = self.subarray(selection, attrs=(self.view_attr,))
return result[self.view_attr]
return self.subarray(selection)

result = self.subarray(selection)
for i in range(self.schema.nattr):
Expand Down Expand Up @@ -269,6 +268,8 @@ def subarray(self, selection, attrs=None, cond=None, coords=False, order=None):
attr = self.schema.attr(0)
if attr.isanon:
return out[attr._internal_name]
if self.view_attr is not None:
return out[self.view_attr]
return out

def _read_dense_subarray(
Expand Down
10 changes: 9 additions & 1 deletion tiledb/sparse_array.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from collections import OrderedDict

import numpy as np
Expand Down Expand Up @@ -292,6 +293,9 @@ def __getitem__(self, selection):
>>> # A[5.0:579.9]

"""
if self.view_attr is not None:
return self.subarray(selection)

result = self.subarray(selection)
for i in range(self.schema.nattr):
attr = self.schema.attr(i)
Expand Down Expand Up @@ -506,7 +510,11 @@ def subarray(self, selection, coords=True, attrs=None, cond=None, order=None):

attr_names = list()

if attrs is None:
if self.view_attr is not None:
if attrs is not None:
warnings.warn("view_attr is set, ignoring attrs parameter", UserWarning)
attr_names.extend(self.view_attr)
elif attrs is None:
attr_names.extend(
self.schema.attr(i)._internal_name for i in range(self.schema.nattr)
)
Expand Down
47 changes: 45 additions & 2 deletions tiledb/tests/test_libtiledb.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .common import (
DiskTestCase,
assert_captured,
assert_dict_arrays_equal,
assert_subarrays_equal,
assert_unordered_equal,
fx_sparse_cell_order, # noqa: F401
Expand Down Expand Up @@ -923,8 +924,8 @@ def assert_ts(timestamp, result):
assert_ts((timestamps[2], None), A * 3)
assert_ts((timestamps[2], None), A * 3)

def test_open_attr(self):
uri = self.path("test_open_attr")
def test_open_attr_dense(self):
uri = self.path("test_open_attr_dense")
schema = tiledb.ArraySchema(
domain=tiledb.Domain(
tiledb.Dim(name="dim0", dtype=np.uint32, domain=(1, 4))
Expand All @@ -949,6 +950,48 @@ def test_open_attr(self):
assert_array_equal(A[:], np.array((1, 2, 3, 4)))
assert list(A.multi_index[:].keys()) == ["x"]

with tiledb.open(uri, attr="x") as A:
q = A.query(cond="x <= 3")
expected = np.array([1, 2, 3, schema.attr("x").fill[0]])
assert_array_equal(q[:], expected)

def test_open_attr_sparse(self):
uri = self.path("test_open_attr_sparse")
schema = tiledb.ArraySchema(
domain=tiledb.Domain(
tiledb.Dim(name="dim0", dtype=np.uint32, domain=(1, 4))
),
attrs=(
tiledb.Attr(name="x", dtype=np.int32),
tiledb.Attr(name="y", dtype=np.int32),
),
sparse=True,
)
tiledb.Array.create(uri, schema)

with tiledb.open(uri, mode="w") as A:
A[[1, 2, 3, 4]] = {"x": np.array((1, 2, 3, 4)), "y": np.array((5, 6, 7, 8))}

with self.assertRaises(KeyError):
tiledb.open(uri, attr="z")

with self.assertRaises(KeyError):
tiledb.open(uri, attr="dim0")

with tiledb.open(uri, attr="x") as A:
expected = OrderedDict(
[("dim0", np.array([1, 2, 3, 4])), ("x", np.array([1, 2, 3, 4]))]
)
assert_dict_arrays_equal(A[:], expected)
assert list(A.multi_index[:].keys()) == ["dim0", "x"]

with tiledb.open(uri, attr="x") as A:
q = A.query(cond="x <= 3")
expected = OrderedDict(
[("dim0", np.array([1, 2, 3])), ("x", np.array([1, 2, 3]))]
)
assert_dict_arrays_equal(q[:], expected)

def test_ncell_attributes(self):
dom = tiledb.Domain(tiledb.Dim(domain=(0, 9), tile=10, dtype=int))
attr = tiledb.Attr(dtype=[("", np.int32), ("", np.int32), ("", np.int32)])
Expand Down
Loading