Skip to content

Commit

Permalink
support storage of array type
Browse files Browse the repository at this point in the history
  • Loading branch information
yangzq50 committed Jan 16, 2025
1 parent 23aa288 commit a539604
Show file tree
Hide file tree
Showing 10 changed files with 499 additions and 189 deletions.
52 changes: 34 additions & 18 deletions src/executor/operator/physical_export.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import stl;
import logical_type;
import embedding_info;
import sparse_info;
import array_info;
import status;
import buffer_manager;
import default_values;
Expand Down Expand Up @@ -208,6 +209,7 @@ SizeT PhysicalExport::ExportToCSV(QueryContext *query_context, ExportOperatorSta
for (SizeT select_column_idx = 0; select_column_idx < select_column_count; ++select_column_idx) {
Value v = column_vectors[select_column_idx].GetValue(row_idx);
switch (v.type().type()) {
case LogicalType::kArray:
case LogicalType::kEmbedding:
case LogicalType::kMultiVector:
case LogicalType::kTensor:
Expand Down Expand Up @@ -497,6 +499,10 @@ SizeT PhysicalExport::ExportToFVECS(QueryContext *query_context, ExportOperatorS
return row_count;
}

SharedPtr<arrow::DataType> GetArrowType(const DataType &column_data_type);

SharedPtr<arrow::Array> BuildArrowArray(const ColumnDef *column_def, const ColumnVector &column_vector, const Vector<u32> &block_rows_for_output);

SizeT PhysicalExport::ExportToPARQUET(QueryContext *query_context, ExportOperatorState *export_op_state) {
const Vector<SharedPtr<ColumnDef>> &column_defs = table_entry_->column_defs();
Vector<ColumnID> select_columns;
Expand All @@ -516,7 +522,7 @@ SizeT PhysicalExport::ExportToPARQUET(QueryContext *query_context, ExportOperato
Vector<SharedPtr<arrow::Field>> fields;
for (auto &column_id : select_columns) {
ColumnDef *column_def = column_defs[column_id].get();
auto arrow_type = GetArrowType(column_def);
auto arrow_type = GetArrowType(*(column_def->type()));
fields.emplace_back(::arrow::field(column_def->name(), std::move(arrow_type)));
}

Expand Down Expand Up @@ -664,9 +670,8 @@ SizeT PhysicalExport::ExportToPARQUET(QueryContext *query_context, ExportOperato
return row_count;
}

SharedPtr<arrow::DataType> PhysicalExport::GetArrowType(ColumnDef *column_def) {
auto &column_type = column_def->type();
switch (const auto column_logical_type = column_type->type(); column_logical_type) {
SharedPtr<arrow::DataType> GetArrowType(const DataType &column_data_type) {
switch (const auto column_logical_type = column_data_type.type(); column_logical_type) {
case LogicalType::kBoolean:
return arrow::boolean();
case LogicalType::kTinyInt:
Expand Down Expand Up @@ -695,8 +700,7 @@ SharedPtr<arrow::DataType> PhysicalExport::GetArrowType(ColumnDef *column_def) {
case LogicalType::kVarchar:
return arrow::utf8();
case LogicalType::kSparse: {
const auto *sparse_info = static_cast<const SparseInfo *>(column_def->type()->type_info().get());

const auto *sparse_info = static_cast<const SparseInfo *>(column_data_type.type_info().get());
SharedPtr<arrow::DataType> index_type;
Optional<SharedPtr<arrow::DataType>> value_type = None;
switch (sparse_info->IndexType()) {
Expand Down Expand Up @@ -775,7 +779,7 @@ SharedPtr<arrow::DataType> PhysicalExport::GetArrowType(ColumnDef *column_def) {
case LogicalType::kMultiVector:
case LogicalType::kTensor:
case LogicalType::kTensorArray: {
const auto *embedding_info = static_cast<const EmbeddingInfo *>(column_type->type_info().get());
const auto *embedding_info = static_cast<const EmbeddingInfo *>(column_data_type.type_info().get());
const SizeT dimension = embedding_info->Dimension();
SharedPtr<arrow::DataType> arrow_embedding_elem_type;
switch (embedding_info->Type()) {
Expand Down Expand Up @@ -839,11 +843,15 @@ SharedPtr<arrow::DataType> PhysicalExport::GetArrowType(ColumnDef *column_def) {
UnrecoverableError("Unreachable code!");
return {};
}
case LogicalType::kArray: {
const auto *array_info = static_cast<const ArrayInfo *>(column_data_type.type_info().get());
auto element_arrow_type = GetArrowType(array_info->ElemType());
return ::arrow::list(std::move(element_arrow_type));
}
case LogicalType::kRowID:
case LogicalType::kInterval:
case LogicalType::kHugeInt:
case LogicalType::kDecimal:
case LogicalType::kArray:
case LogicalType::kTuple:
case LogicalType::kPoint:
case LogicalType::kLine:
Expand All @@ -863,12 +871,9 @@ SharedPtr<arrow::DataType> PhysicalExport::GetArrowType(ColumnDef *column_def) {
return nullptr;
}

SharedPtr<arrow::Array>
PhysicalExport::BuildArrowArray(ColumnDef *column_def, const ColumnVector &column_vector, const Vector<u32> &block_rows_for_output) {
SharedPtr<arrow::ArrayBuilder> array_builder = nullptr;
auto &column_type = column_def->type();

switch (const auto column_logical_type = column_type->type(); column_logical_type) {
SharedPtr<arrow::ArrayBuilder> GetArrowBuilder(const DataType &column_type) {
SharedPtr<arrow::ArrayBuilder> array_builder{};
switch (const auto column_logical_type = column_type.type(); column_logical_type) {
case LogicalType::kBoolean: {
array_builder = MakeShared<arrow::BooleanBuilder>();
break;
Expand Down Expand Up @@ -923,7 +928,7 @@ PhysicalExport::BuildArrowArray(ColumnDef *column_def, const ColumnVector &colum
break;
}
case LogicalType::kSparse: {
const auto *sparse_info = static_cast<const SparseInfo *>(column_def->type()->type_info().get());
const auto *sparse_info = static_cast<const SparseInfo *>(column_type.type_info().get());
SharedPtr<arrow::ArrayBuilder> index_builder = nullptr;
SharedPtr<arrow::ArrayBuilder> value_builder = nullptr;
switch (sparse_info->IndexType()) {
Expand Down Expand Up @@ -1016,7 +1021,7 @@ PhysicalExport::BuildArrowArray(ColumnDef *column_def, const ColumnVector &colum
case LogicalType::kMultiVector:
case LogicalType::kTensor:
case LogicalType::kTensorArray: {
const auto *embedding_info = static_cast<const EmbeddingInfo *>(column_type->type_info().get());
const auto *embedding_info = static_cast<const EmbeddingInfo *>(column_type.type_info().get());
SharedPtr<::arrow::ArrayBuilder> embedding_element_builder;
switch (embedding_info->Type()) {
case EmbeddingDataType::kElemBit: {
Expand Down Expand Up @@ -1085,11 +1090,16 @@ PhysicalExport::BuildArrowArray(ColumnDef *column_def, const ColumnVector &colum
UnrecoverableError("Unreachable code!");
break;
}
case LogicalType::kArray: {
const auto *array_info = static_cast<const ArrayInfo *>(column_type.type_info().get());
auto element_arrow_builder = GetArrowBuilder(array_info->ElemType());
array_builder = MakeShared<::arrow::ListBuilder>(arrow::DefaultMemoryPool(), element_arrow_builder);
break;
}
case LogicalType::kRowID:
case LogicalType::kInterval:
case LogicalType::kHugeInt:
case LogicalType::kDecimal:
case LogicalType::kArray:
case LogicalType::kTuple:
case LogicalType::kPoint:
case LogicalType::kLine:
Expand All @@ -1106,10 +1116,16 @@ PhysicalExport::BuildArrowArray(ColumnDef *column_def, const ColumnVector &colum
UnrecoverableError(error_message);
}
}
return array_builder;
}

SharedPtr<arrow::Array> BuildArrowArray(const ColumnDef *column_def, const ColumnVector &column_vector, const Vector<u32> &block_rows_for_output) {
auto &column_type = column_def->type();
SharedPtr<arrow::ArrayBuilder> array_builder = GetArrowBuilder(*column_type);

for (const auto idx : block_rows_for_output) {
auto value = column_vector.GetValue(idx);
value.AppendToArrowArray(column_type, array_builder);
value.AppendToArrowArray(*column_type, array_builder.get());
}

SharedPtr<arrow::Array> array;
Expand Down
5 changes: 0 additions & 5 deletions src/executor/operator/physical_export.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,6 @@ public:

inline char delimiter() const { return delimiter_; }

private:
SharedPtr<arrow::DataType> GetArrowType(ColumnDef *column_def);

SharedPtr<arrow::Array> BuildArrowArray(ColumnDef *column_def, const ColumnVector &column_vectors, const Vector<u32> &block_rows_for_output);

private:
SharedPtr<Vector<String>> output_names_{};
SharedPtr<Vector<SharedPtr<DataType>>> output_types_{};
Expand Down
2 changes: 2 additions & 0 deletions src/parser/type/info/array_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class ArrayInfo : public TypeInfo {
// Array always costs 8 bytes. Real data is stored in heap memory.
[[nodiscard]] size_t Size() const override { return sizeof(ArrayType); }

[[nodiscard]] size_t ElemSize() const { return elem_type_.Size(); }

[[nodiscard]] nlohmann::json Serialize() const override;

[[nodiscard]] std::string ToString() const override;
Expand Down
3 changes: 2 additions & 1 deletion src/parser/type/internal_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#pragma once

#include "type/complex/array_type.h"
#include "type/complex/embedding_type.h"
#include "type/complex/multi_vector_type.h"
#include "type/complex/row_id.h"
Expand Down Expand Up @@ -71,7 +72,7 @@ using TimestampT = TimestampType;
using IntervalT = IntervalType;

// Nest types
using ArrayT = std::vector<Value>;
using ArrayT = ArrayType;
using TupleT = std::vector<Value>;
// using ArrayT = ArrayType;
// using TupleT = TupleType;
Expand Down
Loading

0 comments on commit a539604

Please sign in to comment.