From 170f0e5a4e4c3056b0b2f2b21da85e4b90dc0e1b Mon Sep 17 00:00:00 2001 From: tcdsv Date: Sat, 21 Oct 2023 18:48:29 +0300 Subject: [PATCH] initial solution for circular refs and testing --- flatten/merge_allof.go | 81 +++++++++++++++++++++++++++------ flatten/merge_allof_test.go | 45 ++++++++++++++++++ flatten/testdata/circular1.yaml | 62 +++++++++++++++++++++++++ 3 files changed, 174 insertions(+), 14 deletions(-) create mode 100644 flatten/testdata/circular1.yaml diff --git a/flatten/merge_allof.go b/flatten/merge_allof.go index 968afafe..41d5d94f 100644 --- a/flatten/merge_allof.go +++ b/flatten/merge_allof.go @@ -8,7 +8,6 @@ import ( "strings" "github.com/getkin/kin-openapi/openapi3" - "github.com/tufin/oasdiff/utils" ) const ( @@ -54,12 +53,17 @@ type SchemaCollection struct { } type state struct { - visitedSchemas utils.VisitedRefs + // map processed circular refs to their result schema + circularRefs map[string]*openapi3.Schema + + // map original schemas to their result schema + mergedSchemas map[*openapi3.Schema]*openapi3.Schema } func newState() *state { return &state{ - visitedSchemas: utils.VisitedRefs{}, + circularRefs: map[string]*openapi3.Schema{}, + mergedSchemas: map[*openapi3.Schema]*openapi3.Schema{}, } } @@ -73,6 +77,30 @@ func Merge(schema openapi3.SchemaRef) (*openapi3.Schema, error) { // Merge replaces objects under AllOf with a flattened equivalent func mergeInternal(state *state, baseSchemaRef openapi3.SchemaRef) (*openapi3.SchemaRef, error) { + + // handle circular refs + if baseSchemaRef.Ref != "" { + + // If the reference is circular and has already been processed, return its resulting schema. + result, ok := state.circularRefs[baseSchemaRef.Ref] + if ok { + return openapi3.NewSchemaRef(baseSchemaRef.Ref, result), nil + } + + // If a circular reference is found, return a schema that closes the loop. + result, ok = state.mergedSchemas[baseSchemaRef.Value] + if ok { + state.circularRefs[baseSchemaRef.Ref] = result + return openapi3.NewSchemaRef(baseSchemaRef.Ref, result), nil + } + } + + mergedSchema := openapi3.NewSchema() + result := openapi3.NewSchemaRef(baseSchemaRef.Ref, mergedSchema) + + // map original schema to result + state.mergedSchemas[baseSchemaRef.Value] = mergedSchema + baseSchema := baseSchemaRef.Value allOfSchemas, err := getAllOfSchemas(state, baseSchema.AllOf) @@ -80,15 +108,30 @@ func mergeInternal(state *state, baseSchemaRef openapi3.SchemaRef) (*openapi3.Sc return nil, err } - schemaRefs := openapi3.SchemaRefs{&baseSchemaRef} - schemaRefs = append(schemaRefs, allOfSchemas...) - result, err := flattenSchemas(state, schemaRefs) + schemaRefsToFlatten := openapi3.SchemaRefs{&baseSchemaRef} + + // in case that AllOf has circular reference, it is not flattened. + if isCircular(state, allOfSchemas) { + mergedSchema.AllOf = allOfSchemas + } else { + schemaRefsToFlatten = append(schemaRefsToFlatten, allOfSchemas...) + } + _, err = flattenSchemas(state, schemaRefsToFlatten, mergedSchema) if err != nil { return nil, err } + return result, nil +} - return openapi3.NewSchemaRef(baseSchemaRef.Ref, result), nil +func isCircular(state *state, schemaRefs openapi3.SchemaRefs) bool { + for _, s := range schemaRefs { + _, ok := state.circularRefs[s.Ref] + if ok { + return true + } + } + return false } func getAllOfSchemas(state *state, schemaRefs openapi3.SchemaRefs) (openapi3.SchemaRefs, error) { @@ -107,8 +150,7 @@ func getAllOfSchemas(state *state, schemaRefs openapi3.SchemaRefs) (openapi3.Sch return srefs, nil } -func flattenSchemas(state *state, schemas []*openapi3.SchemaRef) (*openapi3.Schema, error) { - result := openapi3.NewSchema() +func flattenSchemas(state *state, schemas []*openapi3.SchemaRef, result *openapi3.Schema) (*openapi3.Schema, error) { collection := collect(schemas) result.Title = collection.Title[0] @@ -236,7 +278,7 @@ func resolveItems(state *state, schema *openapi3.Schema, collection *SchemaColle return schema, nil } - res, err := flattenSchemas(state, items) + res, err := flattenSchemas(state, items, openapi3.NewSchema()) if err != nil { return nil, err } @@ -339,7 +381,7 @@ func resolveNonFalseAdditionalProps(state *state, schema *openapi3.Schema, colle var schemaRef *openapi3.SchemaRef if len(additionalSchemas) > 0 { - result, err := flattenSchemas(state, additionalSchemas) + result, err := flattenSchemas(state, additionalSchemas, openapi3.NewSchema()) if err != nil { return nil, err } @@ -356,6 +398,9 @@ func resolveNonFalseProps(state *state, schema *openapi3.Schema, collection *Sch return nil, err } propsToMerge := getNonFalsePropsKeys(collection) + if len(propsToMerge) == 0 { + return schema, nil + } return mergeProps(state, result, collection, propsToMerge) } @@ -425,7 +470,15 @@ func mergeProps(state *state, schema *openapi3.Schema, collection *SchemaCollect result := make(openapi3.Schemas) for prop, schemas := range propsToSchemasMap { - mergedProp, err := flattenSchemas(state, schemas) + + if len(schemas) == 1 { + result[prop] = schemas[0] + continue + } + + // TODO: If schemas of some property contain a loop, do not flatten them. + + mergedProp, err := flattenSchemas(state, schemas, openapi3.NewSchema()) if err != nil { return nil, err } @@ -735,7 +788,7 @@ func getCombinations(groups []openapi3.SchemaRefs) []openapi3.SchemaRefs { func mergeCombinations(state *state, combinations []openapi3.SchemaRefs) ([]*openapi3.Schema, error) { merged := []*openapi3.Schema{} for _, combination := range combinations { - schema, err := flattenSchemas(state, combination) + schema, err := flattenSchemas(state, combination, openapi3.NewSchema()) if err != nil { continue } @@ -812,7 +865,7 @@ func mergeSchemaRefs(state *state, sr []openapi3.SchemaRefs) ([]openapi3.SchemaR if err != nil { return result, err } - r = append(r, openapi3.NewSchemaRef("", merged.Value)) + r = append(r, merged) } result = append(result, r) } diff --git a/flatten/merge_allof_test.go b/flatten/merge_allof_test.go index 6e1f626b..0f088887 100644 --- a/flatten/merge_allof_test.go +++ b/flatten/merge_allof_test.go @@ -1651,3 +1651,48 @@ func TestMerge_Required(t *testing.T) { }) } } + +func TestMerge_SimpleCircularAllOf(t *testing.T) { + ctx := context.Background() + sl := openapi3.NewLoader() + doc, err := sl.LoadFromFile("testdata/circular1.yaml") + require.NoError(t, err, "loading test file") + err = doc.Validate(ctx) + require.NoError(t, err) + result, err := flatten.Merge(*doc.Components.Schemas["Circular_1"]) + require.NoError(t, err) + require.NotEmpty(t, result.AllOf) + require.Equal(t, "#/components/schemas/Circular_1", result.AllOf[0].Ref) + require.Equal(t, result, result.AllOf[0].Value) +} + +func TestMerge_CircularAllOfProps(t *testing.T) { + + ctx := context.Background() + sl := openapi3.NewLoader() + doc, err := sl.LoadFromFile("testdata/circular1.yaml") + + require.NoError(t, err, "loading test file") + err = doc.Validate(ctx) + require.NoError(t, err) + result, err := flatten.Merge(*doc.Components.Schemas["Circular_2"]) + require.NoError(t, err) + + require.Len(t, result.AllOf, 2) + require.Equal(t, result, result.AllOf[0].Value) + require.Equal(t, "#/components/schemas/Circular_2", result.AllOf[0].Ref) + require.Equal(t, result, result.AllOf[1].Value.Properties["test"].Value) + require.Equal(t, "#/components/schemas/Circular_2", result.AllOf[1].Value.Properties["test"].Ref) + + result, err = flatten.Merge(*doc.Components.Schemas["Circular_3"]) + require.NoError(t, err) + require.Equal(t, "#/components/schemas/Circular_3", result.Properties["test"].Value.AllOf[0].Ref) + require.Equal(t, result, result.Properties["test"].Value.AllOf[0].Value) + + result, err = flatten.Merge(*doc.Components.Schemas["Circular_4"]) + require.NoError(t, err) + require.Equal(t, "#/components/schemas/Circular_4", result.AllOf[0].Ref) + require.Equal(t, result, result.AllOf[0].Value) + require.Equal(t, "#/components/schemas/Circular_4", result.AllOf[1].Value.Properties["test"].Value.Properties["b"].Ref) + require.Equal(t, result, result.AllOf[1].Value.Properties["test"].Value.Properties["b"].Value) +} diff --git a/flatten/testdata/circular1.yaml b/flatten/testdata/circular1.yaml new file mode 100644 index 00000000..071ac8a3 --- /dev/null +++ b/flatten/testdata/circular1.yaml @@ -0,0 +1,62 @@ +openapi: 3.0.0 +info: + title: Circular Reference Example + version: 1.0.0 +paths: + /endpoint: + get: + summary: Example endpoint + responses: + '200': + description: OK + content: + application/json: + schema: + $ref: '#/components/schemas/Circular_1' + +components: + schemas: + Circular_1: + type: object + allOf: + - $ref: '#/components/schemas/Circular_1' + required: + - name + description: simple circular allof + + Circular_2: + type: object + allOf: + - $ref: '#/components/schemas/Circular_2' + - type: object + properties: + test: + $ref: '#/components/schemas/Circular_2' + + Circular_3: + type: object + properties: + test: + type: object + allOf: + - $ref: '#/components/schemas/Circular_3' + + Circular_4: + type: object + allOf: + - $ref: '#/components/schemas/Circular_4' + - type: object + properties: + test: + type: object + allOf: + - type: object + properties: + a: + type: object + description: a + - type: object + properties: + b: + $ref: '#/components/schemas/Circular_4' + description: b