Skip to content

Commit

Permalink
Add startAtParameterIndex argument to Convert() (#4)
Browse files Browse the repository at this point in the history
This argument can be used to control which $X argument the generator
starts at. This is useful for when you want to add your own arguments to
your query as well.

Co-authored-by: Koen Bollen <[email protected]>
  • Loading branch information
erikdubbelboer and koenbollen authored May 3, 2024
1 parent de85a91 commit ec4961e
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 32 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ func main() {

// Convert a filter query to a WHERE clause and values:
input := []byte(`{"title": "Jurassic Park"}`)
where, values, err := converter.Convert(input)
conditions, values, err := converter.Convert(input, 1) // 1 is the starting index for params, $1, $2, ...
if err != nil {
// handle error
}
fmt.Println(where, values) // ("title" = $1), ["Jurassic Park"]
fmt.Println(conditions, values) // ("title" = $1), ["Jurassic Park"]

db, _ := sql.Open("postgres", "...")
db.QueryRow("SELECT * FROM movies WHERE " + where, values...)
db.QueryRow("SELECT * FROM movies WHERE " + conditions, values...)
}
```
(See [examples/](examples/) for more examples)
Expand Down
37 changes: 36 additions & 1 deletion examples/basic_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package examples

import (
"database/sql"
"fmt"

"github.com/poki/mongodb-filter-to-postgres/filter"
Expand All @@ -16,7 +17,7 @@ func ExampleNewConverter() {
"$gte": "2020-01-01T00:00:00Z"
}
}`
conditions, values, err := converter.Convert([]byte(mongoFilterQuery))
conditions, values, err := converter.Convert([]byte(mongoFilterQuery), 1)
if err != nil {
// handle error
}
Expand All @@ -27,3 +28,37 @@ func ExampleNewConverter() {
// (("created_at" >= $1) AND ("meta"->>'name' = $2))
// []interface {}{"2020-01-01T00:00:00Z", "John"}
}

func ExampleNewConverter_nonIsolatedConditions() {
converter := filter.NewConverter()

mongoFilterQuery := `{
"$or": [
{ "email": "[email protected]" },
{ "name": {"$regex": "^John.*^" },
]
}`
conditions, values, err := converter.Convert([]byte(mongoFilterQuery), 3)
if err != nil {
// handle error
}

query := `
SELECT *
FROM users
WHERE
disabled_at IS NOT NULL
AND role = $1
AND verified_at > $2
AND ` + conditions + `
LIMIT 10
`

role := "user"
verifiedAt := "2020-01-01T00:00:00Z"
values = append([]any{role, verifiedAt}, values...)

db, _ := sql.Open("postgres", "...")
rows := db.QueryRow(query, values...)
_ = rows // actually use rows
}
2 changes: 1 addition & 1 deletion examples/readme_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func ExampleNewConverter_readme() {
}
]
}`
conditions, values, err := converter.Convert([]byte(mongoFilterQuery))
conditions, values, err := converter.Convert([]byte(mongoFilterQuery), 1)
if err != nil {
// handle error
panic(err)
Expand Down
19 changes: 13 additions & 6 deletions filter/converter.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,21 @@ func NewConverter(options ...Option) *Converter {
}

// Convert converts a MongoDB filter query into SQL conditions and values.
func (c *Converter) Convert(query []byte) (string, []any, error) {
//
// startAtParameterIndex is the index to start the parameter numbering at.
// Passing X will make the first indexed parameter $X, the second $X+1, and so on.
func (c *Converter) Convert(query []byte, startAtParameterIndex int) (conditions string, values []any, err error) {
if startAtParameterIndex < 1 {
return "", nil, fmt.Errorf("startAtParameterIndex must be greater than 0")
}

var mongoFilter map[string]any
err := json.Unmarshal(query, &mongoFilter)
err = json.Unmarshal(query, &mongoFilter)
if err != nil {
return "", nil, err
}

conditions, values, err := c.convertFilter(mongoFilter, 0)
conditions, values, err = c.convertFilter(mongoFilter, startAtParameterIndex)
if err != nil {
return "", nil, err
}
Expand Down Expand Up @@ -126,8 +133,8 @@ func (c *Converter) convertFilter(filter map[string]any, paramIndex int) (string
if !isScalarSlice(v[operator]) {
return "", nil, fmt.Errorf("invalid value for $in operator (must array of primatives): %v", v[operator])
}
paramIndex++
inner = append(inner, fmt.Sprintf("(%s = ANY($%d))", c.columnName(key), paramIndex))
paramIndex++
if c.arrayDriver != nil {
v[operator] = c.arrayDriver(v[operator])
}
Expand All @@ -138,8 +145,8 @@ func (c *Converter) convertFilter(filter map[string]any, paramIndex int) (string
if !ok {
return "", nil, fmt.Errorf("unknown operator: %s", operator)
}
paramIndex++
inner = append(inner, fmt.Sprintf("(%s %s $%d)", c.columnName(key), op, paramIndex))
paramIndex++
values = append(values, value)
}
}
Expand All @@ -149,8 +156,8 @@ func (c *Converter) convertFilter(filter map[string]any, paramIndex int) (string
}
conditions = append(conditions, innerResult)
default:
paramIndex++
conditions = append(conditions, fmt.Sprintf("(%s = $%d)", c.columnName(key), paramIndex))
paramIndex++
values = append(values, value)
}
}
Expand Down
32 changes: 31 additions & 1 deletion filter/converter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,11 @@ func TestConverter_Convert(t *testing.T) {
nil,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := filter.NewConverter(tt.option)
conditions, values, err := c.Convert([]byte(tt.input))
conditions, values, err := c.Convert([]byte(tt.input), 1)
if err != nil && (tt.err == nil || err.Error() != tt.err.Error()) {
t.Errorf("Converter.Convert() error = %v, wantErr %v", err, tt.err)
return
Expand All @@ -215,3 +216,32 @@ func TestConverter_Convert(t *testing.T) {
})
}
}

func TestConverter_Convert_startAtParameterIndex(t *testing.T) {
c := filter.NewConverter()
conditions, values, err := c.Convert([]byte(`{"name": "John", "password": "secret"}`), 10)
if err != nil {
t.Fatal(err)
}
if want := `(("name" = $10) AND ("password" = $11))`; conditions != want {
t.Errorf("Converter.Convert() conditions = %v, want %v", conditions, want)
}
if !reflect.DeepEqual(values, []any{"John", "secret"}) {
t.Errorf("Converter.Convert() values = %v, want %v", values, []any{"John"})
}

_, _, err = c.Convert([]byte(`{"name": "John"}`), 0)
if want := "startAtParameterIndex must be greater than 0"; err == nil || err.Error() != want {
t.Errorf("Converter.Convert(..., 0) error = nil, wantErr %q", want)
}

_, _, err = c.Convert([]byte(`{"name": "John"}`), -123)
if want := "startAtParameterIndex must be greater than 0"; err == nil || err.Error() != want {
t.Errorf("Converter.Convert(..., -123) error = nil, wantErr %q", want)
}

_, _, err = c.Convert([]byte(`{"name": "John"}`), 1234551231231231231)
if err != nil {
t.Errorf("Converter.Convert(..., 1234551231231231231) error = %v, want nil", err)
}
}
10 changes: 5 additions & 5 deletions fuzz/fuzz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,15 @@ func FuzzConverter(f *testing.F) {

f.Fuzz(func(t *testing.T, in string) {
c := filter.NewConverter(filter.WithArrayDriver(pq.Array))
where, _, err := c.Convert([]byte(in))
if err == nil && where != "" {
j, err := pg_query.ParseToJSON("SELECT * FROM test WHERE 1 AND " + where)
conditions, _, err := c.Convert([]byte(in), 1)
if err == nil && conditions != "" {
j, err := pg_query.ParseToJSON("SELECT * FROM test WHERE 1 AND " + conditions)
if err != nil {
t.Fatalf("%q %q %v", in, where, err)
t.Fatalf("%q %q %v", in, conditions, err)
}

if strings.Contains(j, "CommentStmt") {
t.Fatal(where, "CommentStmt found")
t.Fatal(conditions, "CommentStmt found")
}
}
})
Expand Down
30 changes: 15 additions & 15 deletions integration/postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,15 @@ func TestIntegration_ReadmeExample(t *testing.T) {
]
}`

where, values, err := c.Convert([]byte(in))
conditions, values, err := c.Convert([]byte(in), 1)
if err != nil {
t.Fatal(err)
}

rows, err := db.Query(`
SELECT id
FROM lobbies
WHERE `+where+`;
WHERE `+conditions+`;
`, values...)
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -124,15 +124,15 @@ func TestIntegration_InAny_PQ(t *testing.T) {
in := `{
"role": { "$in": ["guest", "user"] }
}`
where, values, err := c.Convert([]byte(in))
conditions, values, err := c.Convert([]byte(in), 1)
if err != nil {
t.Fatal(err)
}

rows, err := db.Query(`
SELECT id
FROM users
WHERE `+where+`;
WHERE `+conditions+`;
`, values...)
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -189,15 +189,15 @@ func TestIntegration_InAny_PGX(t *testing.T) {
in := `{
"role": { "$in": ["guest", "user"] }
}`
where, values, err := c.Convert([]byte(in))
conditions, values, err := c.Convert([]byte(in), 1)
if err != nil {
t.Fatal(err)
}

rows, err := db.Query(ctx, `
SELECT id
FROM users
WHERE `+where+`;
WHERE `+conditions+`;
`, values...)
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -296,15 +296,15 @@ func TestIntegration_BasicOperators(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := filter.NewConverter(filter.WithArrayDriver(pq.Array))
where, values, err := c.Convert([]byte(tt.input))
conditions, values, err := c.Convert([]byte(tt.input), 1)
if err != nil {
t.Fatal(err)
}

rows, err := db.Query(`
SELECT id
FROM players
WHERE `+where+`;
WHERE `+conditions+`;
`, values...)
if err != nil {
if tt.expectedError == nil {
Expand All @@ -325,7 +325,7 @@ func TestIntegration_BasicOperators(t *testing.T) {
}

if !reflect.DeepEqual(players, tt.expectedPlayers) {
t.Fatalf("%q expected %v, got %v (where clause used: %q)", tt.input, tt.expectedPlayers, players, where)
t.Fatalf("%q expected %v, got %v (conditions used: %q)", tt.input, tt.expectedPlayers, players, conditions)
}
})
}
Expand Down Expand Up @@ -384,15 +384,15 @@ func TestIntegration_NestedJSONB(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := filter.NewConverter(filter.WithArrayDriver(pq.Array), filter.WithNestedJSONB("metadata", "name", "level", "class"))
where, values, err := c.Convert([]byte(tt.input))
conditions, values, err := c.Convert([]byte(tt.input), 1)
if err != nil {
t.Fatal(err)
}

rows, err := db.Query(`
SELECT id
FROM players
WHERE `+where+`;
WHERE `+conditions+`;
`, values...)
if err != nil {
t.Fatal(err)
Expand All @@ -408,7 +408,7 @@ func TestIntegration_NestedJSONB(t *testing.T) {
}

if !reflect.DeepEqual(players, tt.expectedPlayers) {
t.Fatalf("%q expected %v, got %v (where clause used: %q)", tt.input, tt.expectedPlayers, players, where)
t.Fatalf("%q expected %v, got %v (conditions used: %q)", tt.input, tt.expectedPlayers, players, conditions)
}
})
}
Expand Down Expand Up @@ -452,15 +452,15 @@ func TestIntegration_Logic(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := filter.NewConverter(filter.WithArrayDriver(pq.Array), filter.WithNestedJSONB("metadata", "name", "level", "class"))
where, values, err := c.Convert([]byte(tt.input))
conditions, values, err := c.Convert([]byte(tt.input), 1)
if err != nil {
t.Fatal(err)
}

rows, err := db.Query(`
SELECT id
FROM players
WHERE `+where+`;
WHERE `+conditions+`;
`, values...)
if err != nil {
t.Fatal(err)
Expand All @@ -476,7 +476,7 @@ func TestIntegration_Logic(t *testing.T) {
}

if !reflect.DeepEqual(players, tt.expectedPlayers) {
t.Fatalf("%q expected %v, got %v (where clause used: %q)", tt.input, tt.expectedPlayers, players, where)
t.Fatalf("%q expected %v, got %v (conditions used: %q)", tt.input, tt.expectedPlayers, players, conditions)
}
})
}
Expand Down

0 comments on commit ec4961e

Please sign in to comment.