Skip to content

Commit

Permalink
Add test for comparison on ordered enums
Browse files Browse the repository at this point in the history
  • Loading branch information
nguyenv committed Dec 15, 2023
1 parent c2fb610 commit 6650a01
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 6 deletions.
14 changes: 10 additions & 4 deletions tiledb/query_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def aux_visit_Compare(

variable = self.get_variable_from_node(variable)
value = self.get_value_from_node(value)

invalid_enum_value = False

if self.array.schema.has_attr(variable):
Expand All @@ -261,8 +261,14 @@ def aux_visit_Compare(
enum_values = self.array.enum(enum_label).values()
if value in enum_values:
dt = self.array.enum(enum_label).dtype
else:
invalid_enum_value = True
else:
# This is a workaround for when the user applies a query
# condition onto an invalid enumeration value. Instead of
# using the enum value, toggle `set_use_enumeration` off
# below, and use the index value instead. To force a False
# result every time, use a value that is +1 the number of
# enumeration values.
invalid_enum_value = True
value = len(enum_values) + 1
dt = self.array.attr(variable).dtype
else:
Expand All @@ -275,7 +281,7 @@ def aux_visit_Compare(

pyqc = qc.PyQueryCondition(self.ctx)
self.init_pyqc(pyqc, dtype)(variable, value, op)

if invalid_enum_value:
pyqc.set_use_enumeration(False)

Expand Down
10 changes: 8 additions & 2 deletions tiledb/tests/test_query_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,7 +839,7 @@ def test_qc_enumeration(self):
uri = self.path("test_qc_enumeration")
dom = tiledb.Domain(tiledb.Dim(domain=(1, 8), tile=1))
enum1 = tiledb.Enumeration("enmr1", True, [0, 1, 2])
enum2 = tiledb.Enumeration("enmr2", False, ["a", "bb", "ccc"])
enum2 = tiledb.Enumeration("enmr2", True, ["a", "bb", "ccc"])
attr1 = tiledb.Attr("attr1", dtype=np.int32, enum_label="enmr1")
attr2 = tiledb.Attr("attr2", dtype=np.int32, enum_label="enmr2")
schema = tiledb.ArraySchema(
Expand All @@ -864,7 +864,13 @@ def test_qc_enumeration(self):
self.filter_dense(result["attr2"], mask)
== list(enum2.values()).index("bb")
)


mask = A.attr("attr2").fill
result = A.query(cond="attr2 < 'ccc'", attrs=["attr2"])[:]
assert list(enum2.values()).index("ccc") not in self.filter_dense(
result["attr2"], mask
)

result = A.query(cond="attr2 == 'b'", attrs=["attr2"])[:]
assert all(self.filter_dense(result["attr2"], mask) == [])

Expand Down

0 comments on commit 6650a01

Please sign in to comment.