diff --git a/convert.go b/convert.go new file mode 100644 index 0000000..3964a91 --- /dev/null +++ b/convert.go @@ -0,0 +1,8 @@ +package sqlx + +import ( + _ "unsafe" +) + +//go:linkname convertAssign database/sql.convertAssign +func convertAssign(dest, src interface{}) error diff --git a/reflectx/reflect.go b/reflectx/reflect.go index 8ec6a13..beaaa43 100644 --- a/reflectx/reflect.go +++ b/reflectx/reflect.go @@ -207,8 +207,7 @@ func FieldByIndexes(v reflect.Value, indexes []int) reflect.Value { v = reflect.Indirect(v).Field(i) // if this is a pointer and it's nil, allocate a new value and set it if v.Kind() == reflect.Ptr && v.IsNil() { - alloc := reflect.New(Deref(v.Type())) - v.Set(alloc) + v.Set(reflect.New(v.Type().Elem())) } if v.Kind() == reflect.Map && v.IsNil() { v.Set(reflect.MakeMap(v.Type())) diff --git a/sqlx.go b/sqlx.go index 8259a4f..dda1ce6 100644 --- a/sqlx.go +++ b/sqlx.go @@ -624,7 +624,7 @@ func (r *Rows) StructScan(dest interface{}) error { r.started = true } - err := fieldsByTraversal(v, r.fields, r.values, true) + err := fieldsByTraversal(v, r.fields, r.values) if err != nil { return err } @@ -784,7 +784,7 @@ func (r *Row) scanAny(dest interface{}, structOnly bool) error { } values := make([]interface{}, len(columns)) - err = fieldsByTraversal(v, fields, values, true) + err = fieldsByTraversal(v, fields, values) if err != nil { return err } @@ -957,7 +957,7 @@ func scanAll(rows rowsi, dest interface{}, structOnly bool) error { vp = reflect.New(base) v = reflect.Indirect(vp) - err = fieldsByTraversal(v, fields, values, true) + err = fieldsByTraversal(v, fields, values) if err != nil { return err } @@ -1023,7 +1023,7 @@ func baseType(t reflect.Type, expected reflect.Kind) (reflect.Type, error) { // when iterating over many rows. Empty traversals will get an interface pointer. // Because of the necessity of requesting ptrs or values, it's considered a bit too // specialized for inclusion in reflectx itself. -func fieldsByTraversal(v reflect.Value, traversals [][]int, values []interface{}, ptrs bool) error { +func fieldsByTraversal(v reflect.Value, traversals [][]int, values []interface{}) error { v = reflect.Indirect(v) if v.Kind() != reflect.Struct { return errors.New("argument not a struct") @@ -1032,23 +1032,38 @@ func fieldsByTraversal(v reflect.Value, traversals [][]int, values []interface{} for i, traversal := range traversals { if len(traversal) == 0 { values[i] = new(interface{}) - continue - } - f := reflectx.FieldByIndexes(v, traversal) - if ptrs { - values[i] = f.Addr().Interface() + } else if len(traversal) == 1 { + values[i] = reflectx.FieldByIndexes(v, traversal).Addr().Interface() } else { - values[i] = f.Interface() + // reflectx.FieldByIndexes initializes pointer fields, including pointers to nested structs. + // Use optDest to delay it until the first non-NULL value is scanned into a field of a nested struct. + // That way we can support LEFT JOINs with optional nested structs. + traversal := traversal + values[i] = optDest(func() interface{} { + return reflectx.FieldByIndexes(v, traversal).Addr().Interface() + }) } } return nil } -func missingFields(transversals [][]int) (field int, err error) { - for i, t := range transversals { +func missingFields(traversals [][]int) (field int, err error) { + for i, t := range traversals { if len(t) == 0 { return i, errors.New("missing field") } } return 0, nil } + +// optDest will only forward the Scan to the nested value if +// the database value is not nil. +type optDest func() interface{} + +// Scan implements sql.Scanner. +func (dest optDest) Scan(src interface{}) error { + if src == nil { + return nil + } + return convertAssign(dest(), src) +} diff --git a/sqlx_context_test.go b/sqlx_context_test.go index 91c5cba..c5e81bc 100644 --- a/sqlx_context_test.go +++ b/sqlx_context_test.go @@ -437,12 +437,17 @@ func TestNamedQueryContext(t *testing.T) { "FIRST" text NULL, last_name text NULL, "EMAIL" text NULL + ); + CREATE TABLE persondetails ( + email text NULL, + notes text NULL );`, drop: ` drop table person; drop table jsperson; drop table place; drop table placeperson; + drop table persondetails; `, } @@ -643,6 +648,229 @@ func TestNamedQueryContext(t *testing.T) { t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp2.Place.ID) } } + + rows.Close() + + type Owner struct { + Email *string `db:"email"` + FirstName string `db:"first_name"` + LastName string `db:"last_name"` + } + + // Test optional nested structs with left join + type PlaceOwner struct { + Place Place `db:"place"` + Owner *Owner `db:"owner"` + } + + pl = Place{ + Name: sql.NullString{String: "the-house", Valid: true}, + } + + q4 := `INSERT INTO place (id, name) VALUES (2, :name)` + _, err = db.NamedExecContext(ctx, q4, pl) + if err != nil { + log.Fatal(err) + } + + id = 2 + pp.Place.ID = id + + q5 := `INSERT INTO placeperson (first_name, last_name, email, place_id) VALUES (:first_name, :last_name, :email, :place.id)` + _, err = db.NamedExecContext(ctx, q5, pp) + if err != nil { + log.Fatal(err) + } + + pp3 := &PlaceOwner{} + rows, err = db.NamedQueryContext(ctx, ` + SELECT + place.id AS "place.id", + place.name AS "place.name", + placeperson.first_name "owner.first_name", + placeperson.last_name "owner.last_name", + placeperson.email "owner.email" + FROM place + LEFT JOIN placeperson ON false -- null left join + WHERE + place.id=:place.id`, pp) + if err != nil { + log.Fatal(err) + } + for rows.Next() { + err = rows.StructScan(pp3) + if err != nil { + t.Error(err) + } + if pp3.Owner != nil { + t.Error("Expected `Owner` to be nil") + } + if pp3.Place.Name.String != "the-house" { + t.Error("Expected place name of `the-house`, got " + pp3.Place.Name.String) + } + if pp3.Place.ID != pp.Place.ID { + t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp3.Place.ID) + } + } + + rows.Close() + + pp4 := &PlaceOwner{} + rows, err = db.NamedQueryContext(ctx, ` + SELECT + place.id AS "place.id", + place.name AS "place.name", + placeperson.first_name "owner.first_name", + placeperson.last_name "owner.last_name", + placeperson.email "owner.email" + FROM place + LEFT JOIN placeperson ON placeperson.place_id = place.id + WHERE + place.id=:place.id`, pp) + if err != nil { + log.Fatal(err) + } + for rows.Next() { + err = rows.StructScan(pp4) + if err != nil { + t.Error(err) + } + if pp4.Owner == nil { + t.Error("Expected `Owner` to not be nil") + } + if pp4.Owner.FirstName != "ben" { + t.Error("Expected first name of `ben`, got " + pp4.Owner.FirstName) + } + if pp4.Owner.LastName != "doe" { + t.Error("Expected first name of `doe`, got " + pp4.Owner.LastName) + } + if pp4.Place.Name.String != "the-house" { + t.Error("Expected place name of `the-house`, got " + pp4.Place.Name.String) + } + if pp4.Place.ID != pp.Place.ID { + t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp4.Place.ID) + } + } + + type Details struct { + Email string `db:"email"` + Notes string `db:"notes"` + } + + type OwnerDetails struct { + Email *string `db:"email"` + FirstName string `db:"first_name"` + LastName string `db:"last_name"` + Details *Details `db:"details"` + } + + type PlaceOwnerDetails struct { + Place Place `db:"place"` + Owner *OwnerDetails `db:"owner"` + } + + pp5 := &PlaceOwnerDetails{} + rows, err = db.NamedQueryContext(ctx, ` + SELECT + place.id AS "place.id", + place.name AS "place.name", + placeperson.first_name "owner.first_name", + placeperson.last_name "owner.last_name", + placeperson.email "owner.email", + persondetails.email "owner.details.email", + persondetails.notes "owner.details.notes" + FROM place + LEFT JOIN placeperson ON placeperson.place_id = place.id + LEFT JOIN persondetails ON false + WHERE + place.id=:place.id`, pp) + if err != nil { + log.Fatal(err) + } + for rows.Next() { + err = rows.StructScan(pp5) + if err != nil { + t.Error(err) + } + if pp5.Owner == nil { + t.Error("Expected `Owner`, to not be nil") + } + if pp5.Owner.FirstName != "ben" { + t.Error("Expected first name of `ben`, got " + pp5.Owner.FirstName) + } + if pp5.Owner.LastName != "doe" { + t.Error("Expected first name of `doe`, got " + pp5.Owner.LastName) + } + if pp5.Place.Name.String != "the-house" { + t.Error("Expected place name of `the-house`, got " + pp5.Place.Name.String) + } + if pp5.Place.ID != pp.Place.ID { + t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp5.Place.ID) + } + if pp5.Owner.Details != nil { + t.Error("Expected `Details` to be nil") + } + } + + details := Details{ + Email: pp.Email.String, + Notes: "this is a test person", + } + + q6 := `INSERT INTO persondetails (email, notes) VALUES (:email, :notes)` + _, err = db.NamedExecContext(ctx, q6, details) + if err != nil { + log.Fatal(err) + } + + pp6 := &PlaceOwnerDetails{} + rows, err = db.NamedQueryContext(ctx, ` + SELECT + place.id AS "place.id", + place.name AS "place.name", + placeperson.first_name "owner.first_name", + placeperson.last_name "owner.last_name", + placeperson.email "owner.email", + persondetails.email "owner.details.email", + persondetails.notes "owner.details.notes" + FROM place + LEFT JOIN placeperson ON placeperson.place_id = place.id + LEFT JOIN persondetails ON persondetails.email = placeperson.email + WHERE + place.id=:place.id`, pp) + if err != nil { + log.Fatal(err) + } + for rows.Next() { + err = rows.StructScan(pp6) + if err != nil { + t.Error(err) + } + if pp6.Owner == nil { + t.Error("Expected `Owner` to not be nil") + } + if pp6.Owner.FirstName != "ben" { + t.Error("Expected first name of `ben`, got " + pp6.Owner.FirstName) + } + if pp6.Owner.LastName != "doe" { + t.Error("Expected first name of `doe`, got " + pp6.Owner.LastName) + } + if pp6.Place.Name.String != "the-house" { + t.Error("Expected place name of `the-house`, got " + pp6.Place.Name.String) + } + if pp6.Place.ID != pp.Place.ID { + t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp6.Place.ID) + } + if pp6.Owner.Details == nil { + t.Error("Expected `Details` to not be nil") + } + if pp6.Owner.Details.Email != details.Email { + t.Errorf("Expected details email of %v, got %v", details.Email, pp6.Owner.Details.Email) + } + if pp6.Owner.Details.Notes != details.Notes { + t.Errorf("Expected details notes of %v, got %v", details.Notes, pp6.Owner.Details.Notes) + } + } }) }