Skip to content

Commit

Permalink
Created a method Query.Match that returns whether a query matches a g…
Browse files Browse the repository at this point in the history
…iven data item.
  • Loading branch information
zond committed Feb 7, 2024
1 parent 51583a6 commit 97aca2b
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 0 deletions.
116 changes: 116 additions & 0 deletions find_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1295,3 +1295,119 @@ func TestFindWithStorerImplementation(t *testing.T) {
equals(t, *customStorerItem, results[0])
})
}

type queryMatchTest struct {
Key int `badgerholdKey:"Key"`
Age int
Color string
Created time.Time
}

func TestComplexQueryMatch(t *testing.T) {
testWrap(t, func(store *badgerhold.Store, t *testing.T) {
item := queryMatchTest{
Key: 1,
Age: 2,
Color: "color",
Created: time.UnixMicro(0),
}
query := badgerhold.Where("Key").Eq(1).And("Age").Eq(3).Or(badgerhold.Where("Key").Eq(2).And("Age").Eq(2))
if m, err := query.Matches(store, item); m || err != nil {
t.Errorf("wanted %+v to not match %+v, but got %v, %v", query, item, m, err)
}
query = badgerhold.Where("Key").Eq(1).And("Age").Eq(3).Or(badgerhold.Where("Key").Eq(1).And("Age").Eq(2))
if m, err := query.Matches(store, item); !m || err != nil {
t.Errorf("wanted %+v to match %+v, but got %v, %v", query, item, m, err)
}
query = badgerhold.Where("Key").Eq(1).And("Age").Eq(1).Or(badgerhold.Where("Key").Eq(2).And("Age").Eq(2).Or(badgerhold.Where("Key").Eq(1).And("Age").Eq(2)))
if m, err := query.Matches(store, item); !m || err != nil {
t.Errorf("wanted %+v to match %+v, but got %v, %v", query, item, m, err)
}
})
}

func TestQueryMatch(t *testing.T) {
testWrap(t, func(store *badgerhold.Store, t *testing.T) {
item := queryMatchTest{
Key: 1,
Age: 2,
Color: "color",
Created: time.UnixMicro(0),
}
for _, tc := range []struct {
query *badgerhold.Query
wantMatch bool
title string
}{
{
query: badgerhold.Where("Key").Eq(1),
wantMatch: true,
title: "SingleKeyFieldMatch",
},
{
query: badgerhold.Where("Key").Eq(2),
wantMatch: false,
title: "SingleKeyFieldMismatch",
},
{
query: badgerhold.Where("Age").Eq(2),
wantMatch: true,
title: "SingleIntFieldMatch",
},
{
query: badgerhold.Where("Age").Eq(3),
wantMatch: false,
title: "SingleIntFieldMatch",
},
{
query: badgerhold.Where("Key").Eq(1).And("Color").Eq("color"),
wantMatch: true,
title: "MultiFieldAndMatch",
},
{
query: badgerhold.Where("Key").Eq(1).And("Color").Eq("notcolor"),
wantMatch: false,
title: "MultiFieldAndMismatch",
},
{
query: badgerhold.Where("Key").Eq(2).Or(badgerhold.Where("Color").Eq("color")),
wantMatch: true,
title: "MultiFieldOrMatch",
},
{
query: badgerhold.Where("Key").Eq(2).Or(badgerhold.Where("Color").Eq("notcolor")),
wantMatch: false,
title: "MultiFieldOrMismatch",
},
{
query: badgerhold.Where("Created").Eq(time.UnixMicro(0)),
wantMatch: true,
title: "SingleTimeFieldMatch",
},
{
query: badgerhold.Where("Created").Eq(time.UnixMicro(1)),
wantMatch: false,
title: "SingleTimeFieldMismatch",
},
} {
t.Run(tc.title+"StructReceiver", func(t *testing.T) {
gotMatch, err := tc.query.Matches(store, item)
if err != nil {
t.Fatal(err)
}
if gotMatch != tc.wantMatch {
t.Errorf("wanted %+v to return %v for %+v, got %v", tc.query, tc.wantMatch, item, gotMatch)
}
})
t.Run(tc.title+"PtrReceiver", func(t *testing.T) {
gotMatch, err := tc.query.Matches(store, &item)
if err != nil {
t.Fatal(err)
}
if gotMatch != tc.wantMatch {
t.Errorf("wanted %+v to return %v for %+v, got %v", tc.query, tc.wantMatch, &item, gotMatch)
}
})
}
})
}
33 changes: 33 additions & 0 deletions query.go
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,39 @@ func (q *Query) Or(query *Query) *Query {
return q
}

// Matches returns whether the provided data matches the query.
// Will match all field criteria, including nested OR queries, but ignores limits, skips, sort orders, etc.
func (q *Query) Matches(s *Store, data interface{}) (bool, error) {
var key []byte
dataVal := reflect.ValueOf(data)
for dataVal.Kind() == reflect.Ptr {
dataVal = dataVal.Elem()
}
data = dataVal.Interface()
storer := s.newStorer(data)
if keyField, ok := getKeyField(dataVal.Type()); ok {
fieldValue := dataVal.FieldByName(keyField.Name)
var err error
key, err = s.encodeKey(fieldValue.Interface(), storer.Type())
if err != nil {
return false, err
}
}
return q.matches(s, key, dataVal, data)
}

func (q *Query) matches(s *Store, key []byte, value reflect.Value, data interface{}) (bool, error) {
if result, err := q.matchesAllFields(s, key, value, data); result || err != nil {
return result, err
}
for _, orQuery := range q.ors {
if result, err := orQuery.matches(s, key, value, data); result || err != nil {
return result, err
}
}
return false, nil
}

func (q *Query) matchesAllFields(s *Store, key []byte, value reflect.Value, currentRow interface{}) (bool, error) {
if q.IsEmpty() {
return true, nil
Expand Down

0 comments on commit 97aca2b

Please sign in to comment.