diff --git a/bindings/go/limbo_test.go b/bindings/go/limbo_test.go index 6aaf3f95c..1a787a149 100644 --- a/bindings/go/limbo_test.go +++ b/bindings/go/limbo_test.go @@ -2,6 +2,7 @@ package limbo_test import ( "database/sql" + "fmt" "testing" _ "limbo" @@ -76,7 +77,7 @@ func TestQuery(t *testing.T) { } defer rows.Close() - expectedCols := []string{"foo", "bar"} + expectedCols := []string{"foo", "bar", "baz"} cols, err := rows.Columns() if err != nil { t.Fatalf("Error getting columns: %v", err) @@ -93,13 +94,15 @@ func TestQuery(t *testing.T) { for rows.Next() { var a int var b string - err = rows.Scan(&a, &b) + var c []byte + err = rows.Scan(&a, &b, &c) if err != nil { t.Fatalf("Error scanning row: %v", err) } - if a != i || b != rowsMap[i] { - t.Fatalf("Expected %d, %s, got %d, %s", i, rowsMap[i], a, b) + if a != i || b != rowsMap[i] || string(c) != rowsMap[i] { + t.Fatalf("Expected %d, %s, %s, got %d, %s, %b", i, rowsMap[i], rowsMap[i], a, b, c) } + fmt.Println("RESULTS: ", a, b, string(c)) i++ } @@ -111,7 +114,7 @@ func TestQuery(t *testing.T) { var rowsMap = map[int]string{1: "hello", 2: "world", 3: "foo", 4: "bar", 5: "baz"} func createTable(conn *sql.DB) error { - insert := "CREATE TABLE test (foo INT, bar TEXT);" + insert := "CREATE TABLE test (foo INT, bar TEXT, baz BLOB);" stmt, err := conn.Prepare(insert) if err != nil { return err @@ -123,13 +126,13 @@ func createTable(conn *sql.DB) error { func insertData(conn *sql.DB) error { for i := 1; i <= 5; i++ { - insert := "INSERT INTO test (foo, bar) VALUES (?, ?);" + insert := "INSERT INTO test (foo, bar, baz) VALUES (?, ?, ?);" stmt, err := conn.Prepare(insert) if err != nil { return err } defer stmt.Close() - if _, err = stmt.Exec(i, rowsMap[i]); err != nil { + if _, err = stmt.Exec(i, rowsMap[i], []byte(rowsMap[i])); err != nil { return err } } diff --git a/bindings/go/limbo_unix.go b/bindings/go/limbo_unix.go index 69464fc2d..ecafa5a85 100644 --- a/bindings/go/limbo_unix.go +++ b/bindings/go/limbo_unix.go @@ -35,7 +35,7 @@ func loadLibrary() error { for _, path := range paths { libPath := filepath.Join(path, libraryName) if _, err := os.Stat(libPath); err == nil { - slib, dlerr := purego.Dlopen(libPath, purego.RTLD_LAZY) + slib, dlerr := purego.Dlopen(libPath, purego.RTLD_NOW|purego.RTLD_GLOBAL) if dlerr != nil { return fmt.Errorf("failed to load library at %s: %w", libPath, dlerr) } diff --git a/bindings/go/rs_src/rows.rs b/bindings/go/rs_src/rows.rs index 189ff84f8..19b526c74 100644 --- a/bindings/go/rs_src/rows.rs +++ b/bindings/go/rs_src/rows.rs @@ -65,8 +65,7 @@ pub extern "C" fn rows_get_value(ctx: *mut c_void, col_idx: usize) -> *const c_v if let Some(ref cursor) = ctx.cursor { if let Some(value) = cursor.values.get(col_idx) { - let val = LimboValue::from_value(value); - return val.to_ptr(); + return LimboValue::from_value(value).to_ptr(); } } std::ptr::null() diff --git a/bindings/go/rs_src/types.rs b/bindings/go/rs_src/types.rs index 334aa6e77..11c9b251f 100644 --- a/bindings/go/rs_src/types.rs +++ b/bindings/go/rs_src/types.rs @@ -30,28 +30,22 @@ pub enum ValueType { #[repr(C)] pub struct LimboValue { - pub value_type: ValueType, - pub value: ValueUnion, + value_type: ValueType, + value: ValueUnion, } #[repr(C)] -pub union ValueUnion { - pub int_val: i64, - pub real_val: f64, - pub text_ptr: *const c_char, - pub blob_ptr: *const c_void, +union ValueUnion { + int_val: i64, + real_val: f64, + text_ptr: *const c_char, + blob_ptr: *const c_void, } #[repr(C)] -pub struct Blob { - pub data: *const u8, - pub len: usize, -} - -impl Blob { - pub fn to_ptr(&self) -> *const c_void { - self as *const Blob as *const c_void - } +struct Blob { + data: *const u8, + len: i64, } pub struct AllocPool { @@ -97,12 +91,12 @@ impl ValueUnion { } fn from_bytes(b: &[u8]) -> Self { + let blob = Box::new(Blob { + data: b.as_ptr(), + len: b.len() as i64, + }); ValueUnion { - blob_ptr: Blob { - data: b.as_ptr(), - len: b.len(), - } - .to_ptr(), + blob_ptr: Box::into_raw(blob) as *const c_void, } } @@ -140,12 +134,12 @@ impl ValueUnion { pub fn to_bytes(&self) -> &[u8] { let blob = unsafe { self.blob_ptr as *const Blob }; let blob = unsafe { &*blob }; - unsafe { std::slice::from_raw_parts(blob.data, blob.len) } + unsafe { std::slice::from_raw_parts(blob.data, blob.len as usize) } } } impl LimboValue { - pub fn new(value_type: ValueType, value: ValueUnion) -> Self { + fn new(value_type: ValueType, value: ValueUnion) -> Self { LimboValue { value_type, value } } @@ -204,15 +198,9 @@ impl LimboValue { if unsafe { self.value.blob_ptr.is_null() } { return limbo_core::Value::Null; } - let blob_ptr = unsafe { self.value.blob_ptr as *const Blob }; - if blob_ptr.is_null() { - limbo_core::Value::Null - } else { - let blob = unsafe { &*blob_ptr }; - let data = unsafe { std::slice::from_raw_parts(blob.data, blob.len) }; - let borrowed = pool.add_blob(data.to_vec()); - limbo_core::Value::Blob(borrowed) - } + let bytes = self.value.to_bytes(); + let borrowed = pool.add_blob(bytes.to_vec()); + limbo_core::Value::Blob(borrowed) } ValueType::Null => limbo_core::Value::Null, } diff --git a/bindings/go/stmt.go b/bindings/go/stmt.go index 5f3632810..02ad3c3eb 100644 --- a/bindings/go/stmt.go +++ b/bindings/go/stmt.go @@ -59,7 +59,8 @@ func (ls *limboStmt) Close() error { } func (ls *limboStmt) Exec(args []driver.Value) (driver.Result, error) { - argArray, err := buildArgs(args) + argArray, cleanup, err := buildArgs(args) + defer cleanup() if err != nil { return nil, err } @@ -87,7 +88,8 @@ func (ls *limboStmt) Exec(args []driver.Value) (driver.Result, error) { } func (st *limboStmt) Query(args []driver.Value) (driver.Rows, error) { - queryArgs, err := buildArgs(args) + queryArgs, cleanup, err := buildArgs(args) + defer cleanup() if err != nil { return nil, err } @@ -105,7 +107,8 @@ func (st *limboStmt) Query(args []driver.Value) (driver.Rows, error) { func (ls *limboStmt) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { stripped := namedValueToValue(args) - argArray, err := getArgsPtr(stripped) + argArray, cleanup, err := getArgsPtr(stripped) + defer cleanup() if err != nil { return nil, err } @@ -132,7 +135,8 @@ func (ls *limboStmt) ExecContext(ctx context.Context, query string, args []drive } func (ls *limboStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { - queryArgs, err := buildNamedArgs(args) + queryArgs, allocs, err := buildNamedArgs(args) + defer allocs() if err != nil { return nil, err } diff --git a/bindings/go/types.go b/bindings/go/types.go index b391d6f0a..78fb96153 100644 --- a/bindings/go/types.go +++ b/bindings/go/types.go @@ -3,6 +3,7 @@ package limbo import ( "database/sql/driver" "fmt" + "runtime" "unsafe" ) @@ -77,6 +78,7 @@ const ( FfiRowsGetValue string = "rows_get_value" FfiFreeColumns string = "free_columns" FfiFreeCString string = "free_string" + FfiFreeBlob string = "free_blob" ) // convert a namedValue slice into normal values until named parameters are supported @@ -88,7 +90,7 @@ func namedValueToValue(named []driver.NamedValue) []driver.Value { return out } -func buildNamedArgs(named []driver.NamedValue) ([]limboValue, error) { +func buildNamedArgs(named []driver.NamedValue) ([]limboValue, func(), error) { args := namedValueToValue(named) return buildArgs(args) } @@ -123,14 +125,14 @@ func (vt valueType) String() string { // struct to pass Go values over FFI type limboValue struct { Type valueType - _ [4]byte // padding to align Value to 8 bytes + _ [4]byte Value [8]byte } // struct to pass byte slices over FFI type Blob struct { Data uintptr - Len uint + Len int64 } // convert a limboValue to a native Go value @@ -146,9 +148,11 @@ func toGoValue(valPtr uintptr) interface{} { return *(*float64)(unsafe.Pointer(&val.Value)) case textVal: textPtr := *(*uintptr)(unsafe.Pointer(&val.Value)) + defer freeCString(textPtr) return GoString(textPtr) case blobVal: blobPtr := *(*uintptr)(unsafe.Pointer(&val.Value)) + defer freeBlob(blobPtr) return toGoBlob(blobPtr) case nullVal: return nil @@ -157,15 +161,15 @@ func toGoValue(valPtr uintptr) interface{} { } } -func getArgsPtr(args []driver.Value) (uintptr, error) { +func getArgsPtr(args []driver.Value) (uintptr, func(), error) { if len(args) == 0 { - return 0, nil + return 0, nil, nil } - argSlice, err := buildArgs(args) + argSlice, allocs, err := buildArgs(args) if err != nil { - return 0, err + return 0, allocs, err } - return uintptr(unsafe.Pointer(&argSlice[0])), nil + return uintptr(unsafe.Pointer(&argSlice[0])), allocs, nil } // convert a byte slice to a Blob type that can be sent over FFI @@ -173,11 +177,10 @@ func makeBlob(b []byte) *Blob { if len(b) == 0 { return nil } - blob := &Blob{ + return &Blob{ Data: uintptr(unsafe.Pointer(&b[0])), - Len: uint(len(b)), + Len: int64(len(b)), } - return blob } // converts a blob received via FFI to a native Go byte slice @@ -186,7 +189,37 @@ func toGoBlob(blobPtr uintptr) []byte { return nil } blob := (*Blob)(unsafe.Pointer(blobPtr)) - return unsafe.Slice((*byte)(unsafe.Pointer(blob.Data)), blob.Len) + if blob.Data == 0 || blob.Len == 0 { + return nil + } + data := unsafe.Slice((*byte)(unsafe.Pointer(blob.Data)), blob.Len) + copied := make([]byte, len(data)) + copy(copied, data) + return copied +} + +var freeBlobFunc func(uintptr) + +func freeBlob(blobPtr uintptr) { + if blobPtr == 0 { + return + } + if freeBlobFunc == nil { + getFfiFunc(&freeBlobFunc, FfiFreeBlob) + } + freeBlobFunc(blobPtr) +} + +var freeStringFunc func(uintptr) + +func freeCString(cstrPtr uintptr) { + if cstrPtr == 0 { + return + } + if freeStringFunc == nil { + getFfiFunc(&freeStringFunc, FfiFreeCString) + } + freeStringFunc(cstrPtr) } func cArrayToGoStrings(arrayPtr uintptr, length uint) []string { @@ -207,7 +240,10 @@ func cArrayToGoStrings(arrayPtr uintptr, length uint) []string { } // convert a Go slice of driver.Value to a slice of limboValue that can be sent over FFI -func buildArgs(args []driver.Value) ([]limboValue, error) { +// for Blob types, we have to pin them so they are not garbage collected before they can be copied +// into a buffer on the Rust side, so we return a function to unpin them that can be deferred after this call +func buildArgs(args []driver.Value) ([]limboValue, func(), error) { + pinner := new(runtime.Pinner) argSlice := make([]limboValue, len(args)) for i, v := range args { limboVal := limboValue{} @@ -225,27 +261,16 @@ func buildArgs(args []driver.Value) ([]limboValue, error) { cstr := CString(val) *(*uintptr)(unsafe.Pointer(&limboVal.Value)) = uintptr(unsafe.Pointer(cstr)) case []byte: - argSlice[i].Type = blobVal + limboVal.Type = blobVal blob := makeBlob(val) + pinner.Pin(blob) *(*uintptr)(unsafe.Pointer(&limboVal.Value)) = uintptr(unsafe.Pointer(blob)) default: - return nil, fmt.Errorf("unsupported type: %T", v) + return nil, pinner.Unpin, fmt.Errorf("unsupported type: %T", v) } argSlice[i] = limboVal } - return argSlice, nil -} - -func storeInt64(data *[8]byte, val int64) { - *(*int64)(unsafe.Pointer(data)) = val -} - -func storeFloat64(data *[8]byte, val float64) { - *(*float64)(unsafe.Pointer(data)) = val -} - -func storePointer(data *[8]byte, ptr *byte) { - *(*uintptr)(unsafe.Pointer(data)) = uintptr(unsafe.Pointer(ptr)) + return argSlice, pinner.Unpin, nil } /* Credit below (Apache2 License) to: