Skip to content

Commit

Permalink
Merge 'bindings/go Support blob types in query arguments, free non-gc…
Browse files Browse the repository at this point in the history
… allocations' from Preston Thorpe

This PR fixes/adds support for the Blob type and adds the appropriate
tests.
Types created on the Go side will be cleaned up rather quickly if
nothing is referencing them, so this approach uses `runtime.Pinner` to
pin the bytes in memory so the pointers will be valid when Rust uses
`from_raw_parts` and then owns a new vec. They are then cleaned up after
the FFI call with `pinner.Unpin`.

Closes #822
  • Loading branch information
penberg committed Jan 30, 2025
2 parents c7c3461 + d03ed35 commit d25ccf0
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 74 deletions.
17 changes: 10 additions & 7 deletions bindings/go/limbo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package limbo_test

import (
"database/sql"
"fmt"
"testing"

_ "limbo"
Expand Down Expand Up @@ -76,7 +77,7 @@ func TestQuery(t *testing.T) {
}
defer rows.Close()

expectedCols := []string{"foo", "bar"}
expectedCols := []string{"foo", "bar", "baz"}
cols, err := rows.Columns()
if err != nil {
t.Fatalf("Error getting columns: %v", err)
Expand All @@ -93,13 +94,15 @@ func TestQuery(t *testing.T) {
for rows.Next() {
var a int
var b string
err = rows.Scan(&a, &b)
var c []byte
err = rows.Scan(&a, &b, &c)
if err != nil {
t.Fatalf("Error scanning row: %v", err)
}
if a != i || b != rowsMap[i] {
t.Fatalf("Expected %d, %s, got %d, %s", i, rowsMap[i], a, b)
if a != i || b != rowsMap[i] || string(c) != rowsMap[i] {
t.Fatalf("Expected %d, %s, %s, got %d, %s, %b", i, rowsMap[i], rowsMap[i], a, b, c)
}
fmt.Println("RESULTS: ", a, b, string(c))
i++
}

Expand All @@ -111,7 +114,7 @@ func TestQuery(t *testing.T) {
var rowsMap = map[int]string{1: "hello", 2: "world", 3: "foo", 4: "bar", 5: "baz"}

func createTable(conn *sql.DB) error {
insert := "CREATE TABLE test (foo INT, bar TEXT);"
insert := "CREATE TABLE test (foo INT, bar TEXT, baz BLOB);"
stmt, err := conn.Prepare(insert)
if err != nil {
return err
Expand All @@ -123,13 +126,13 @@ func createTable(conn *sql.DB) error {

func insertData(conn *sql.DB) error {
for i := 1; i <= 5; i++ {
insert := "INSERT INTO test (foo, bar) VALUES (?, ?);"
insert := "INSERT INTO test (foo, bar, baz) VALUES (?, ?, ?);"
stmt, err := conn.Prepare(insert)
if err != nil {
return err
}
defer stmt.Close()
if _, err = stmt.Exec(i, rowsMap[i]); err != nil {
if _, err = stmt.Exec(i, rowsMap[i], []byte(rowsMap[i])); err != nil {
return err
}
}
Expand Down
2 changes: 1 addition & 1 deletion bindings/go/limbo_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func loadLibrary() error {
for _, path := range paths {
libPath := filepath.Join(path, libraryName)
if _, err := os.Stat(libPath); err == nil {
slib, dlerr := purego.Dlopen(libPath, purego.RTLD_LAZY)
slib, dlerr := purego.Dlopen(libPath, purego.RTLD_NOW|purego.RTLD_GLOBAL)
if dlerr != nil {
return fmt.Errorf("failed to load library at %s: %w", libPath, dlerr)
}
Expand Down
3 changes: 1 addition & 2 deletions bindings/go/rs_src/rows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@ pub extern "C" fn rows_get_value(ctx: *mut c_void, col_idx: usize) -> *const c_v

if let Some(ref cursor) = ctx.cursor {
if let Some(value) = cursor.values.get(col_idx) {
let val = LimboValue::from_value(value);
return val.to_ptr();
return LimboValue::from_value(value).to_ptr();
}
}
std::ptr::null()
Expand Down
52 changes: 20 additions & 32 deletions bindings/go/rs_src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,28 +30,22 @@ pub enum ValueType {

#[repr(C)]
pub struct LimboValue {
pub value_type: ValueType,
pub value: ValueUnion,
value_type: ValueType,
value: ValueUnion,
}

#[repr(C)]
pub union ValueUnion {
pub int_val: i64,
pub real_val: f64,
pub text_ptr: *const c_char,
pub blob_ptr: *const c_void,
union ValueUnion {
int_val: i64,
real_val: f64,
text_ptr: *const c_char,
blob_ptr: *const c_void,
}

#[repr(C)]
pub struct Blob {
pub data: *const u8,
pub len: usize,
}

impl Blob {
pub fn to_ptr(&self) -> *const c_void {
self as *const Blob as *const c_void
}
struct Blob {
data: *const u8,
len: i64,
}

pub struct AllocPool {
Expand Down Expand Up @@ -97,12 +91,12 @@ impl ValueUnion {
}

fn from_bytes(b: &[u8]) -> Self {
let blob = Box::new(Blob {
data: b.as_ptr(),
len: b.len() as i64,
});
ValueUnion {
blob_ptr: Blob {
data: b.as_ptr(),
len: b.len(),
}
.to_ptr(),
blob_ptr: Box::into_raw(blob) as *const c_void,
}
}

Expand Down Expand Up @@ -140,12 +134,12 @@ impl ValueUnion {
pub fn to_bytes(&self) -> &[u8] {
let blob = unsafe { self.blob_ptr as *const Blob };
let blob = unsafe { &*blob };
unsafe { std::slice::from_raw_parts(blob.data, blob.len) }
unsafe { std::slice::from_raw_parts(blob.data, blob.len as usize) }
}
}

impl LimboValue {
pub fn new(value_type: ValueType, value: ValueUnion) -> Self {
fn new(value_type: ValueType, value: ValueUnion) -> Self {
LimboValue { value_type, value }
}

Expand Down Expand Up @@ -204,15 +198,9 @@ impl LimboValue {
if unsafe { self.value.blob_ptr.is_null() } {
return limbo_core::Value::Null;
}
let blob_ptr = unsafe { self.value.blob_ptr as *const Blob };
if blob_ptr.is_null() {
limbo_core::Value::Null
} else {
let blob = unsafe { &*blob_ptr };
let data = unsafe { std::slice::from_raw_parts(blob.data, blob.len) };
let borrowed = pool.add_blob(data.to_vec());
limbo_core::Value::Blob(borrowed)
}
let bytes = self.value.to_bytes();
let borrowed = pool.add_blob(bytes.to_vec());
limbo_core::Value::Blob(borrowed)
}
ValueType::Null => limbo_core::Value::Null,
}
Expand Down
12 changes: 8 additions & 4 deletions bindings/go/stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ func (ls *limboStmt) Close() error {
}

func (ls *limboStmt) Exec(args []driver.Value) (driver.Result, error) {
argArray, err := buildArgs(args)
argArray, cleanup, err := buildArgs(args)
defer cleanup()
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -87,7 +88,8 @@ func (ls *limboStmt) Exec(args []driver.Value) (driver.Result, error) {
}

func (st *limboStmt) Query(args []driver.Value) (driver.Rows, error) {
queryArgs, err := buildArgs(args)
queryArgs, cleanup, err := buildArgs(args)
defer cleanup()
if err != nil {
return nil, err
}
Expand All @@ -105,7 +107,8 @@ func (st *limboStmt) Query(args []driver.Value) (driver.Rows, error) {

func (ls *limboStmt) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
stripped := namedValueToValue(args)
argArray, err := getArgsPtr(stripped)
argArray, cleanup, err := getArgsPtr(stripped)
defer cleanup()
if err != nil {
return nil, err
}
Expand All @@ -132,7 +135,8 @@ func (ls *limboStmt) ExecContext(ctx context.Context, query string, args []drive
}

func (ls *limboStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
queryArgs, err := buildNamedArgs(args)
queryArgs, allocs, err := buildNamedArgs(args)
defer allocs()
if err != nil {
return nil, err
}
Expand Down
81 changes: 53 additions & 28 deletions bindings/go/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package limbo
import (
"database/sql/driver"
"fmt"
"runtime"
"unsafe"
)

Expand Down Expand Up @@ -77,6 +78,7 @@ const (
FfiRowsGetValue string = "rows_get_value"
FfiFreeColumns string = "free_columns"
FfiFreeCString string = "free_string"
FfiFreeBlob string = "free_blob"
)

// convert a namedValue slice into normal values until named parameters are supported
Expand All @@ -88,7 +90,7 @@ func namedValueToValue(named []driver.NamedValue) []driver.Value {
return out
}

func buildNamedArgs(named []driver.NamedValue) ([]limboValue, error) {
func buildNamedArgs(named []driver.NamedValue) ([]limboValue, func(), error) {
args := namedValueToValue(named)
return buildArgs(args)
}
Expand Down Expand Up @@ -123,14 +125,14 @@ func (vt valueType) String() string {
// struct to pass Go values over FFI
type limboValue struct {
Type valueType
_ [4]byte // padding to align Value to 8 bytes
_ [4]byte
Value [8]byte
}

// struct to pass byte slices over FFI
type Blob struct {
Data uintptr
Len uint
Len int64
}

// convert a limboValue to a native Go value
Expand All @@ -146,9 +148,11 @@ func toGoValue(valPtr uintptr) interface{} {
return *(*float64)(unsafe.Pointer(&val.Value))
case textVal:
textPtr := *(*uintptr)(unsafe.Pointer(&val.Value))
defer freeCString(textPtr)
return GoString(textPtr)
case blobVal:
blobPtr := *(*uintptr)(unsafe.Pointer(&val.Value))
defer freeBlob(blobPtr)
return toGoBlob(blobPtr)
case nullVal:
return nil
Expand All @@ -157,27 +161,26 @@ func toGoValue(valPtr uintptr) interface{} {
}
}

func getArgsPtr(args []driver.Value) (uintptr, error) {
func getArgsPtr(args []driver.Value) (uintptr, func(), error) {
if len(args) == 0 {
return 0, nil
return 0, nil, nil
}
argSlice, err := buildArgs(args)
argSlice, allocs, err := buildArgs(args)
if err != nil {
return 0, err
return 0, allocs, err
}
return uintptr(unsafe.Pointer(&argSlice[0])), nil
return uintptr(unsafe.Pointer(&argSlice[0])), allocs, nil
}

// convert a byte slice to a Blob type that can be sent over FFI
func makeBlob(b []byte) *Blob {
if len(b) == 0 {
return nil
}
blob := &Blob{
return &Blob{
Data: uintptr(unsafe.Pointer(&b[0])),
Len: uint(len(b)),
Len: int64(len(b)),
}
return blob
}

// converts a blob received via FFI to a native Go byte slice
Expand All @@ -186,7 +189,37 @@ func toGoBlob(blobPtr uintptr) []byte {
return nil
}
blob := (*Blob)(unsafe.Pointer(blobPtr))
return unsafe.Slice((*byte)(unsafe.Pointer(blob.Data)), blob.Len)
if blob.Data == 0 || blob.Len == 0 {
return nil
}
data := unsafe.Slice((*byte)(unsafe.Pointer(blob.Data)), blob.Len)
copied := make([]byte, len(data))
copy(copied, data)
return copied
}

var freeBlobFunc func(uintptr)

func freeBlob(blobPtr uintptr) {
if blobPtr == 0 {
return
}
if freeBlobFunc == nil {
getFfiFunc(&freeBlobFunc, FfiFreeBlob)
}
freeBlobFunc(blobPtr)
}

var freeStringFunc func(uintptr)

func freeCString(cstrPtr uintptr) {
if cstrPtr == 0 {
return
}
if freeStringFunc == nil {
getFfiFunc(&freeStringFunc, FfiFreeCString)
}
freeStringFunc(cstrPtr)
}

func cArrayToGoStrings(arrayPtr uintptr, length uint) []string {
Expand All @@ -207,7 +240,10 @@ func cArrayToGoStrings(arrayPtr uintptr, length uint) []string {
}

// convert a Go slice of driver.Value to a slice of limboValue that can be sent over FFI
func buildArgs(args []driver.Value) ([]limboValue, error) {
// for Blob types, we have to pin them so they are not garbage collected before they can be copied
// into a buffer on the Rust side, so we return a function to unpin them that can be deferred after this call
func buildArgs(args []driver.Value) ([]limboValue, func(), error) {
pinner := new(runtime.Pinner)
argSlice := make([]limboValue, len(args))
for i, v := range args {
limboVal := limboValue{}
Expand All @@ -225,27 +261,16 @@ func buildArgs(args []driver.Value) ([]limboValue, error) {
cstr := CString(val)
*(*uintptr)(unsafe.Pointer(&limboVal.Value)) = uintptr(unsafe.Pointer(cstr))
case []byte:
argSlice[i].Type = blobVal
limboVal.Type = blobVal
blob := makeBlob(val)
pinner.Pin(blob)
*(*uintptr)(unsafe.Pointer(&limboVal.Value)) = uintptr(unsafe.Pointer(blob))
default:
return nil, fmt.Errorf("unsupported type: %T", v)
return nil, pinner.Unpin, fmt.Errorf("unsupported type: %T", v)
}
argSlice[i] = limboVal
}
return argSlice, nil
}

func storeInt64(data *[8]byte, val int64) {
*(*int64)(unsafe.Pointer(data)) = val
}

func storeFloat64(data *[8]byte, val float64) {
*(*float64)(unsafe.Pointer(data)) = val
}

func storePointer(data *[8]byte, ptr *byte) {
*(*uintptr)(unsafe.Pointer(data)) = uintptr(unsafe.Pointer(ptr))
return argSlice, pinner.Unpin, nil
}

/* Credit below (Apache2 License) to:
Expand Down

0 comments on commit d25ccf0

Please sign in to comment.