Skip to content

Commit

Permalink
initial solution for circular refs and testing
Browse files Browse the repository at this point in the history
  • Loading branch information
tcdsv committed Oct 21, 2023
1 parent a146d1e commit 170f0e5
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 14 deletions.
81 changes: 67 additions & 14 deletions flatten/merge_allof.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"strings"

"github.com/getkin/kin-openapi/openapi3"
"github.com/tufin/oasdiff/utils"
)

const (
Expand Down Expand Up @@ -54,12 +53,17 @@ type SchemaCollection struct {
}

type state struct {
visitedSchemas utils.VisitedRefs
// map processed circular refs to their result schema
circularRefs map[string]*openapi3.Schema

// map original schemas to their result schema
mergedSchemas map[*openapi3.Schema]*openapi3.Schema
}

func newState() *state {
return &state{
visitedSchemas: utils.VisitedRefs{},
circularRefs: map[string]*openapi3.Schema{},
mergedSchemas: map[*openapi3.Schema]*openapi3.Schema{},
}
}

Expand All @@ -73,22 +77,61 @@ func Merge(schema openapi3.SchemaRef) (*openapi3.Schema, error) {

// Merge replaces objects under AllOf with a flattened equivalent
func mergeInternal(state *state, baseSchemaRef openapi3.SchemaRef) (*openapi3.SchemaRef, error) {

// handle circular refs
if baseSchemaRef.Ref != "" {

// If the reference is circular and has already been processed, return its resulting schema.
result, ok := state.circularRefs[baseSchemaRef.Ref]
if ok {
return openapi3.NewSchemaRef(baseSchemaRef.Ref, result), nil
}

// If a circular reference is found, return a schema that closes the loop.
result, ok = state.mergedSchemas[baseSchemaRef.Value]
if ok {
state.circularRefs[baseSchemaRef.Ref] = result
return openapi3.NewSchemaRef(baseSchemaRef.Ref, result), nil
}
}

mergedSchema := openapi3.NewSchema()
result := openapi3.NewSchemaRef(baseSchemaRef.Ref, mergedSchema)

// map original schema to result
state.mergedSchemas[baseSchemaRef.Value] = mergedSchema

baseSchema := baseSchemaRef.Value
allOfSchemas, err := getAllOfSchemas(state, baseSchema.AllOf)

if err != nil {
return nil, err
}

schemaRefs := openapi3.SchemaRefs{&baseSchemaRef}
schemaRefs = append(schemaRefs, allOfSchemas...)
result, err := flattenSchemas(state, schemaRefs)
schemaRefsToFlatten := openapi3.SchemaRefs{&baseSchemaRef}

// in case that AllOf has circular reference, it is not flattened.
if isCircular(state, allOfSchemas) {
mergedSchema.AllOf = allOfSchemas
} else {
schemaRefsToFlatten = append(schemaRefsToFlatten, allOfSchemas...)
}

_, err = flattenSchemas(state, schemaRefsToFlatten, mergedSchema)
if err != nil {
return nil, err
}
return result, nil
}

return openapi3.NewSchemaRef(baseSchemaRef.Ref, result), nil
func isCircular(state *state, schemaRefs openapi3.SchemaRefs) bool {
for _, s := range schemaRefs {
_, ok := state.circularRefs[s.Ref]
if ok {
return true
}
}
return false
}

func getAllOfSchemas(state *state, schemaRefs openapi3.SchemaRefs) (openapi3.SchemaRefs, error) {
Expand All @@ -107,8 +150,7 @@ func getAllOfSchemas(state *state, schemaRefs openapi3.SchemaRefs) (openapi3.Sch
return srefs, nil
}

func flattenSchemas(state *state, schemas []*openapi3.SchemaRef) (*openapi3.Schema, error) {
result := openapi3.NewSchema()
func flattenSchemas(state *state, schemas []*openapi3.SchemaRef, result *openapi3.Schema) (*openapi3.Schema, error) {
collection := collect(schemas)

result.Title = collection.Title[0]
Expand Down Expand Up @@ -236,7 +278,7 @@ func resolveItems(state *state, schema *openapi3.Schema, collection *SchemaColle
return schema, nil
}

res, err := flattenSchemas(state, items)
res, err := flattenSchemas(state, items, openapi3.NewSchema())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -339,7 +381,7 @@ func resolveNonFalseAdditionalProps(state *state, schema *openapi3.Schema, colle

var schemaRef *openapi3.SchemaRef
if len(additionalSchemas) > 0 {
result, err := flattenSchemas(state, additionalSchemas)
result, err := flattenSchemas(state, additionalSchemas, openapi3.NewSchema())
if err != nil {
return nil, err
}
Expand All @@ -356,6 +398,9 @@ func resolveNonFalseProps(state *state, schema *openapi3.Schema, collection *Sch
return nil, err
}
propsToMerge := getNonFalsePropsKeys(collection)
if len(propsToMerge) == 0 {
return schema, nil
}
return mergeProps(state, result, collection, propsToMerge)
}

Expand Down Expand Up @@ -425,7 +470,15 @@ func mergeProps(state *state, schema *openapi3.Schema, collection *SchemaCollect

result := make(openapi3.Schemas)
for prop, schemas := range propsToSchemasMap {
mergedProp, err := flattenSchemas(state, schemas)

if len(schemas) == 1 {
result[prop] = schemas[0]
continue
}

// TODO: If schemas of some property contain a loop, do not flatten them.

mergedProp, err := flattenSchemas(state, schemas, openapi3.NewSchema())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -735,7 +788,7 @@ func getCombinations(groups []openapi3.SchemaRefs) []openapi3.SchemaRefs {
func mergeCombinations(state *state, combinations []openapi3.SchemaRefs) ([]*openapi3.Schema, error) {
merged := []*openapi3.Schema{}
for _, combination := range combinations {
schema, err := flattenSchemas(state, combination)
schema, err := flattenSchemas(state, combination, openapi3.NewSchema())
if err != nil {
continue
}
Expand Down Expand Up @@ -812,7 +865,7 @@ func mergeSchemaRefs(state *state, sr []openapi3.SchemaRefs) ([]openapi3.SchemaR
if err != nil {
return result, err
}
r = append(r, openapi3.NewSchemaRef("", merged.Value))
r = append(r, merged)
}
result = append(result, r)
}
Expand Down
45 changes: 45 additions & 0 deletions flatten/merge_allof_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1651,3 +1651,48 @@ func TestMerge_Required(t *testing.T) {
})
}
}

func TestMerge_SimpleCircularAllOf(t *testing.T) {
ctx := context.Background()
sl := openapi3.NewLoader()
doc, err := sl.LoadFromFile("testdata/circular1.yaml")
require.NoError(t, err, "loading test file")
err = doc.Validate(ctx)
require.NoError(t, err)
result, err := flatten.Merge(*doc.Components.Schemas["Circular_1"])
require.NoError(t, err)
require.NotEmpty(t, result.AllOf)
require.Equal(t, "#/components/schemas/Circular_1", result.AllOf[0].Ref)
require.Equal(t, result, result.AllOf[0].Value)
}

func TestMerge_CircularAllOfProps(t *testing.T) {

ctx := context.Background()
sl := openapi3.NewLoader()
doc, err := sl.LoadFromFile("testdata/circular1.yaml")

require.NoError(t, err, "loading test file")
err = doc.Validate(ctx)
require.NoError(t, err)
result, err := flatten.Merge(*doc.Components.Schemas["Circular_2"])
require.NoError(t, err)

require.Len(t, result.AllOf, 2)
require.Equal(t, result, result.AllOf[0].Value)
require.Equal(t, "#/components/schemas/Circular_2", result.AllOf[0].Ref)
require.Equal(t, result, result.AllOf[1].Value.Properties["test"].Value)
require.Equal(t, "#/components/schemas/Circular_2", result.AllOf[1].Value.Properties["test"].Ref)

result, err = flatten.Merge(*doc.Components.Schemas["Circular_3"])
require.NoError(t, err)
require.Equal(t, "#/components/schemas/Circular_3", result.Properties["test"].Value.AllOf[0].Ref)
require.Equal(t, result, result.Properties["test"].Value.AllOf[0].Value)

result, err = flatten.Merge(*doc.Components.Schemas["Circular_4"])
require.NoError(t, err)
require.Equal(t, "#/components/schemas/Circular_4", result.AllOf[0].Ref)
require.Equal(t, result, result.AllOf[0].Value)
require.Equal(t, "#/components/schemas/Circular_4", result.AllOf[1].Value.Properties["test"].Value.Properties["b"].Ref)
require.Equal(t, result, result.AllOf[1].Value.Properties["test"].Value.Properties["b"].Value)
}
62 changes: 62 additions & 0 deletions flatten/testdata/circular1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
openapi: 3.0.0
info:
title: Circular Reference Example
version: 1.0.0
paths:
/endpoint:
get:
summary: Example endpoint
responses:
'200':
description: OK
content:
application/json:
schema:
$ref: '#/components/schemas/Circular_1'

components:
schemas:
Circular_1:
type: object
allOf:
- $ref: '#/components/schemas/Circular_1'
required:
- name
description: simple circular allof

Circular_2:
type: object
allOf:
- $ref: '#/components/schemas/Circular_2'
- type: object
properties:
test:
$ref: '#/components/schemas/Circular_2'

Circular_3:
type: object
properties:
test:
type: object
allOf:
- $ref: '#/components/schemas/Circular_3'

Circular_4:
type: object
allOf:
- $ref: '#/components/schemas/Circular_4'
- type: object
properties:
test:
type: object
allOf:
- type: object
properties:
a:
type: object
description: a
- type: object
properties:
b:
$ref: '#/components/schemas/Circular_4'
description: b

0 comments on commit 170f0e5

Please sign in to comment.