From 8188ccf2a08026979177fc3ab8e6b7e13f6420cf Mon Sep 17 00:00:00 2001 From: Joachim Moeyens Date: Thu, 2 Nov 2023 20:29:30 -0700 Subject: [PATCH 1/5] Only check non-empty tables's attributes for equality when concatenating tables If all tables are empty then the first table's class and attributes are used to check for equality. --- quivr/concat.py | 40 ++++++++++++++++++++++++++-------------- test/test_concat.py | 18 ++++++++++++++++-- test/test_tables.py | 6 ++++++ 3 files changed, 48 insertions(+), 16 deletions(-) diff --git a/quivr/concat.py b/quivr/concat.py index 735a30f..5e1ea9f 100644 --- a/quivr/concat.py +++ b/quivr/concat.py @@ -22,24 +22,36 @@ def concatenate(values: Iterator[tables.AnyTable], defrag: bool = True) -> table memory. Defaults to True. """ + if len(values) == 0: + raise ValueError("No values to concatenate") + batches = [] - first = True + first_full = False + + # Find the first non-empty table to get the class for v in values: - batches += v.table.to_batches() - if first: + if not first_full and len(v) > 0: first_cls = v.__class__ first_val = v - first = False - else: - if v.__class__ != first_cls: - raise errors.TablesNotCompatibleError("All tables must be the same class to concatenate") - if not first_val._attr_equal(v): - raise errors.TablesNotCompatibleError( - "All tables must have the same attribute values to concatenate" - ) - - if first: - raise ValueError("No values to concatenate") + first_full = True + break + + # No non-empty tables found so lets pick the first table + # to get the class and attributes + if not first_full: + first_cls = values[0].__class__ + first_val = values[0] + + # Scan the values and now make sure they are all the same class + # as the first non-empty table + for v in values: + batches += v.table.to_batches() + if v.__class__ != first_cls: + raise errors.TablesNotCompatibleError("All tables must be the same class to concatenate") + if not first_val._attr_equal(v) and len(v) > 0: + raise errors.TablesNotCompatibleError( + "All non-empty tables must have the same attribute values to concatenate" + ) if len(batches) == 0: return first_cls.empty() diff --git a/test/test_concat.py b/test/test_concat.py index 7f696d1..98ee19b 100644 --- a/test/test_concat.py +++ b/test/test_concat.py @@ -4,7 +4,7 @@ import quivr as qv -from .test_tables import Pair, TableWithAttributes, Wrapper +from .test_tables import Pair, TableWithAttributes, TableWithDefaultAttributes, Wrapper def test_concatenate(): @@ -96,13 +96,27 @@ def test_concatenate_different_attrs(): t2 = TableWithAttributes.from_kwargs(x=[3], y=[4], attrib="bar") with pytest.raises( - qv.TablesNotCompatibleError, match="All tables must have the same attribute values to concatenate" + qv.TablesNotCompatibleError, + match="All non-empty tables must have the same attribute values to concatenate", ): qv.concatenate([t1, t2]) +def test_concatenate_default_attrs_empty(): + t1 = TableWithDefaultAttributes.empty() # This will default to attrib="foo" + t2 = TableWithDefaultAttributes.from_kwargs(x=[3], y=[4], attrib="bar") + t3 = TableWithDefaultAttributes.from_kwargs(x=[3], y=[4], attrib="bar") + have = qv.concatenate([t1, t2, t3]) + assert have.attrib == "bar" + + def test_concatenate_same_attrs(): t1 = TableWithAttributes.from_kwargs(x=[1], y=[2], attrib="foo") t2 = TableWithAttributes.from_kwargs(x=[3], y=[4], attrib="foo") have = qv.concatenate([t1, t2]) assert have.attrib == "foo" + + +def test_concatenate_no_values(): + with pytest.raises(ValueError, match="No values to concatenate"): + qv.concatenate([]) diff --git a/test/test_tables.py b/test/test_tables.py index e42e40b..1d775ce 100644 --- a/test/test_tables.py +++ b/test/test_tables.py @@ -890,6 +890,12 @@ class TableWithAttributes(qv.Table): attrib = qv.StringAttribute() +class TableWithDefaultAttributes(qv.Table): + x = qv.Int64Column() + y = qv.Int64Column() + attrib = qv.StringAttribute(default="foo") + + class TestTableAttributes: def test_from_dataframe(self): have = TableWithAttributes.from_dataframe( From d3c6d58a740e2266cdf3e47b5ae3fa4ac486863a Mon Sep 17 00:00:00 2001 From: Alec Koumjian Date: Mon, 20 May 2024 10:16:59 -0400 Subject: [PATCH 2/5] Formatting --- pyproject.toml | 2 +- quivr/columns.py | 204 ++++++++++++------------------------ quivr/experimental/shmem.py | 3 +- quivr/tables.py | 3 +- 4 files changed, 71 insertions(+), 141 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b6864bd..3f4eae7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ fix = [ "ruff ./quivr ./test --fix" ] lint = [ - "ruff ./quivr ./test", + "ruff check ./quivr ./test", "black --check ./quivr ./test", "isort --check-only ./quivr ./test" ] diff --git a/quivr/columns.py b/quivr/columns.py index 9e0d3cd..25b36ef 100644 --- a/quivr/columns.py +++ b/quivr/columns.py @@ -72,12 +72,10 @@ def __init__( raise errors.InvalidColumnDefault(self.default, self.dtype) from e @overload - def __get__(self, obj: None, objtype: type) -> Self: - ... + def __get__(self, obj: None, objtype: type) -> Self: ... @overload - def __get__(self, obj: tables.Table, objtype: type) -> pa.Array: - ... + def __get__(self, obj: tables.Table, objtype: type) -> pa.Array: ... def __get__(self, obj: Union[tables.Table, None], objtype: type) -> Union[Self, pa.Array]: """ @@ -226,12 +224,10 @@ def _set_on_pyarrow_table(self, table: pa.Table, value: T) -> pa.Table: return table @overload - def __get__(self, obj: None, objtype: type) -> Self: - ... + def __get__(self, obj: None, objtype: type) -> Self: ... @overload - def __get__(self, obj: tables.Table, objtype: type) -> T: - ... + def __get__(self, obj: tables.Table, objtype: type) -> T: ... def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[Self, T]: if obj is None: @@ -278,12 +274,10 @@ def __init__( ) @overload - def __get__(self, obj: None, objtype: type) -> Self: - ... + def __get__(self, obj: None, objtype: type) -> Self: ... @overload - def __get__(self, obj: tables.Table, objtype: type) -> pa.Int8Array: - ... + def __get__(self, obj: tables.Table, objtype: type) -> pa.Int8Array: ... def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[Self, pa.Int8Array]: if obj is None: @@ -310,12 +304,10 @@ def __init__( ) @overload - def __get__(self, obj: None, objtype: type) -> Self: - ... + def __get__(self, obj: None, objtype: type) -> Self: ... @overload - def __get__(self, obj: tables.Table, objtype: type) -> pa.Int16Array: - ... + def __get__(self, obj: tables.Table, objtype: type) -> pa.Int16Array: ... def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[Self, pa.Int16Array]: if obj is None: @@ -342,12 +334,10 @@ def __init__( ) @overload - def __get__(self, obj: None, objtype: type) -> Self: - ... + def __get__(self, obj: None, objtype: type) -> Self: ... @overload - def __get__(self, obj: tables.Table, objtype: type) -> pa.Int32Array: - ... + def __get__(self, obj: tables.Table, objtype: type) -> pa.Int32Array: ... def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[Self, pa.Int32Array]: if obj is None: @@ -374,12 +364,10 @@ def __init__( ) @overload - def __get__(self, obj: None, objtype: type) -> Self: - ... + def __get__(self, obj: None, objtype: type) -> Self: ... @overload - def __get__(self, obj: tables.Table, objtype: type) -> pa.Int64Array: - ... + def __get__(self, obj: tables.Table, objtype: type) -> pa.Int64Array: ... def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[Self, pa.Int64Array]: if obj is None: @@ -406,12 +394,10 @@ def __init__( ) @overload - def __get__(self, obj: None, objtype: type) -> Self: - ... + def __get__(self, obj: None, objtype: type) -> Self: ... @overload - def __get__(self, obj: tables.Table, objtype: type) -> pa.UInt8Array: - ... + def __get__(self, obj: tables.Table, objtype: type) -> pa.UInt8Array: ... def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[Self, pa.UInt8Array]: if obj is None: @@ -438,12 +424,10 @@ def __init__( ) @overload - def __get__(self, obj: None, objtype: type) -> Self: - ... + def __get__(self, obj: None, objtype: type) -> Self: ... @overload - def __get__(self, obj: tables.Table, objtype: type) -> pa.UInt16Array: - ... + def __get__(self, obj: tables.Table, objtype: type) -> pa.UInt16Array: ... def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[Self, pa.UInt16Array]: if obj is None: @@ -470,12 +454,10 @@ def __init__( ) @overload - def __get__(self, obj: None, objtype: type) -> Self: - ... + def __get__(self, obj: None, objtype: type) -> Self: ... @overload - def __get__(self, obj: tables.Table, objtype: type) -> pa.UInt32Array: - ... + def __get__(self, obj: tables.Table, objtype: type) -> pa.UInt32Array: ... def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[Self, pa.UInt32Array]: if obj is None: @@ -502,12 +484,10 @@ def __init__( ) @overload - def __get__(self, obj: None, objtype: type) -> Self: - ... + def __get__(self, obj: None, objtype: type) -> Self: ... @overload - def __get__(self, obj: tables.Table, objtype: type) -> pa.UInt64Array: - ... + def __get__(self, obj: tables.Table, objtype: type) -> pa.UInt64Array: ... def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[Self, pa.UInt64Array]: if obj is None: @@ -542,12 +522,10 @@ def __init__( super().__init__(pa.float16(), nullable=nullable, metadata=metadata, validator=validator) @overload - def __get__(self, obj: None, objtype: type) -> Self: - ... + def __get__(self, obj: None, objtype: type) -> Self: ... @overload - def __get__(self, obj: tables.Table, objtype: type) -> pa.lib.HalfFloatArray: - ... + def __get__(self, obj: tables.Table, objtype: type) -> pa.lib.HalfFloatArray: ... def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[Self, pa.lib.HalfFloatArray]: if obj is None: @@ -574,12 +552,10 @@ def __init__( ) @overload - def __get__(self, obj: None, objtype: type) -> Self: - ... + def __get__(self, obj: None, objtype: type) -> Self: ... @overload - def __get__(self, obj: tables.Table, objtype: type) -> pa.lib.FloatArray: - ... + def __get__(self, obj: tables.Table, objtype: type) -> pa.lib.FloatArray: ... def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[Self, pa.lib.FloatArray]: if obj is None: @@ -606,12 +582,10 @@ def __init__( ) @overload - def __get__(self, obj: None, objtype: type) -> Self: - ... + def __get__(self, obj: None, objtype: type) -> Self: ... @overload - def __get__(self, obj: tables.Table, objtype: type) -> pa.lib.DoubleArray: - ... + def __get__(self, obj: tables.Table, objtype: type) -> pa.lib.DoubleArray: ... def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[Self, pa.lib.DoubleArray]: if obj is None: @@ -636,12 +610,10 @@ def __init__( ) @overload - def __get__(self, obj: None, objtype: type) -> Self: - ... + def __get__(self, obj: None, objtype: type) -> Self: ... @overload - def __get__(self, obj: tables.Table, objtype: type) -> pa.BooleanArray: - ... + def __get__(self, obj: tables.Table, objtype: type) -> pa.BooleanArray: ... def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[pa.BooleanArray, Self]: if obj is None: @@ -672,12 +644,10 @@ def __init__( ) @overload - def __get__(self, obj: None, objtype: type) -> Self: - ... + def __get__(self, obj: None, objtype: type) -> Self: ... @overload - def __get__(self, obj: tables.Table, objtype: type) -> pa.StringArray: - ... + def __get__(self, obj: tables.Table, objtype: type) -> pa.StringArray: ... def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[pa.StringArray, Self]: if obj is None: @@ -703,12 +673,10 @@ def __init__( ) @overload - def __get__(self, obj: None, objtype: type) -> Self: - ... + def __get__(self, obj: None, objtype: type) -> Self: ... @overload - def __get__(self, obj: tables.Table, objtype: type) -> pa.LargeBinaryArray: - ... + def __get__(self, obj: tables.Table, objtype: type) -> pa.LargeBinaryArray: ... def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[Self, pa.LargeBinaryArray]: if obj is None: @@ -734,12 +702,10 @@ def __init__( ) @overload - def __get__(self, obj: None, objtype: type) -> Self: - ... + def __get__(self, obj: None, objtype: type) -> Self: ... @overload - def __get__(self, obj: tables.Table, objtype: type) -> pa.LargeStringArray: - ... + def __get__(self, obj: tables.Table, objtype: type) -> pa.LargeStringArray: ... def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[Self, pa.LargeStringArray]: if obj is None: @@ -764,12 +730,10 @@ def __init__( super().__init__(pa.date32(), nullable=nullable, metadata=metadata, validator=validator) @overload - def __get__(self, obj: None, objtype: type) -> Self: - ... + def __get__(self, obj: None, objtype: type) -> Self: ... @overload - def __get__(self, obj: tables.Table, objtype: type) -> pa.Date32Array: - ... + def __get__(self, obj: tables.Table, objtype: type) -> pa.Date32Array: ... def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[Self, pa.Date32Array]: if obj is None: @@ -794,12 +758,10 @@ def __init__( super().__init__(pa.date64(), nullable=nullable, metadata=metadata, validator=validator) @overload - def __get__(self, obj: None, objtype: type) -> Self: - ... + def __get__(self, obj: None, objtype: type) -> Self: ... @overload - def __get__(self, obj: tables.Table, objtype: type) -> pa.Date64Array: - ... + def __get__(self, obj: tables.Table, objtype: type) -> pa.Date64Array: ... def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[Self, pa.Date64Array]: if obj is None: @@ -847,12 +809,10 @@ def __init__( ) @overload - def __get__(self, obj: None, objtype: type) -> Self: - ... + def __get__(self, obj: None, objtype: type) -> Self: ... @overload - def __get__(self, obj: tables.Table, objtype: type) -> pa.TimestampArray: - ... + def __get__(self, obj: tables.Table, objtype: type) -> pa.TimestampArray: ... def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[Self, pa.TimestampArray]: if obj is None: @@ -891,12 +851,10 @@ def __init__( ) @overload - def __get__(self, obj: None, objtype: type) -> Self: - ... + def __get__(self, obj: None, objtype: type) -> Self: ... @overload - def __get__(self, obj: tables.Table, objtype: type) -> pa.Time32Array: - ... + def __get__(self, obj: tables.Table, objtype: type) -> pa.Time32Array: ... def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[Self, pa.Time32Array]: if obj is None: @@ -935,12 +893,10 @@ def __init__( ) @overload - def __get__(self, obj: None, objtype: type) -> Self: - ... + def __get__(self, obj: None, objtype: type) -> Self: ... @overload - def __get__(self, obj: tables.Table, objtype: type) -> pa.Time64Array: - ... + def __get__(self, obj: tables.Table, objtype: type) -> pa.Time64Array: ... def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[Self, pa.Time64Array]: if obj is None: @@ -969,12 +925,10 @@ def __init__( super().__init__(pa.duration(unit), nullable=nullable, metadata=metadata, validator=validator) @overload - def __get__(self, obj: None, objtype: type) -> Self: - ... + def __get__(self, obj: None, objtype: type) -> Self: ... @overload - def __get__(self, obj: tables.Table, objtype: type) -> pa.DurationArray: - ... + def __get__(self, obj: tables.Table, objtype: type) -> pa.DurationArray: ... def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[Self, pa.DurationArray]: if obj is None: @@ -1005,12 +959,10 @@ def __init__( ) @overload - def __get__(self, obj: None, objtype: type) -> Self: - ... + def __get__(self, obj: None, objtype: type) -> Self: ... @overload - def __get__(self, obj: tables.Table, objtype: type) -> pa.MonthDayNanoIntervalArray: - ... + def __get__(self, obj: tables.Table, objtype: type) -> pa.MonthDayNanoIntervalArray: ... def __get__( self, obj: Optional[tables.Table], objtype: type @@ -1035,12 +987,10 @@ def __init__( ) @overload - def __get__(self, obj: None, objtype: type) -> Self: - ... + def __get__(self, obj: None, objtype: type) -> Self: ... @overload - def __get__(self, obj: tables.Table, objtype: type) -> pa.BinaryArray: - ... + def __get__(self, obj: tables.Table, objtype: type) -> pa.BinaryArray: ... def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[Self, pa.BinaryArray]: if obj is None: @@ -1073,12 +1023,10 @@ def __init__( ) @overload - def __get__(self, obj: None, objtype: type) -> Self: - ... + def __get__(self, obj: None, objtype: type) -> Self: ... @overload - def __get__(self, obj: tables.Table, objtype: type) -> pa.FixedSizeBinaryArray: - ... + def __get__(self, obj: tables.Table, objtype: type) -> pa.FixedSizeBinaryArray: ... def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[Self, pa.FixedSizeBinaryArray]: if obj is None: @@ -1127,12 +1075,10 @@ def __init__( ) @overload - def __get__(self, obj: None, objtype: type) -> Self: - ... + def __get__(self, obj: None, objtype: type) -> Self: ... @overload - def __get__(self, obj: tables.Table, objtype: type) -> pa.Decimal128Array: - ... + def __get__(self, obj: tables.Table, objtype: type) -> pa.Decimal128Array: ... def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[Self, pa.Decimal128Array]: if obj is None: @@ -1171,12 +1117,10 @@ def __init__( ) @overload - def __get__(self, obj: None, objtype: type) -> Self: - ... + def __get__(self, obj: None, objtype: type) -> Self: ... @overload - def __get__(self, obj: tables.Table, objtype: type) -> pa.Decimal256Array: - ... + def __get__(self, obj: tables.Table, objtype: type) -> pa.Decimal256Array: ... def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[Self, pa.Decimal256Array]: if obj is None: @@ -1200,12 +1144,10 @@ def __init__( super().__init__(pa.null(), nullable=True, metadata=metadata, validator=validator) @overload - def __get__(self, obj: None, objtype: type) -> Self: - ... + def __get__(self, obj: None, objtype: type) -> Self: ... @overload - def __get__(self, obj: tables.Table, objtype: type) -> pa.NullArray: - ... + def __get__(self, obj: tables.Table, objtype: type) -> pa.NullArray: ... def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[Self, pa.NullArray]: if obj is None: @@ -1243,12 +1185,10 @@ def __init__( super().__init__(pa.list_(value_type, -1), nullable=nullable, metadata=metadata, validator=validator) @overload - def __get__(self, obj: None, objtype: type) -> Self: - ... + def __get__(self, obj: None, objtype: type) -> Self: ... @overload - def __get__(self, obj: tables.Table, objtype: type) -> pa.ListArray: - ... + def __get__(self, obj: tables.Table, objtype: type) -> pa.ListArray: ... def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[Self, pa.ListArray]: if obj is None: @@ -1288,12 +1228,10 @@ def __init__( ) @overload - def __get__(self, obj: None, objtype: type) -> Self: - ... + def __get__(self, obj: None, objtype: type) -> Self: ... @overload - def __get__(self, obj: tables.Table, objtype: type) -> pa.FixedSizeListArray: - ... + def __get__(self, obj: tables.Table, objtype: type) -> pa.FixedSizeListArray: ... def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[Self, pa.FixedSizeListArray]: if obj is None: @@ -1330,12 +1268,10 @@ def __init__( super().__init__(pa.large_list(value_type), nullable=nullable, metadata=metadata, validator=validator) @overload - def __get__(self, obj: None, objtype: type) -> Self: - ... + def __get__(self, obj: None, objtype: type) -> Self: ... @overload - def __get__(self, obj: tables.Table, objtype: type) -> pa.LargeListArray: - ... + def __get__(self, obj: tables.Table, objtype: type) -> pa.LargeListArray: ... def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[Self, pa.LargeListArray]: if obj is None: @@ -1373,12 +1309,10 @@ def __init__( ) @overload - def __get__(self, obj: None, objtype: type) -> Self: - ... + def __get__(self, obj: None, objtype: type) -> Self: ... @overload - def __get__(self, obj: tables.Table, objtype: type) -> pa.MapArray: - ... + def __get__(self, obj: tables.Table, objtype: type) -> pa.MapArray: ... def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[Self, pa.MapArray]: if obj is None: @@ -1421,12 +1355,10 @@ def __init__( ) @overload - def __get__(self, obj: None, objtype: type) -> Self: - ... + def __get__(self, obj: None, objtype: type) -> Self: ... @overload - def __get__(self, obj: tables.Table, objtype: type) -> pa.DictionaryArray: - ... + def __get__(self, obj: tables.Table, objtype: type) -> pa.DictionaryArray: ... def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[Self, pa.DictionaryArray]: if obj is None: diff --git a/quivr/experimental/shmem.py b/quivr/experimental/shmem.py index 07ef920..3862080 100644 --- a/quivr/experimental/shmem.py +++ b/quivr/experimental/shmem.py @@ -94,8 +94,7 @@ class Partitioning(abc.ABC): """ @abc.abstractmethod - def partition(self, table: T) -> Iterator[T]: - ... + def partition(self, table: T) -> Iterator[T]: ... def partition_func(f: Callable[[T], Iterator[T]]) -> Partitioning: diff --git a/quivr/tables.py b/quivr/tables.py index e5c99ce..b0e0a06 100644 --- a/quivr/tables.py +++ b/quivr/tables.py @@ -44,8 +44,7 @@ class ArrowArrayProvider(Protocol): A Protocol which describes objects that support the Arrow custom array extension protocol. """ - def __arrow_array__(self, type: Optional[pa.DataType] = None) -> pa.Array: - ... + def __arrow_array__(self, type: Optional[pa.DataType] = None) -> pa.Array: ... AttributeValueType: TypeAlias = Union[int, float, str] From ce326de004bfea6cbf2346b85dbf582e408a0ae3 Mon Sep 17 00:00:00 2001 From: Alec Koumjian Date: Mon, 20 May 2024 10:32:25 -0400 Subject: [PATCH 3/5] typing --- quivr/concat.py | 14 ++++++++------ quivr/tables.py | 3 ++- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/quivr/concat.py b/quivr/concat.py index 5e1ea9f..a014d51 100644 --- a/quivr/concat.py +++ b/quivr/concat.py @@ -1,4 +1,4 @@ -from typing import Iterator +from typing import Iterator, List import pyarrow as pa @@ -22,14 +22,16 @@ def concatenate(values: Iterator[tables.AnyTable], defrag: bool = True) -> table memory. Defaults to True. """ - if len(values) == 0: + values_list: List[tables.AnyTable] = list(values) + + if len(values_list) == 0: raise ValueError("No values to concatenate") batches = [] first_full = False # Find the first non-empty table to get the class - for v in values: + for v in values_list: if not first_full and len(v) > 0: first_cls = v.__class__ first_val = v @@ -39,12 +41,12 @@ def concatenate(values: Iterator[tables.AnyTable], defrag: bool = True) -> table # No non-empty tables found so lets pick the first table # to get the class and attributes if not first_full: - first_cls = values[0].__class__ - first_val = values[0] + first_cls = values_list[0].__class__ + first_val = values_list[0] # Scan the values and now make sure they are all the same class # as the first non-empty table - for v in values: + for v in values_list: batches += v.table.to_batches() if v.__class__ != first_cls: raise errors.TablesNotCompatibleError("All tables must be the same class to concatenate") diff --git a/quivr/tables.py b/quivr/tables.py index b0e0a06..ceba67a 100644 --- a/quivr/tables.py +++ b/quivr/tables.py @@ -49,10 +49,11 @@ def __arrow_array__(self, type: Optional[pa.DataType] = None) -> pa.Array: ... AttributeValueType: TypeAlias = Union[int, float, str] DataSourceType: TypeAlias = Union[ - pa.Array, list[Any], "Table", pd.Series, npt.NDArray[Any], ArrowArrayProvider + pa.Array, list[Any], "Table", pd.Series[Any], npt.NDArray[Any], ArrowArrayProvider ] AnyTable = TypeVar("AnyTable", bound="Table") + # If a table uses any of the following names, it will break quivr # internals entirely, so they must be rejected. _FORBIDDEN_COLUMN_NAMES = { From c61619696e1d5eefbd4ee9ebd742f2e83b330796 Mon Sep 17 00:00:00 2001 From: Alec Koumjian Date: Mon, 20 May 2024 11:00:54 -0400 Subject: [PATCH 4/5] works now --- quivr/concat.py | 4 ++-- quivr/linkage.py | 4 ++-- quivr/tables.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/quivr/concat.py b/quivr/concat.py index a014d51..54e0b28 100644 --- a/quivr/concat.py +++ b/quivr/concat.py @@ -1,11 +1,11 @@ -from typing import Iterator, List +from typing import Iterable, List import pyarrow as pa from . import defragment, errors, tables -def concatenate(values: Iterator[tables.AnyTable], defrag: bool = True) -> tables.AnyTable: +def concatenate(values: Iterable[tables.AnyTable], defrag: bool = True) -> tables.AnyTable: """Concatenate a collection of Tables into a single Table. All input Tables be of the same class, and have the same attribute diff --git a/quivr/linkage.py b/quivr/linkage.py index 9fbcbbb..b5acba7 100644 --- a/quivr/linkage.py +++ b/quivr/linkage.py @@ -422,12 +422,12 @@ def _concatenate_linkage_components( right_keys: List[pa.Array], ) -> Tuple[LeftTable, RightTable, pa.Array, pa.Array]: try: - left_table: LeftTable = concat.concatenate(left_tables) # type: ignore + left_table: LeftTable = concat.concatenate(left_tables) except errors.TablesNotCompatibleError as e: raise errors.LinkageCombinationError("Left tables are not compatible") from e try: - right_table: RightTable = concat.concatenate(right_tables) # type: ignore + right_table: RightTable = concat.concatenate(right_tables) except errors.TablesNotCompatibleError as e: raise errors.LinkageCombinationError("Right tables are not compatible") from e diff --git a/quivr/tables.py b/quivr/tables.py index ceba67a..1bac858 100644 --- a/quivr/tables.py +++ b/quivr/tables.py @@ -49,7 +49,7 @@ def __arrow_array__(self, type: Optional[pa.DataType] = None) -> pa.Array: ... AttributeValueType: TypeAlias = Union[int, float, str] DataSourceType: TypeAlias = Union[ - pa.Array, list[Any], "Table", pd.Series[Any], npt.NDArray[Any], ArrowArrayProvider + pa.Array, list[Any], "Table", pd.Series[Any], npt.NDArray[Any], "ArrowArrayProvider" ] AnyTable = TypeVar("AnyTable", bound="Table") From c488882646e4162a67576d33199925a47b3dc249 Mon Sep 17 00:00:00 2001 From: Alec Koumjian Date: Mon, 20 May 2024 11:11:53 -0400 Subject: [PATCH 5/5] passing locally --- quivr/tables.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/quivr/tables.py b/quivr/tables.py index 1bac858..9eeada5 100644 --- a/quivr/tables.py +++ b/quivr/tables.py @@ -49,7 +49,7 @@ def __arrow_array__(self, type: Optional[pa.DataType] = None) -> pa.Array: ... AttributeValueType: TypeAlias = Union[int, float, str] DataSourceType: TypeAlias = Union[ - pa.Array, list[Any], "Table", pd.Series[Any], npt.NDArray[Any], "ArrowArrayProvider" + pa.Array, list[Any], "Table", "pd.Series[Any]", npt.NDArray[Any], "ArrowArrayProvider" ] AnyTable = TypeVar("AnyTable", bound="Table")