From 7e1bfb7d4e2d01370625cb258b391ada18c224ca Mon Sep 17 00:00:00 2001 From: Ben Hubbard Date: Mon, 5 Aug 2019 15:29:10 -0400 Subject: [PATCH 1/4] Fixes PLIN-2287 Improve Picard Filtering --- delete.go | 15 ++++++--- filter.go | 15 +++++---- filter_test.go | 72 +++++++++++++++++++++++++++++++++++++++++- query/build.go | 20 ++++++++---- query/build_test.go | 8 +++-- query/hydrate.go | 11 +++---- queryparts/table.go | 49 ++++++++-------------------- queryparts/where.go | 13 +++----- reflectutil/reflect.go | 39 ----------------------- tags/tags.go | 12 ++++--- 10 files changed, 139 insertions(+), 115 deletions(-) diff --git a/delete.go b/delete.go index 556579b..503c826 100644 --- a/delete.go +++ b/delete.go @@ -5,7 +5,6 @@ import ( "reflect" sq "github.com/Masterminds/squirrel" - "github.com/skuid/picard/queryparts" "github.com/skuid/picard/reflectutil" "github.com/skuid/picard/query" @@ -22,7 +21,13 @@ func (porm PersistenceORM) DeleteModel(model interface{}) (int64, error) { return 0, err } - tbl, err := query.Build(porm.multitenancyValue, model, queryparts.SelectFilter{}, nil) + metadata, err := tags.GetTableMetadata(model) + if err != nil { + return 0, err + } + + tbl, err := query.Build(porm.multitenancyValue, model, nil, nil, metadata) + if err != nil { return 0, err } @@ -31,7 +36,7 @@ func (porm PersistenceORM) DeleteModel(model interface{}) (int64, error) { lookupPks := make([]interface{}, 0) if len(associations) > 0 { - _, pk := reflectutil.ReflectTableInfo(reflect.TypeOf(model)) + pk := metadata.GetPrimaryKeyColumnName() results, err := porm.FilterModel(FilterRequest{ FilterModel: model, Associations: associations, @@ -41,8 +46,8 @@ func (porm PersistenceORM) DeleteModel(model interface{}) (int64, error) { } for _, result := range results { - val, ok := reflectutil.GetPK(reflect.ValueOf(result)) - if ok { + val := getValueFromLookupString(reflect.ValueOf(result), metadata.GetPrimaryKeyFieldName()) + if val.IsValid() { lookupPks = append(lookupPks, val.Interface()) } } diff --git a/filter.go b/filter.go index f94db19..90180aa 100644 --- a/filter.go +++ b/filter.go @@ -8,7 +8,6 @@ import ( sq "github.com/Masterminds/squirrel" "github.com/skuid/picard/query" qp "github.com/skuid/picard/queryparts" - "github.com/skuid/picard/reflectutil" "github.com/skuid/picard/stringutil" "github.com/skuid/picard/tags" ) @@ -16,10 +15,11 @@ import ( // FilterRequest holds information about a request to filter on a model type FilterRequest struct { FilterModel interface{} + FieldFilters []qp.FieldFilter Associations []tags.Association OrderBy []qp.OrderByRequest Runner sq.BaseRunner - Fields qp.SelectFilter + //Fields []string // For use later whe we implement selecting specific columns } func addOrderBy(builder sq.SelectBuilder, orderBy []qp.OrderByRequest, filterMetadata *tags.TableMetadata, tableAlias string) sq.SelectBuilder { @@ -39,7 +39,7 @@ func addOrderBy(builder sq.SelectBuilder, orderBy []qp.OrderByRequest, filterMet func (p PersistenceORM) getSingleFilterResults(request FilterRequest, filterMetadata *tags.TableMetadata) ([]*reflect.Value, error) { filterModel := request.FilterModel - tbl, err := query.Build(p.multitenancyValue, filterModel, request.Fields, request.Associations) + tbl, err := query.Build(p.multitenancyValue, filterModel, request.FieldFilters, request.Associations, filterMetadata) if err != nil { return nil, err } @@ -67,7 +67,7 @@ func (p PersistenceORM) getMultiFilterResults(request FilterRequest, filterMetad for i := 0; i < modelVal.Len(); i++ { val := modelVal.Index(i) - ftbl, err := query.Build(mtVal, val.Interface(), request.Fields, request.Associations) + ftbl, err := query.Build(mtVal, val.Interface(), request.FieldFilters, request.Associations, filterMetadata) if err != nil { return nil, err } @@ -163,13 +163,14 @@ func (p PersistenceORM) FilterModel(request FilterRequest) ([]interface{}, error if foreignKey != nil { for _, result := range results { newFilter := reflect.Indirect(reflect.New(childType)) - pkval, ok := reflectutil.GetPK(*result) - if !ok { + pkval := getValueFromLookupString(*result, childMetadata.GetPrimaryKeyFieldName()) + + if !pkval.IsValid() { return nil, fmt.Errorf("Missing 'primary_key' tag on type '%v'", result.Type().Name()) } if fmf := newFilter.FieldByName(foreignKey.FieldName); fmf.CanSet() { - fmf.Set(*pkval) + fmf.Set(pkval) } else { return nil, fmt.Errorf("'foreign_key' field '%s' on 'child' type '%v' is not settable", foreignKey.FieldName, newFilter.Type()) } diff --git a/filter_test.go b/filter_test.go index 71c7eed..8b8ebbd 100644 --- a/filter_test.go +++ b/filter_test.go @@ -1301,7 +1301,7 @@ func TestFilterModel(t *testing.T) { Name: "Children", OrderBy: []qp.OrderByRequest{ { - Field: "Name", + Field: "Name", Descending: true, }, }, @@ -1451,6 +1451,76 @@ func TestFilterModel(t *testing.T) { mock.ExpectCommit() }, }, + { + "filter request with additional field filters item", + FilterRequest{ + FilterModel: testdata.ToyModel{}, + FieldFilters: []qp.FieldFilter{ + { + FieldName: "Name", + FilterValue: "Lego", + }, + }, + }, + []interface{}{}, + func(mock sqlmock.Sqlmock) { + mock.ExpectBegin() + mock.ExpectQuery(testdata.FmtSQLRegex(` + SELECT + t0.id AS "t0.id", + t0.organization_id AS "t0.organization_id", + t0.name AS "t0.name", + t0.parent_id AS "t0.parent_id" + FROM toymodel AS t0 + WHERE t0.organization_id = $1 AND t0.name = $2 + `)). + WithArgs(orgID, "Lego"). + WillReturnRows( + sqlmock.NewRows([]string{ + "t0.id", + "t0.organization_id", + "t0.name", + "t0.parent_id", + }), + ) + mock.ExpectCommit() + }, + }, + { + "filter request with additional field filters array", + FilterRequest{ + FilterModel: testdata.ToyModel{}, + FieldFilters: []qp.FieldFilter{ + { + FieldName: "Name", + FilterValue: []string{"Lego", "Matchbox Car", "Nintendo"}, + }, + }, + }, + []interface{}{}, + func(mock sqlmock.Sqlmock) { + mock.ExpectBegin() + mock.ExpectQuery(testdata.FmtSQLRegex(` + SELECT + t0.id AS "t0.id", + t0.organization_id AS "t0.organization_id", + t0.name AS "t0.name", + t0.parent_id AS "t0.parent_id" + FROM toymodel AS t0 + WHERE t0.organization_id = $1 AND t0.name IN ($2,$3,$4) + `)). + WithArgs(orgID, "Lego", "Matchbox Car", "Nintendo"). + WillReturnRows( + sqlmock.NewRows([]string{ + "t0.id", + "t0.organization_id", + "t0.name", + "t0.parent_id", + }), + ) + mock.ExpectCommit() + }, + }, } for _, tc := range testCases { diff --git a/query/build.go b/query/build.go index 2d7a2ac..8b98171 100644 --- a/query/build.go +++ b/query/build.go @@ -14,7 +14,7 @@ import ( Build takes the filter model and returns a query object. It takes the multitenancy value, current reflected value, and any tags */ -func Build(multitenancyVal, model interface{}, fields qp.SelectFilter, associations []tags.Association) (*qp.Table, error) { +func Build(multitenancyVal, model interface{}, filters []qp.FieldFilter, associations []tags.Association, filterMetadata *tags.TableMetadata) (*qp.Table, error) { val, err := stringutil.GetStructValue(model) if err != nil { @@ -23,7 +23,7 @@ func Build(multitenancyVal, model interface{}, fields qp.SelectFilter, associati typ := val.Type() - tbl, err := buildQuery(multitenancyVal, typ, &val, fields, associations, false, 0) + tbl, err := buildQuery(multitenancyVal, typ, &val, filters, associations, false, 0, filterMetadata) if err != nil { return nil, err } @@ -61,14 +61,16 @@ func buildQuery( multitenancyVal interface{}, modelType reflect.Type, modelVal *reflect.Value, - selectFilter qp.SelectFilter, + filters []qp.FieldFilter, associations []tags.Association, onlyJoin bool, counter int, + filterMetadata *tags.TableMetadata, ) (*qp.Table, error) { // Inspect current reflected value, and add select/where clauses - tableName, pkName := reflectutil.ReflectTableInfo(modelType) + pkName := filterMetadata.GetPrimaryKeyColumnName() + tableName := filterMetadata.GetTableName() tbl := NewIndexed(tableName, counter) @@ -128,8 +130,9 @@ func buildQuery( if ok || childOnlyJoin { // Get type, load it as a model so we can build it out refTyp := relatedVal.Type() + refMetadata := tags.TableMetadataFromType(refTyp) - refTbl, err := buildQuery(multitenancyVal, refTyp, &relatedVal, qp.SelectFilter{}, association.Associations, childOnlyJoin, counter+1) + refTbl, err := buildQuery(multitenancyVal, refTyp, &relatedVal, nil, association.Associations, childOnlyJoin, counter+1, refMetadata) if err != nil { return nil, err } @@ -163,8 +166,11 @@ func buildQuery( tbl.AddColumns(cols) - if tableName == selectFilter.TableName && len(selectFilter.Values) > 0 { - tbl.AddWhereIn(selectFilter.FieldName, selectFilter.Values) + if filters != nil && len(filters) > 0 && modelVal != nil { + for _, filter := range filters { + fieldMetadata := filterMetadata.GetField(filter.FieldName) + tbl.AddWhere(fieldMetadata.GetColumnName(), filter.FilterValue) + } } return tbl, nil diff --git a/query/build_test.go b/query/build_test.go index 802e8d7..c7b56a9 100644 --- a/query/build_test.go +++ b/query/build_test.go @@ -1,7 +1,6 @@ package query import ( - "github.com/skuid/picard/queryparts" "testing" "github.com/skuid/picard/metadata" @@ -165,7 +164,12 @@ func TestQueryBuilder(t *testing.T) { t.Run(tc.desc, func(t *testing.T) { assert := assert.New(t) - tbl, err := Build(orgID, tc.model, queryparts.SelectFilter{}, tc.assoc) + metadata, err := tags.GetTableMetadata(tc.model) + if err != nil { + t.Fatal(err) + } + + tbl, err := Build(orgID, tc.model, nil, tc.assoc, metadata) assert.NoError(err) actual, actualArgs, err := tbl.ToSQL() diff --git a/query/hydrate.go b/query/hydrate.go index 959ef14..9bc0aca 100644 --- a/query/hydrate.go +++ b/query/hydrate.go @@ -9,8 +9,8 @@ import ( "reflect" "github.com/skuid/picard/crypto" - "github.com/skuid/picard/stringutil" qp "github.com/skuid/picard/queryparts" + "github.com/skuid/picard/stringutil" "github.com/skuid/picard/tags" ) @@ -26,9 +26,8 @@ func Hydrate(filterModel interface{}, aliasMap map[string]qp.FieldDescriptor, ro // Get the models type and picard tags typ := modelVal.Type() - meta := tags.TableMetadataFromType(typ) - mappedCols, err := mapRows2Cols(meta, aliasMap, rows) + mappedCols, err := mapRows2Cols(aliasMap, rows) if err != nil { return nil, err } @@ -61,7 +60,7 @@ func hydrate(typ reflect.Type, mapped map[string]map[string]interface{}, counter if err != nil { return nil, err } - + if field.IsFK() { refTyp := field.GetRelatedType() // Recursively hydrate this reference field @@ -119,7 +118,7 @@ func setFieldValue(model *reflect.Value, field tags.FieldMetadata, value interfa model.FieldByName(field.GetName()).Set(reflect.ValueOf(value)) } } - + return nil } @@ -150,7 +149,7 @@ This function would return something like: */ -func mapRows2Cols(meta *tags.TableMetadata, aliasMap map[string]qp.FieldDescriptor, rows *sql.Rows) ([]map[string]map[string]interface{}, error) { +func mapRows2Cols(aliasMap map[string]qp.FieldDescriptor, rows *sql.Rows) ([]map[string]map[string]interface{}, error) { results := make([]map[string]map[string]interface{}, 0) cols, err := rows.Columns() diff --git a/queryparts/table.go b/queryparts/table.go index eba9e6e..a96fd28 100644 --- a/queryparts/table.go +++ b/queryparts/table.go @@ -3,6 +3,7 @@ package queryparts import ( "fmt" "strings" + sql "github.com/Masterminds/squirrel" ) @@ -16,16 +17,15 @@ a query by calling tbl := New("my_table") */ type Table struct { - root *Table - Counter int - Alias string - Name string - columns []string - lookups map[string]interface{} - Joins []Join - Wheres []Where + root *Table + Counter int + Alias string + Name string + columns []string + lookups map[string]interface{} + Joins []Join + Wheres []Where MultiTenancy Where - WhereIns []WhereIn } /* @@ -65,17 +65,10 @@ func (t *Table) AddWhere(field string, val interface{}) { }) } -func (t *Table) AddWhereIn(field string, val [] interface{}) { - t.WhereIns = append(t.WhereIns, WhereIn{ - Field: field, - Val: val, - }) -} - func (t *Table) AddMultitenancyWhere(field string, val interface{}) { t.MultiTenancy = Where{ Field: field, - Val: val, + Val: val, } } @@ -153,7 +146,6 @@ func (t *Table) Columns() []string { return cols } - /* FieldAliases returns a map of all columns on a table and that table's joins. */ @@ -196,8 +188,7 @@ func (t *Table) BuildSQL() sql.SelectBuilder { if t.MultiTenancy != (Where{}) { bld = bld.Where( sql.Eq{ - fmt.Sprintf(AliasedField, t.Alias, t.MultiTenancy.Field): - t.MultiTenancy.Val, + fmt.Sprintf(AliasedField, t.Alias, t.MultiTenancy.Field): t.MultiTenancy.Val, }, ) } @@ -206,15 +197,6 @@ func (t *Table) BuildSQL() sql.SelectBuilder { bld = bld.Where(sql.Eq{fmt.Sprintf(AliasedField, t.Alias, where.Field): where.Val}) } - for _, whereIn := range t.WhereIns { - placeholders := make([]string, len(whereIn.Val)) - for i, _ := range whereIn.Val { - placeholders[i] = "?" - } - parens := "(" + strings.Join(placeholders, ",") + ")" - bld = bld.Where(whereIn.Field + " IN " + parens, whereIn.Val...) - } - for _, join := range t.Joins { bld = sqlizeJoin(bld, join) } @@ -233,8 +215,7 @@ func (t *Table) DeleteSQL() sql.DeleteBuilder { if t.MultiTenancy != (Where{}) { bld = bld.Where( sql.Eq{ - fmt.Sprintf(AliasedField, t.Alias, t.MultiTenancy.Field): - t.MultiTenancy.Val, + fmt.Sprintf(AliasedField, t.Alias, t.MultiTenancy.Field): t.MultiTenancy.Val, }, ) } @@ -256,8 +237,7 @@ func sqlizeJoin(bld sql.SelectBuilder, join Join) sql.SelectBuilder { jc = sql.And{ jc, sql.Eq{ - fmt.Sprintf(AliasedField, join.Table.Alias, where.Field): - where.Val, + fmt.Sprintf(AliasedField, join.Table.Alias, where.Field): where.Val, }, } } @@ -273,7 +253,6 @@ func sqlizeJoin(bld sql.SelectBuilder, join Join) sql.SelectBuilder { } bld = bld.JoinClause(jc) - for _, where := range join.Table.Wheres { bld = bld.Where(sql.Eq{fmt.Sprintf(AliasedField, join.Table.Alias, where.Field): where.Val}) } @@ -284,4 +263,4 @@ func sqlizeJoin(bld sql.SelectBuilder, join Join) sql.SelectBuilder { return bld -} \ No newline at end of file +} diff --git a/queryparts/where.go b/queryparts/where.go index 57ee799..99911f3 100644 --- a/queryparts/where.go +++ b/queryparts/where.go @@ -8,13 +8,8 @@ type Where struct { Val interface{} } -type WhereIn struct { - Field string - Val []interface{} +// FieldFilter defines an arbitrary filter on a FilterRequest +type FieldFilter struct { + FieldName string + FilterValue interface{} } - -type SelectFilter struct { - TableName string - FieldName string - Values []interface{} -} \ No newline at end of file diff --git a/reflectutil/reflect.go b/reflectutil/reflect.go index ce00d7d..a951a0a 100644 --- a/reflectutil/reflect.go +++ b/reflectutil/reflect.go @@ -2,8 +2,6 @@ package reflectutil import ( "reflect" - - "github.com/skuid/picard/tags" ) // IsZeroValue returns true if the value provided is the zero value for its type @@ -13,40 +11,3 @@ func IsZeroValue(v reflect.Value) bool { } return false } - -// GetPK returns the primary key for a struct -func GetPK(val reflect.Value) (*reflect.Value, bool) { - typ := val.Type() - for i := 0; i < typ.NumField(); i++ { - field := typ.Field(i) - ptags := tags.GetStructTagsMap(field, "picard") - if _, isPK := ptags["primary_key"]; isPK { - fv := val.FieldByName(field.Name) - return &fv, true - } - } - - return nil, false -} - -/* -ReflectTableInfo will return the table name and primary key name from the type -*/ -func ReflectTableInfo(typ reflect.Type) (string, string) { - var tblName string - var primaryKey string - - for i := 0; i < typ.NumField(); i++ { - field := typ.Field(i) - typName := field.Type.Name() - ptags := tags.GetStructTagsMap(field, "picard") - if typName == "Metadata" { - tblName = ptags["tablename"] - } - if _, isPK := ptags["primary_key"]; isPK { - primaryKey = ptags["column"] - } - } - - return tblName, primaryKey -} diff --git a/tags/tags.go b/tags/tags.go index 1bc5ebf..b9d7c7b 100644 --- a/tags/tags.go +++ b/tags/tags.go @@ -316,12 +316,16 @@ func (tm TableMetadata) GetField(fieldName string) FieldMetadata { func GetTableMetadata(data interface{}) (*TableMetadata, error) { // Verify that we've been passed valid input t := reflect.TypeOf(data) - if t.Kind() != reflect.Slice { - return nil, errors.New("Can only upsert slices") + var tableMetadata *TableMetadata + + if t.Kind() == reflect.Slice { + tableMetadata = TableMetadataFromType(t.Elem()) + } else if t.Kind() == reflect.Struct { + tableMetadata = TableMetadataFromType(t) + } else { + return nil, errors.New("Can only get metadata structs or slices of structs") } - tableMetadata := TableMetadataFromType(t.Elem()) - if tableMetadata.tableName == "" { return nil, errors.New("No table name specified in struct metadata") } From f7c56ba2d7ef53d84072e42ac0a4a94c654552b8 Mon Sep 17 00:00:00 2001 From: Ben Hubbard Date: Mon, 5 Aug 2019 16:37:41 -0400 Subject: [PATCH 2/4] Don't get metadata on hydrate --- filter.go | 4 ++-- query/hydrate.go | 10 +++++----- query/hydrate_test.go | 8 +++++++- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/filter.go b/filter.go index 90180aa..6cb8e82 100644 --- a/filter.go +++ b/filter.go @@ -50,7 +50,7 @@ func (p PersistenceORM) getSingleFilterResults(request FilterRequest, filterMeta return nil, err } aliasMap := tbl.FieldAliases() - return query.Hydrate(filterModel, aliasMap, rows) + return query.Hydrate(filterModel, aliasMap, rows, filterMetadata) } func (p PersistenceORM) getMultiFilterResults(request FilterRequest, filterMetadata *tags.TableMetadata) ([]*reflect.Value, error) { @@ -111,7 +111,7 @@ func (p PersistenceORM) getMultiFilterResults(request FilterRequest, filterMetad return nil, err } aliasMap := tbl.FieldAliases() - return query.Hydrate(filterModel, aliasMap, rows) + return query.Hydrate(filterModel, aliasMap, rows, filterMetadata) } func (p PersistenceORM) getFilterResults(request FilterRequest, filterMetadata *tags.TableMetadata) ([]*reflect.Value, error) { diff --git a/query/hydrate.go b/query/hydrate.go index 9bc0aca..09690a6 100644 --- a/query/hydrate.go +++ b/query/hydrate.go @@ -18,7 +18,7 @@ import ( Hydrate takes the rows and pops them into the correct struct, in the correct order. This is usually called after you've built and executed the query model. */ -func Hydrate(filterModel interface{}, aliasMap map[string]qp.FieldDescriptor, rows *sql.Rows) ([]*reflect.Value, error) { +func Hydrate(filterModel interface{}, aliasMap map[string]qp.FieldDescriptor, rows *sql.Rows, meta *tags.TableMetadata) ([]*reflect.Value, error) { modelVal, err := stringutil.GetStructValue(filterModel) if err != nil { return nil, err @@ -34,7 +34,7 @@ func Hydrate(filterModel interface{}, aliasMap map[string]qp.FieldDescriptor, ro hydrateds := make([]*reflect.Value, 0, len(mappedCols)) for _, mapped := range mappedCols { - hydrated, err := hydrate(typ, mapped, 0) + hydrated, err := hydrate(typ, mapped, 0, meta) if err != nil { return nil, err @@ -45,8 +45,7 @@ func Hydrate(filterModel interface{}, aliasMap map[string]qp.FieldDescriptor, ro return hydrateds, nil } -func hydrate(typ reflect.Type, mapped map[string]map[string]interface{}, counter int) (*reflect.Value, error) { - meta := tags.TableMetadataFromType(typ) +func hydrate(typ reflect.Type, mapped map[string]map[string]interface{}, counter int, meta *tags.TableMetadata) (*reflect.Value, error) { model := reflect.Indirect(reflect.New(typ)) @@ -63,8 +62,9 @@ func hydrate(typ reflect.Type, mapped map[string]map[string]interface{}, counter if field.IsFK() { refTyp := field.GetRelatedType() + fkField := meta.GetForeignKeyField(field.GetName()) // Recursively hydrate this reference field - refValHydrated, err := hydrate(refTyp, mapped, counter+1) + refValHydrated, err := hydrate(refTyp, mapped, counter+1, fkField.TableMetadata) if err != nil { return nil, err } diff --git a/query/hydrate_test.go b/query/hydrate_test.go index 6b0a56f..3ebb0d1 100644 --- a/query/hydrate_test.go +++ b/query/hydrate_test.go @@ -6,6 +6,7 @@ import ( "github.com/skuid/picard/crypto" qp "github.com/skuid/picard/queryparts" + "github.com/skuid/picard/tags" "github.com/DATA-DOG/go-sqlmock" sql "github.com/Masterminds/squirrel" @@ -331,8 +332,13 @@ func TestHydrate(t *testing.T) { t.Errorf("there were unmet sqlmock expectations:\n%s", err) } + metadata, err := tags.GetTableMetadata(tc.model) + if err != nil { + t.Fatal(err) + } + // Testing our Hydrate function - actuals, err := Hydrate(tc.model, tc.aliasMap, rows) + actuals, err := Hydrate(tc.model, tc.aliasMap, rows, metadata) assert.NoError(err) for i, actual := range actuals { assert.Equal(tc.expected[i], actual.Interface().(field)) From 0dfdde1ed35f69bad69ecf29773eb0acdcb1dc40 Mon Sep 17 00:00:00 2001 From: Ben Hubbard Date: Tue, 6 Aug 2019 12:24:14 -0400 Subject: [PATCH 3/4] Add SelectFields and FieldFilters to Associations --- delete.go | 64 ++++-------- delete_test.go | 18 +--- filter.go | 8 +- filter_test.go | 240 ++++++++++++++++++++++++++++++++++++++++---- query/build.go | 34 +++---- query/build_test.go | 2 +- tags/tags.go | 12 ++- 7 files changed, 277 insertions(+), 101 deletions(-) diff --git a/delete.go b/delete.go index 503c826..b22eff0 100644 --- a/delete.go +++ b/delete.go @@ -16,17 +16,20 @@ import ( // Returns the number of rows affected or an error. func (porm PersistenceORM) DeleteModel(model interface{}) (int64, error) { - associations, err := getAssociationsFromModel(model) + metadata, err := tags.GetTableMetadata(model) if err != nil { return 0, err } - metadata, err := tags.GetTableMetadata(model) + hasAssociations, err := hasAssociations(model, metadata) if err != nil { return 0, err } - tbl, err := query.Build(porm.multitenancyValue, model, nil, nil, metadata) + pkField := metadata.GetPrimaryKeyFieldName() + pkColumn := metadata.GetPrimaryKeyColumnName() + + tbl, err := query.Build(porm.multitenancyValue, model, nil, nil, nil, metadata) if err != nil { return 0, err @@ -35,25 +38,24 @@ func (porm PersistenceORM) DeleteModel(model interface{}) (int64, error) { dSQL := tbl.DeleteSQL() lookupPks := make([]interface{}, 0) - if len(associations) > 0 { - pk := metadata.GetPrimaryKeyColumnName() + if hasAssociations { results, err := porm.FilterModel(FilterRequest{ FilterModel: model, - Associations: associations, + SelectFields: []string{pkField}, }) if err != nil { return 0, err } for _, result := range results { - val := getValueFromLookupString(reflect.ValueOf(result), metadata.GetPrimaryKeyFieldName()) + val := getValueFromLookupString(reflect.ValueOf(result), pkField) if val.IsValid() { lookupPks = append(lookupPks, val.Interface()) } } dSQL = dSQL.Where( sq.Eq{ - fmt.Sprintf("%s.%s", tbl.Alias, pk): lookupPks, + fmt.Sprintf("%s.%s", tbl.Alias, pkColumn): lookupPks, }, ) } @@ -79,51 +81,25 @@ func (porm PersistenceORM) DeleteModel(model interface{}) (int64, error) { return results.RowsAffected() } -func getAssociationsFromModel(model interface{}) ([]tags.Association, error) { - +func hasAssociations(model interface{}, metadata *tags.TableMetadata) (bool, error) { val, err := stringutil.GetStructValue(model) - if err != nil { - return nil, err + return false, err } - return getAssociationsFromValue(val) -} - -func getAssociationsFromValue(val reflect.Value) ([]tags.Association, error) { - associations := make([]tags.Association, 0) - if val.Kind() != reflect.Struct { - return nil, fmt.Errorf("Model must be a struct in order to get associations. It was a %v instead", val.Kind()) + return false, fmt.Errorf("Model must be a struct in order to get associations. It was a %v instead", val.Kind()) } - for i := 0; i < val.Type().NumField(); i++ { - structField := val.Type().Field(i) - ptags := tags.GetStructTagsMap(structField, "picard") - - _, isFK := ptags["foreign_key"] - - if isFK { - if relatedName, ok := ptags["related"]; ok { - - relatedVal := val.FieldByName(relatedName) + for _, fkField := range metadata.GetForeignKeys() { + relatedName := fkField.RelatedFieldName - if !reflectutil.IsZeroValue(relatedVal) { - fieldAssoc := tags.Association{ - Name: relatedName, - } - - childAssocs, err := getAssociationsFromValue(relatedVal) - if err != nil { - return nil, err - } - fieldAssoc.Associations = append(fieldAssoc.Associations, childAssocs...) - - associations = append(associations, fieldAssoc) - } + if relatedName != "" { + relatedVal := val.FieldByName(relatedName) + if !reflectutil.IsZeroValue(relatedVal) { + return true, nil } } } - - return associations, nil + return false, nil } diff --git a/delete_test.go b/delete_test.go index fc2a5f7..2ae1ea8 100644 --- a/delete_test.go +++ b/delete_test.go @@ -43,8 +43,8 @@ func TestDeleteModel(t *testing.T) { `)). WithArgs( "00000000-0000-0000-0000-000000000001", - "00000000-0000-0000-0000-000000000555", - ). + "00000000-0000-0000-0000-000000000555", + ). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, @@ -63,22 +63,10 @@ func TestDeleteModel(t *testing.T) { mock.ExpectQuery(testdata.FmtSQLRegex(` SELECT t0.id AS "t0.id", - t0.organization_id AS "t0.organization_id", - t0.name AS "t0.name", - t0.nullable_lookup AS "t0.nullable_lookup", - t0.type AS "t0.type", - t0.is_active AS "t0.is_active", - t0.parent_id AS "t0.parent_id", - t0.config AS "t0.config", - t0.created_by_id AS "t0.created_by_id", - t0.updated_by_id AS "t0.updated_by_id", - t0.created_at AS "t0.created_at", - t0.updated_at AS "t0.updated_at", t1.id AS "t1.id", - t1.organization_id AS "t1.organization_id", t1.name AS "t1.name" FROM testobject AS t0 - LEFT JOIN parenttest AS t1 ON + JOIN parenttest AS t1 ON (t1.id = t0.parent_id AND t1.organization_id = $1) WHERE t0.organization_id = $2 AND diff --git a/filter.go b/filter.go index 6cb8e82..f9d0ef5 100644 --- a/filter.go +++ b/filter.go @@ -19,7 +19,7 @@ type FilterRequest struct { Associations []tags.Association OrderBy []qp.OrderByRequest Runner sq.BaseRunner - //Fields []string // For use later whe we implement selecting specific columns + SelectFields []string } func addOrderBy(builder sq.SelectBuilder, orderBy []qp.OrderByRequest, filterMetadata *tags.TableMetadata, tableAlias string) sq.SelectBuilder { @@ -39,7 +39,7 @@ func addOrderBy(builder sq.SelectBuilder, orderBy []qp.OrderByRequest, filterMet func (p PersistenceORM) getSingleFilterResults(request FilterRequest, filterMetadata *tags.TableMetadata) ([]*reflect.Value, error) { filterModel := request.FilterModel - tbl, err := query.Build(p.multitenancyValue, filterModel, request.FieldFilters, request.Associations, filterMetadata) + tbl, err := query.Build(p.multitenancyValue, filterModel, request.FieldFilters, request.Associations, request.SelectFields, filterMetadata) if err != nil { return nil, err } @@ -67,7 +67,7 @@ func (p PersistenceORM) getMultiFilterResults(request FilterRequest, filterMetad for i := 0; i < modelVal.Len(); i++ { val := modelVal.Index(i) - ftbl, err := query.Build(mtVal, val.Interface(), request.FieldFilters, request.Associations, filterMetadata) + ftbl, err := query.Build(mtVal, val.Interface(), request.FieldFilters, request.Associations, request.SelectFields, filterMetadata) if err != nil { return nil, err } @@ -208,6 +208,8 @@ func (p PersistenceORM) FilterModel(request FilterRequest) ([]interface{}, error Associations: association.Associations, OrderBy: association.OrderBy, Runner: request.Runner, + FieldFilters: association.FieldFilters, + SelectFields: association.SelectFields, }) if err != nil { return nil, err diff --git a/filter_test.go b/filter_test.go index 8b8ebbd..db888f5 100644 --- a/filter_test.go +++ b/filter_test.go @@ -1098,7 +1098,7 @@ func TestDoFilterSelectWithJSONBField(t *testing.T) { func TestFilterModel(t *testing.T) { orgID := "00000000-0000-0000-0000-000000000001" - parentId := "00000000-0000-0000-0000-000000000002" + parentID := "00000000-0000-0000-0000-000000000002" testCases := []struct { description string filterRequest FilterRequest @@ -1109,7 +1109,7 @@ func TestFilterModel(t *testing.T) { "basic filter", FilterRequest{ FilterModel: testdata.ToyModel{ - ParentID: parentId, + ParentID: parentID, }, }, []interface{}{ @@ -1117,7 +1117,7 @@ func TestFilterModel(t *testing.T) { ID: "00000000-0000-0000-0000-000000000011", OrganizationID: orgID, Name: "lego", - ParentID: parentId, + ParentID: parentID, }, }, func(mock sqlmock.Sqlmock) { @@ -1131,7 +1131,7 @@ func TestFilterModel(t *testing.T) { FROM toymodel AS t0 WHERE t0.organization_id = $1 AND t0.parent_id = $2 `)). - WithArgs(orgID, parentId). + WithArgs(orgID, parentID). WillReturnRows( sqlmock.NewRows([]string{ "t0.id", @@ -1143,7 +1143,7 @@ func TestFilterModel(t *testing.T) { "00000000-0000-0000-0000-000000000011", orgID, "lego", - parentId, + parentID, ), ) mock.ExpectCommit() @@ -1153,7 +1153,7 @@ func TestFilterModel(t *testing.T) { "basic filter with no returns", FilterRequest{ FilterModel: testdata.ToyModel{ - ParentID: parentId, + ParentID: parentID, }, }, []interface{}{}, @@ -1168,7 +1168,7 @@ func TestFilterModel(t *testing.T) { FROM toymodel AS t0 WHERE t0.organization_id = $1 AND t0.parent_id = $2 `)). - WithArgs(orgID, parentId). + WithArgs(orgID, parentID). WillReturnRows( sqlmock.NewRows([]string{ "t0.id", @@ -1323,7 +1323,7 @@ func TestFilterModel(t *testing.T) { }, []interface{}{ testdata.ParentModel{ - ID: parentId, + ID: parentID, OrganizationID: orgID, Name: "pops", ParentID: "00000000-0000-0000-0000-000000000004", @@ -1332,13 +1332,13 @@ func TestFilterModel(t *testing.T) { ID: "00000000-0000-0000-0000-000000000012", OrganizationID: orgID, Name: "Betty", - ParentID: parentId, + ParentID: parentID, }, { ID: "00000000-0000-0000-0000-000000000011", OrganizationID: orgID, Name: "Alex", - ParentID: parentId, + ParentID: parentID, }, }, Animals: []testdata.PetModel{ @@ -1346,13 +1346,13 @@ func TestFilterModel(t *testing.T) { ID: "00000000-0000-0000-0000-000000000031", OrganizationID: orgID, Name: "Cheerios", - ParentID: parentId, + ParentID: parentID, }, { ID: "00000000-0000-0000-0000-000000000032", OrganizationID: orgID, Name: "Pinkerton", - ParentID: parentId, + ParentID: parentID, }, }, }, @@ -1378,7 +1378,7 @@ func TestFilterModel(t *testing.T) { "t0.name", "t0.parent_id", }).AddRow( - parentId, + parentID, orgID, "pops", "00000000-0000-0000-0000-000000000004", @@ -1395,7 +1395,7 @@ func TestFilterModel(t *testing.T) { WHERE t0.organization_id = $1 AND ((t0.parent_id = $2)) ORDER BY t0.name DESC `)). - WithArgs(orgID, parentId). + WithArgs(orgID, parentID). WillReturnRows( sqlmock.NewRows([]string{ "t0.id", @@ -1407,13 +1407,13 @@ func TestFilterModel(t *testing.T) { "00000000-0000-0000-0000-000000000012", orgID, "Betty", - parentId, + parentID, ). AddRow( "00000000-0000-0000-0000-000000000011", orgID, "Alex", - parentId, + parentID, ), ) // Pets/Animals @@ -1427,7 +1427,7 @@ func TestFilterModel(t *testing.T) { WHERE t0.organization_id = $1 AND ((t0.parent_id = $2)) ORDER BY t0.name `)). - WithArgs(orgID, parentId). + WithArgs(orgID, parentID). WillReturnRows( sqlmock.NewRows([]string{ "t0.id", @@ -1439,13 +1439,13 @@ func TestFilterModel(t *testing.T) { "00000000-0000-0000-0000-000000000031", orgID, "Cheerios", - parentId, + parentID, ). AddRow( "00000000-0000-0000-0000-000000000032", orgID, "Pinkerton", - parentId, + parentID, ), ) mock.ExpectCommit() @@ -1521,6 +1521,208 @@ func TestFilterModel(t *testing.T) { mock.ExpectCommit() }, }, + { + "filter request with additional field filters array and select fields specified", + FilterRequest{ + FilterModel: testdata.ToyModel{}, + FieldFilters: []qp.FieldFilter{ + { + FieldName: "Name", + FilterValue: []string{"Lego", "Matchbox Car", "Nintendo"}, + }, + }, + SelectFields: []string{"ID", "Name"}, + }, + []interface{}{}, + func(mock sqlmock.Sqlmock) { + mock.ExpectBegin() + mock.ExpectQuery(testdata.FmtSQLRegex(` + SELECT + t0.id AS "t0.id", + t0.name AS "t0.name" + FROM toymodel AS t0 + WHERE t0.organization_id = $1 AND t0.name IN ($2,$3,$4) + `)). + WithArgs(orgID, "Lego", "Matchbox Car", "Nintendo"). + WillReturnRows( + sqlmock.NewRows([]string{ + "t0.id", + "t0.organization_id", + "t0.name", + "t0.parent_id", + }), + ) + mock.ExpectCommit() + }, + }, + { + "happy path for single parent filter with eager loading parent - also add selectfields and field filter on association", + FilterRequest{ + FilterModel: testdata.ParentModel{ + Name: "pops", + }, + Associations: []tags.Association{ + { + Name: "GrandParent", + SelectFields: []string{"ID", "Name"}, + FieldFilters: []qp.FieldFilter{ + { + FieldName: "Name", + FilterValue: "grandpops", + }, + }, + }, + }, + }, + []interface{}{ + testdata.ParentModel{ + ID: "00000000-0000-0000-0000-000000000002", + OrganizationID: orgID, + Name: "pops", + ParentID: "00000000-0000-0000-0000-000000000023", + GrandParent: testdata.GrandParentModel{ + ID: "00000000-0000-0000-0000-000000000023", + Name: "grandpops", + }, + }, + }, + func(mock sqlmock.Sqlmock) { + mock.ExpectBegin() + mock.ExpectQuery(testdata.FmtSQLRegex(` + SELECT + t0.id AS "t0.id", + t0.organization_id AS "t0.organization_id", + t0.name AS "t0.name", + t0.parent_id AS "t0.parent_id", + t1.id AS "t1.id", + t1.name AS "t1.name" + FROM parentmodel AS t0 + LEFT JOIN grandparentmodel AS t1 ON + (t1.id = t0.parent_id AND t1.organization_id = $1) + WHERE + t0.organization_id = $2 AND + t0.name = $3 AND + t1.name = $4 + `)). + WithArgs(orgID, orgID, "pops", "grandpops"). + WillReturnRows( + sqlmock.NewRows([]string{ + "t0.id", + "t0.organization_id", + "t0.name", + "t0.parent_id", + "t1.id", + "t1.name", + }). + AddRow( + "00000000-0000-0000-0000-000000000002", + orgID, + "pops", + "00000000-0000-0000-0000-000000000023", + "00000000-0000-0000-0000-000000000023", + "grandpops", + ), + ) + mock.ExpectCommit() + }, + }, + { + "happy path for filtering children with selectfields and fieldfilters", + FilterRequest{ + FilterModel: testdata.ParentModel{ + Name: "pops", + }, + Associations: []tags.Association{ + { + Name: "Children", + SelectFields: []string{"ID", "Name", "ParentID"}, + FieldFilters: []qp.FieldFilter{ + { + FieldName: "Name", + FilterValue: []string{"kiddo", "another_kid"}, + }, + }, + }, + }, + }, + []interface{}{ + testdata.ParentModel{ + ID: "00000000-0000-0000-0000-000000000002", + OrganizationID: orgID, + Name: "pops", + ParentID: "00000000-0000-0000-0000-000000000004", + Children: []testdata.ChildModel{ + { + ID: "00000000-0000-0000-0000-000000000011", + Name: "kiddo", + ParentID: "00000000-0000-0000-0000-000000000002", + }, + { + ID: "00000000-0000-0000-0000-000000000012", + Name: "another_kid", + ParentID: "00000000-0000-0000-0000-000000000002", + }, + }, + }, + }, + func(mock sqlmock.Sqlmock) { + mock.ExpectBegin() + mock.ExpectQuery(testdata.FmtSQLRegex(` + SELECT + t0.id AS "t0.id", + t0.organization_id AS "t0.organization_id", + t0.name AS "t0.name", + t0.parent_id AS "t0.parent_id" + FROM parentmodel AS t0 + WHERE t0.organization_id = $1 AND t0.name = $2 + `)). + WithArgs(orgID, "pops"). + WillReturnRows( + sqlmock.NewRows([]string{ + "t0.id", + "t0.organization_id", + "t0.name", + "t0.parent_id", + }). + AddRow( + "00000000-0000-0000-0000-000000000002", + orgID, + "pops", + "00000000-0000-0000-0000-000000000004", + ), + ) + + // parent is vtestdata.ParentModel + mock.ExpectQuery(testdata.FmtSQLRegex(` + SELECT + t0.id AS "t0.id", + t0.name AS "t0.name", + t0.parent_id AS "t0.parent_id" + FROM childmodel AS t0 + WHERE + t0.organization_id = $1 AND ((t0.parent_id = $2 AND t0.name IN ($3,$4))) + `)). + WithArgs(orgID, "00000000-0000-0000-0000-000000000002", "kiddo", "another_kid"). + WillReturnRows( + sqlmock.NewRows([]string{ + "t0.id", + "t0.name", + "t0.parent_id", + }). + AddRow( + "00000000-0000-0000-0000-000000000011", + "kiddo", + "00000000-0000-0000-0000-000000000002", + ). + AddRow( + "00000000-0000-0000-0000-000000000012", + "another_kid", + "00000000-0000-0000-0000-000000000002", + ), + ) + mock.ExpectCommit() + }, + }, } for _, tc := range testCases { diff --git a/query/build.go b/query/build.go index 8b98171..734c598 100644 --- a/query/build.go +++ b/query/build.go @@ -14,7 +14,7 @@ import ( Build takes the filter model and returns a query object. It takes the multitenancy value, current reflected value, and any tags */ -func Build(multitenancyVal, model interface{}, filters []qp.FieldFilter, associations []tags.Association, filterMetadata *tags.TableMetadata) (*qp.Table, error) { +func Build(multitenancyVal, model interface{}, filters []qp.FieldFilter, associations []tags.Association, selectFields []string, filterMetadata *tags.TableMetadata) (*qp.Table, error) { val, err := stringutil.GetStructValue(model) if err != nil { @@ -23,7 +23,7 @@ func Build(multitenancyVal, model interface{}, filters []qp.FieldFilter, associa typ := val.Type() - tbl, err := buildQuery(multitenancyVal, typ, &val, filters, associations, false, 0, filterMetadata) + tbl, err := buildQuery(multitenancyVal, typ, &val, filters, associations, selectFields, false, 0, filterMetadata) if err != nil { return nil, err } @@ -63,6 +63,7 @@ func buildQuery( modelVal *reflect.Value, filters []qp.FieldFilter, associations []tags.Association, + selectFields []string, onlyJoin bool, counter int, filterMetadata *tags.TableMetadata, @@ -77,28 +78,26 @@ func buildQuery( cols := make([]string, 0, modelType.NumField()) seen := make(map[string]bool) - for i := 0; i < modelType.NumField(); i++ { - field := modelType.Field(i) + for _, field := range filterMetadata.GetFields() { notZero := false var val reflect.Value if modelVal != nil { - val = modelVal.FieldByName(field.Name) + val = modelVal.FieldByName(field.GetName()) notZero = !reflectutil.IsZeroValue(val) } - ptags := tags.GetStructTagsMap(field, "picard") - column, hasColumn := ptags["column"] - _, isMultitenancyColumn := ptags["multitenancy_key"] - _, isFk := ptags["foreign_key"] - _, isPrimaryKey := ptags["primary_key"] + column := field.GetColumnName() + isMultitenancyColumn := field.IsMultitenancyKey() + isFk := field.IsFK() + isPrimaryKey := field.IsPrimaryKey() addCol := true - if onlyJoin && !isPrimaryKey { + if selectFields != nil && !stringutil.StringSliceContainsKey(selectFields, field.GetName()) { addCol = false } - if !hasColumn { - continue + if onlyJoin && !isPrimaryKey { + addCol = false } switch { @@ -109,7 +108,7 @@ func buildQuery( } tbl.AddMultitenancyWhere(column, multitenancyVal) case isFk: - relatedName := ptags["related"] + relatedName := field.GetRelatedName() relatedVal := modelVal.FieldByName(relatedName) association, ok := getAssociation(associations, relatedName) @@ -132,12 +131,12 @@ func buildQuery( refTyp := relatedVal.Type() refMetadata := tags.TableMetadataFromType(refTyp) - refTbl, err := buildQuery(multitenancyVal, refTyp, &relatedVal, nil, association.Associations, childOnlyJoin, counter+1, refMetadata) + refTbl, err := buildQuery(multitenancyVal, refTyp, &relatedVal, association.FieldFilters, association.Associations, association.SelectFields, childOnlyJoin, counter+1, refMetadata) if err != nil { return nil, err } - joinField := ptags["column"] + joinField := column direction := "left" if childOnlyJoin { @@ -147,8 +146,7 @@ func buildQuery( } case notZero: - _, isEncrypted := ptags["encrypted"] - if isEncrypted { + if field.IsEncrypted() { return nil, errors.New("cannot perform queries with where clauses on encrypted fields") } if !seen[column] { diff --git a/query/build_test.go b/query/build_test.go index c7b56a9..9a6b7ed 100644 --- a/query/build_test.go +++ b/query/build_test.go @@ -169,7 +169,7 @@ func TestQueryBuilder(t *testing.T) { t.Fatal(err) } - tbl, err := Build(orgID, tc.model, nil, tc.assoc, metadata) + tbl, err := Build(orgID, tc.model, nil, tc.assoc, nil, metadata) assert.NoError(err) actual, actualArgs, err := tbl.ToSQL() diff --git a/tags/tags.go b/tags/tags.go index b9d7c7b..22ae742 100644 --- a/tags/tags.go +++ b/tags/tags.go @@ -16,6 +16,8 @@ type Association struct { Name string Associations []Association OrderBy []qp.OrderByRequest + SelectFields []string + FieldFilters []qp.FieldFilter } // Lookup structure @@ -89,11 +91,16 @@ func (fm FieldMetadata) IsPrimaryKey() bool { return fm.isPrimaryKey } -// IsReference function +// IsFK function func (fm FieldMetadata) IsFK() bool { return fm.isFK } +// IsMultitenancyKey function +func (fm FieldMetadata) IsMultitenancyKey() bool { + return fm.isMultitenancyKey +} + // GetName function func (fm FieldMetadata) GetName() string { return fm.name @@ -109,14 +116,17 @@ func (fm FieldMetadata) GetFieldType() reflect.Type { return fm.fieldType } +// GetRelatedType function func (fm FieldMetadata) GetRelatedType() reflect.Type { return fm.relatedField.Type } +// GetRelatedName function func (fm FieldMetadata) GetRelatedName() string { return fm.relatedField.Name } +// IsEncrypted function func (fm FieldMetadata) IsEncrypted() bool { return fm.isEncrypted } From 3edb46f9fca92d8b0a48bd8763d7bd4cb47835e8 Mon Sep 17 00:00:00 2001 From: Ben Hubbard Date: Tue, 6 Aug 2019 14:05:58 -0400 Subject: [PATCH 4/4] fix doc signature --- query/build.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/query/build.go b/query/build.go index 734c598..e9a1d9e 100644 --- a/query/build.go +++ b/query/build.go @@ -52,10 +52,17 @@ to generate the SQL. It takes - modelType: This is the reflected type of the struct used for this table's load. It is used to figure out which columns to select, joins to add, and wheres. - modelVal: This is an instance of the struct, holding any lookup values +- filters: Additional filters to add to this query. This allows for more complex conditions + than a simple modelFilter can provide. - associations: List of associations to load. For references, this will add the join to the table at the correct level. +- selectFields: List of fields to add to the select clause of the query. If this is null, + add all fields with columns specified to the query. +- onlyJoin: If the association wasn't asked for, but there is a value in the related structure, just join but don't + add the fields to the select. - counter: because record keeping and aliasing is hard, we have to keep track of which join we're currently looking at during the recursions. +- filterMetadata: Metadata about struct that was passed in in modelVal */ func buildQuery( multitenancyVal interface{},