Skip to content

Commit

Permalink
Handle recursive schemas correctly
Browse files Browse the repository at this point in the history
This change checks for cycles when resolving schemas by keeping a cache
of each fully built schema object in a given example. If an item is in
the cache but not yet completely built, it indicates a cycle, and the
example generator bails out.

Sometimes, the recursion isn't actually necessary (e.g., for a property
that isn't required), in which case we try to intelligently omit
recursive schema references.

This should resolve danielgtaylor#45, although without a link to the spec they're
using it's difficult to be 100% sure.
  • Loading branch information
impl committed Sep 29, 2019
1 parent 79bc0ab commit e4ad0a3
Show file tree
Hide file tree
Showing 7 changed files with 207 additions and 19 deletions.
3 changes: 3 additions & 0 deletions apisprout.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ var (
// ErrNoExample is sent when no example was found for an operation.
ErrNoExample = errors.New("No example found")

// ErrRecursive is when a schema is impossible to represent because it infinitely recurses.
ErrRecursive = errors.New("Recursive schema")

// ErrCannotMarshal is set when an example cannot be marshalled.
ErrCannotMarshal = errors.New("Cannot marshal example")

Expand Down
101 changes: 82 additions & 19 deletions example.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,26 +92,76 @@ func excludeFromMode(mode Mode, schema *openapi3.Schema) bool {
return false
}

// OpenAPIExample creates an example structure from an OpenAPI 3 schema
// object, which is an extended subset of JSON Schema.
// https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.1.md#schemaObject
func OpenAPIExample(mode Mode, schema *openapi3.Schema) (interface{}, error) {
// isRequired checks whether a key is actually required.
func isRequired(schema *openapi3.Schema, key string) bool {
for _, req := range schema.Required {
if req == key {
return true
}
}

return false
}

type cachedSchema struct {
pending bool
out interface{}
}

func openAPIExample(mode Mode, schema *openapi3.Schema, cache map[*openapi3.Schema]*cachedSchema) (out interface{}, err error) {
if ex, ok := getSchemaExample(schema); ok {
return ex, nil
}

cached, ok := cache[schema]
if !ok {
cached = &cachedSchema{
pending: true,
}
cache[schema] = cached
} else if cached.pending {
return nil, ErrRecursive
} else {
return cached.out, nil
}

defer func() {
cached.pending = false
cached.out = out
}()

// Handle combining keywords
if len(schema.OneOf) > 0 {
return OpenAPIExample(mode, schema.OneOf[0].Value)
var ex interface{}
var err error

for _, candidate := range schema.OneOf {
ex, err = openAPIExample(mode, candidate.Value, cache)
if err == nil {
break
}
}

return ex, err
}
if len(schema.AnyOf) > 0 {
return OpenAPIExample(mode, schema.AnyOf[0].Value)
var ex interface{}
var err error

for _, candidate := range schema.AnyOf {
ex, err = openAPIExample(mode, candidate.Value, cache)
if err == nil {
break
}
}

return ex, err
}
if len(schema.AllOf) > 0 {
example := map[string]interface{}{}

for _, allOf := range schema.AllOf {
candidate, err := OpenAPIExample(mode, allOf.Value)
candidate, err := openAPIExample(mode, allOf.Value, cache)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -188,9 +238,9 @@ func OpenAPIExample(mode Mode, schema *openapi3.Schema) (interface{}, error) {
example := []interface{}{}

if schema.Items != nil && schema.Items.Value != nil {
ex, err := OpenAPIExample(mode, schema.Items.Value)
ex, err := openAPIExample(mode, schema.Items.Value, cache)
if err != nil {
return nil, fmt.Errorf("can't get example for array item")
return nil, fmt.Errorf("can't get example for array item: %+v", err)
}

example = append(example, ex)
Expand All @@ -209,24 +259,30 @@ func OpenAPIExample(mode Mode, schema *openapi3.Schema) (interface{}, error) {
continue
}

ex, err := OpenAPIExample(mode, v.Value)
if err != nil {
return nil, fmt.Errorf("can't get example for '%s'", k)
ex, err := openAPIExample(mode, v.Value, cache)
if err == ErrRecursive {
if isRequired(schema, k) {
return nil, fmt.Errorf("can't get example for '%s': %+v", k, err)
}
} else if err != nil {
return nil, fmt.Errorf("can't get example for '%s': %+v", k, err)
} else {
example[k] = ex
}

example[k] = ex
}

if schema.AdditionalProperties != nil && schema.AdditionalProperties.Value != nil {
addl := schema.AdditionalProperties.Value

if !excludeFromMode(mode, addl) {
ex, err := OpenAPIExample(mode, addl)
if err != nil {
return nil, fmt.Errorf("can't get example for additional properties")
ex, err := openAPIExample(mode, addl, cache)
if err == ErrRecursive {
// We just won't add this if it's recursive.
} else if err != nil {
return nil, fmt.Errorf("can't get example for additional properties: %+v", err)
} else {
example["additionalPropertyName"] = ex
}

example["additionalPropertyName"] = ex
}
}

Expand All @@ -235,3 +291,10 @@ func OpenAPIExample(mode Mode, schema *openapi3.Schema) (interface{}, error) {

return nil, ErrNoExample
}

// OpenAPIExample creates an example structure from an OpenAPI 3 schema
// object, which is an extended subset of JSON Schema.
// https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.1.md#schemaObject
func OpenAPIExample(mode Mode, schema *openapi3.Schema) (interface{}, error) {
return openAPIExample(mode, schema, make(map[*openapi3.Schema]*cachedSchema))
}
69 changes: 69 additions & 0 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,28 @@ package main

import (
"encoding/json"
"io/ioutil"
"os"
"path"
"strings"
"testing"

"github.com/getkin/kin-openapi/openapi3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func exampleFixture(t *testing.T, name string) string {
f, err := os.Open(path.Join("testdata/example", name))
require.NoError(t, err)
defer f.Close()

b, err := ioutil.ReadAll(f)
require.NoError(t, err)

return string(b)
}

var schemaTests = []struct {
name string
in string
Expand Down Expand Up @@ -505,3 +520,57 @@ func TestGenExample(t *testing.T) {
})
}
}

func TestRecursiveSchema(t *testing.T) {
loader := openapi3.NewSwaggerLoader()

tests := []struct {
name string
in string
schema string
out string
}{
{
"Valid recursive schema",
exampleFixture(t, "recursive_ok.yml"),
"Test",
`{"something": "Hello"}`,
},
{
"Infinitely recursive schema",
exampleFixture(t, "recursive_infinite.yml"),
"Test",
``,
},
{
"Seeing the same schema twice non-recursively",
exampleFixture(t, "recursive_seen_twice.yml"),
"Test",
`{"ref_a": {"spud": "potato"}, "ref_b": {"spud": "potato"}}`,
},
{
"Cyclical dependencies",
exampleFixture(t, "recursive_cycles.yml"),
"Front",
``,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
swagger, err := loader.LoadSwaggerFromData([]byte(test.in))
require.NoError(t, err)

ex, err := OpenAPIExample(ModeResponse, swagger.Components.Schemas[test.schema].Value)
if test.out == "" {
assert.Error(t, err)
assert.Nil(t, ex)
} else {
assert.Nil(t, err)
// Expected to match the output.
var expected interface{}
json.Unmarshal([]byte(test.out), &expected)
assert.EqualValues(t, expected, ex)
}
})
}
}
16 changes: 16 additions & 0 deletions testdata/example/recursive_cycles.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
components:
schemas:
Front:
type: object
required:
- back
properties:
back:
$ref: '#/components/schemas/Back'
Back:
type: object
required:
- front
properties:
front:
$ref: '#/components/schemas/Front'
9 changes: 9 additions & 0 deletions testdata/example/recursive_infinite.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
components:
schemas:
Test:
type: object
required:
- test
properties:
test:
$ref: '#/components/schemas/Test'
10 changes: 10 additions & 0 deletions testdata/example/recursive_ok.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
components:
schemas:
Test:
type: object
properties:
something:
type: string
example: Hello
test:
$ref: '#/components/schemas/Test'
18 changes: 18 additions & 0 deletions testdata/example/recursive_seen_twice.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
components:
schemas:
Ref:
type: object
properties:
spud:
type: string
example: "potato"
Test:
type: object
required:
- ref_a
- ref_b
properties:
ref_a:
$ref: '#/components/schemas/Ref'
ref_b:
$ref: '#/components/schemas/Ref'

0 comments on commit e4ad0a3

Please sign in to comment.