diff --git a/find_test.go b/find_test.go index 9fbff8a..a4b8cf9 100644 --- a/find_test.go +++ b/find_test.go @@ -1295,3 +1295,119 @@ func TestFindWithStorerImplementation(t *testing.T) { equals(t, *customStorerItem, results[0]) }) } + +type queryMatchTest struct { + Key int `badgerholdKey:"Key"` + Age int + Color string + Created time.Time +} + +func TestComplexQueryMatch(t *testing.T) { + testWrap(t, func(store *badgerhold.Store, t *testing.T) { + item := queryMatchTest{ + Key: 1, + Age: 2, + Color: "color", + Created: time.UnixMicro(0), + } + query := badgerhold.Where("Key").Eq(1).And("Age").Eq(3).Or(badgerhold.Where("Key").Eq(2).And("Age").Eq(2)) + if m, err := query.Matches(store, item); m || err != nil { + t.Errorf("wanted %+v to not match %+v, but got %v, %v", query, item, m, err) + } + query = badgerhold.Where("Key").Eq(1).And("Age").Eq(3).Or(badgerhold.Where("Key").Eq(1).And("Age").Eq(2)) + if m, err := query.Matches(store, item); !m || err != nil { + t.Errorf("wanted %+v to match %+v, but got %v, %v", query, item, m, err) + } + query = badgerhold.Where("Key").Eq(1).And("Age").Eq(1).Or(badgerhold.Where("Key").Eq(2).And("Age").Eq(2).Or(badgerhold.Where("Key").Eq(1).And("Age").Eq(2))) + if m, err := query.Matches(store, item); !m || err != nil { + t.Errorf("wanted %+v to match %+v, but got %v, %v", query, item, m, err) + } + }) +} + +func TestQueryMatch(t *testing.T) { + testWrap(t, func(store *badgerhold.Store, t *testing.T) { + item := queryMatchTest{ + Key: 1, + Age: 2, + Color: "color", + Created: time.UnixMicro(0), + } + for _, tc := range []struct { + query *badgerhold.Query + wantMatch bool + title string + }{ + { + query: badgerhold.Where("Key").Eq(1), + wantMatch: true, + title: "SingleKeyFieldMatch", + }, + { + query: badgerhold.Where("Key").Eq(2), + wantMatch: false, + title: "SingleKeyFieldMismatch", + }, + { + query: badgerhold.Where("Age").Eq(2), + wantMatch: true, + title: "SingleIntFieldMatch", + }, + { + query: badgerhold.Where("Age").Eq(3), + wantMatch: false, + title: "SingleIntFieldMatch", + }, + { + query: badgerhold.Where("Key").Eq(1).And("Color").Eq("color"), + wantMatch: true, + title: "MultiFieldAndMatch", + }, + { + query: badgerhold.Where("Key").Eq(1).And("Color").Eq("notcolor"), + wantMatch: false, + title: "MultiFieldAndMismatch", + }, + { + query: badgerhold.Where("Key").Eq(2).Or(badgerhold.Where("Color").Eq("color")), + wantMatch: true, + title: "MultiFieldOrMatch", + }, + { + query: badgerhold.Where("Key").Eq(2).Or(badgerhold.Where("Color").Eq("notcolor")), + wantMatch: false, + title: "MultiFieldOrMismatch", + }, + { + query: badgerhold.Where("Created").Eq(time.UnixMicro(0)), + wantMatch: true, + title: "SingleTimeFieldMatch", + }, + { + query: badgerhold.Where("Created").Eq(time.UnixMicro(1)), + wantMatch: false, + title: "SingleTimeFieldMismatch", + }, + } { + t.Run(tc.title+"StructReceiver", func(t *testing.T) { + gotMatch, err := tc.query.Matches(store, item) + if err != nil { + t.Fatal(err) + } + if gotMatch != tc.wantMatch { + t.Errorf("wanted %+v to return %v for %+v, got %v", tc.query, tc.wantMatch, item, gotMatch) + } + }) + t.Run(tc.title+"PtrReceiver", func(t *testing.T) { + gotMatch, err := tc.query.Matches(store, &item) + if err != nil { + t.Fatal(err) + } + if gotMatch != tc.wantMatch { + t.Errorf("wanted %+v to return %v for %+v, got %v", tc.query, tc.wantMatch, &item, gotMatch) + } + }) + } + }) +} diff --git a/query.go b/query.go index 4a4deb9..d0ac923 100644 --- a/query.go +++ b/query.go @@ -291,6 +291,39 @@ func (q *Query) Or(query *Query) *Query { return q } +// Matches returns whether the provided data matches the query. +// Will match all field criteria, including nested OR queries, but ignores limits, skips, sort orders, etc. +func (q *Query) Matches(s *Store, data interface{}) (bool, error) { + var key []byte + dataVal := reflect.ValueOf(data) + for dataVal.Kind() == reflect.Ptr { + dataVal = dataVal.Elem() + } + data = dataVal.Interface() + storer := s.newStorer(data) + if keyField, ok := getKeyField(dataVal.Type()); ok { + fieldValue := dataVal.FieldByName(keyField.Name) + var err error + key, err = s.encodeKey(fieldValue.Interface(), storer.Type()) + if err != nil { + return false, err + } + } + return q.matches(s, key, dataVal, data) +} + +func (q *Query) matches(s *Store, key []byte, value reflect.Value, data interface{}) (bool, error) { + if result, err := q.matchesAllFields(s, key, value, data); result || err != nil { + return result, err + } + for _, orQuery := range q.ors { + if result, err := orQuery.matches(s, key, value, data); result || err != nil { + return result, err + } + } + return false, nil +} + func (q *Query) matchesAllFields(s *Store, key []byte, value reflect.Value, currentRow interface{}) (bool, error) { if q.IsEmpty() { return true, nil