Skip to content

Commit

Permalink
RowVector integration for PyVelox (facebookincubator#6533)
Browse files Browse the repository at this point in the history
Summary:
This PR includes `RowVector` creation in PyVelox. Also this is an important feature for the continuation of facebookincubator#5023.

Pull Request resolved: facebookincubator#6533

Reviewed By: pedroerp

Differential Revision: D51029745

Pulled By: kgpai

fbshipit-source-id: 4fa27ae3c3cf079f046d4bae8076d9e0f9ad6bc3
  • Loading branch information
vibhatha authored and facebook-github-bot committed Nov 9, 2023
1 parent d30bd47 commit 58979b1
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 0 deletions.
99 changes: 99 additions & 0 deletions pyvelox/pyvelox.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

#pragma once

#include <cassert>

#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/stl_bind.h>
Expand All @@ -27,6 +30,7 @@
#include <velox/parse/TypeResolver.h>
#include <velox/type/Type.h>
#include <velox/type/Variant.h>
#include <velox/vector/ComplexVector.h>
#include <velox/vector/DictionaryVector.h>
#include <velox/vector/FlatVector.h>
#include "folly/json.h"
Expand Down Expand Up @@ -109,6 +113,32 @@ inline velox::variant pyToVariant(const py::handle& obj, const Type& dtype) {
}
}

inline void checkRowVectorBounds(const RowVectorPtr& v, vector_size_t idx) {
if (idx < 0 || size_t(idx) >= v->childrenSize()) {
throw std::out_of_range("Index out of range");
}
}

bool compareRowVector(const RowVectorPtr& u, const RowVectorPtr& v) {
CompareFlags compFlags;
compFlags.nullHandlingMode = CompareFlags::NullHandlingMode::NoStop;
compFlags.equalsOnly = true;
if (u->size() != v->size()) {
return false;
}
for (size_t i = 0; i < u->size(); i++) {
if (u->compare(v.get(), i, i, compFlags) != 0) {
return false;
}
}

return true;
}

inline std::string rowVectorToString(const RowVectorPtr& vector) {
return vector->toString(0, vector->size());
}

static VectorPtr pyToConstantVector(
const py::handle& obj,
vector_size_t length,
Expand Down Expand Up @@ -504,6 +534,75 @@ static void addVectorBindings(
std::move(baseVector),
PyVeloxContext::getSingletonInstance().pool());
});

m.def(
"row_vector",
[](std::vector<std::string>& names,
std::vector<VectorPtr>& children,
const std::optional<py::dict>& nullabilityDict) {
if (children.size() == 0 || names.size() == 0) {
throw py::value_error("RowVector must have children.");
}
std::vector<std::shared_ptr<const Type>> childTypes;
childTypes.reserve(children.size());

size_t vectorSize = children[0]->size();
for (int i = 0; i < children.size(); i++) {
if (i > 0 && children[i]->size() != vectorSize) {
PyErr_SetString(PyExc_ValueError, "Each child must have same size");
throw py::error_already_set();
}
childTypes.push_back(children[i]->type());
}
auto rowType = ROW(std::move(names), std::move(childTypes));

BufferPtr nullabilityBuffer = nullptr;
if (nullabilityDict.has_value()) {
auto nullabilityValues = nullabilityDict.value();
nullabilityBuffer = AlignedBuffer::allocate<bool>(
vectorSize, PyVeloxContext::getSingletonInstance().pool(), true);
for (const auto&& item : nullabilityValues) {
auto row = item.first;
auto nullability = item.second;
if (!py::isinstance<py::int_>(row) ||
!py::isinstance<py::bool_>(nullability)) {
throw py::type_error(
"Nullability must be a dictionary, rowId in int and nullability in boolean.");
}
int rowId = py::cast<int>(row);
if (!(rowId >= 0 && rowId < vectorSize)) {
throw py::type_error("Nullability index out of bounds.");
}
bool nullabilityVal = py::cast<bool>(nullability);
bits::setBit(
nullabilityBuffer->asMutable<uint64_t>(),
rowId,
bits::kNull ? nullabilityVal : !nullabilityVal);
}
}

return std::make_shared<RowVector>(
PyVeloxContext::getSingletonInstance().pool(),
rowType,
nullabilityBuffer,
vectorSize,
children);
},
py::arg("names"),
py::arg("children"),
py::arg("nullability") = std::nullopt);

py::class_<RowVector, BaseVector, RowVectorPtr>(
m, "RowVector", py::module_local(asModuleLocalDefinitions))
.def(
"__len__",
[](RowVectorPtr& v) {
return v->childrenSize() > 0 ? v->childAt(0)->size() : 0;
})
.def("__str__", [](RowVectorPtr& v) { return rowVectorToString(v); })
.def("__eq__", [](RowVectorPtr& u, RowVectorPtr& v) {
return compareRowVector(u, v);
});
}

static void addExpressionBindings(
Expand Down
62 changes: 62 additions & 0 deletions pyvelox/test/test_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,3 +351,65 @@ def test_roundtrip_conversion(self):
self.assertTrue(velox_vector.dtype(), expected_type)
for i in range(0, len(data)):
self.assertEqual(velox_vector[i], data[i])

def test_row_vector_basic(self):
vals = [
pv.from_list([1, 2, 3]),
pv.from_list([4.0, 5.0, 6.0]),
pv.from_list(["a", "b", "c"]),
]

col_names = ["x", "y", "z"]
rw = pv.row_vector(col_names, vals)
rw_str = str(rw)
expected_str = "0: {1, 4, a}\n1: {2, 5, b}\n2: {3, 6, c}"
assert expected_str == rw_str

def test_row_vector_with_nulls(self):
vals = [
pv.from_list([1, 2, 3, 1, 2]),
pv.from_list([4, 5, 6, 4, 5]),
pv.from_list([7, 8, 9, 7, 8]),
pv.from_list([10, 11, 12, 10, 11]),
]

col_names = ["a", "b", "c", "d"]
rw = pv.row_vector(col_names, vals, {0: True, 2: True})
rw_str = str(rw)
expected_str = (
"0: null\n1: {2, 5, 8, 11}\n2: null\n3: {1, 4, 7, 10}\n4: {2, 5, 8, 11}"
)
assert expected_str == rw_str

def test_row_vector_comparison(self):
u = [
pv.from_list([1, 2, 3]),
pv.from_list([7, 4, 9]),
pv.from_list([10, 11, 12]),
]

v = [
pv.from_list([1, 2, 3]),
pv.from_list([7, 8, 9]),
pv.from_list([10, 11, 12]),
]

w = [
pv.from_list([1, 2, 3]),
pv.from_list([7, 8, 9]),
]

u_names = ["a", "b", "c"]
w_names = ["x", "y"]
u_rw = pv.row_vector(u_names, u)
v_rw = pv.row_vector(u_names, v)
w_rw = pv.row_vector(w_names, w)
y_rw = pv.row_vector(u_names, u)
x1_rw = pv.row_vector(u_names, u, {0: True, 2: True})
x2_rw = pv.row_vector(u_names, u, {0: True, 2: True})

assert u_rw != w_rw # num of children doesn't match
assert u_rw != v_rw # data doesn't match
assert u_rw == y_rw # data match
assert x1_rw == x2_rw # with null
assert x1_rw != u_rw # with and without null

0 comments on commit 58979b1

Please sign in to comment.