From a146d1eebbb6ba3c836af62451c4e766a2ff66c8 Mon Sep 17 00:00:00 2001 From: tcdsv Date: Wed, 11 Oct 2023 18:46:41 +0300 Subject: [PATCH] mergeInternal - return SchemaRef instead of Schema --- flatten/merge_allof.go | 191 ++++++++++++++++++++--------------------- 1 file changed, 95 insertions(+), 96 deletions(-) diff --git a/flatten/merge_allof.go b/flatten/merge_allof.go index a90459ff..968afafe 100644 --- a/flatten/merge_allof.go +++ b/flatten/merge_allof.go @@ -64,42 +64,50 @@ func newState() *state { } func Merge(schema openapi3.SchemaRef) (*openapi3.Schema, error) { - return mergeInternal(newState(), schema) + result, err := mergeInternal(newState(), schema) + if err != nil { + return nil, err + } + return result.Value, nil } // Merge replaces objects under AllOf with a flattened equivalent -func mergeInternal(state *state, baseSchemaRef openapi3.SchemaRef) (*openapi3.Schema, error) { +func mergeInternal(state *state, baseSchemaRef openapi3.SchemaRef) (*openapi3.SchemaRef, error) { baseSchema := baseSchemaRef.Value - allOfSchemas, err := getAllOfSchemas(state, *baseSchema) + allOfSchemas, err := getAllOfSchemas(state, baseSchema.AllOf) if err != nil { - return &openapi3.Schema{}, err + return nil, err } - schemas := []*openapi3.Schema{baseSchema} - schemas = append(schemas, allOfSchemas...) - result, err := flattenSchemas(state, schemas) + + schemaRefs := openapi3.SchemaRefs{&baseSchemaRef} + schemaRefs = append(schemaRefs, allOfSchemas...) + result, err := flattenSchemas(state, schemaRefs) + if err != nil { - return &openapi3.Schema{}, err + return nil, err } - return result, nil + + return openapi3.NewSchemaRef(baseSchemaRef.Ref, result), nil } -func getAllOfSchemas(state *state, schema openapi3.Schema) ([]*openapi3.Schema, error) { - schemas := []*openapi3.Schema{} - if schema.AllOf == nil { - return schemas, nil +func getAllOfSchemas(state *state, schemaRefs openapi3.SchemaRefs) (openapi3.SchemaRefs, error) { + + srefs := openapi3.SchemaRefs{} + if schemaRefs == nil { + return srefs, nil } - for _, schema := range schema.AllOf { - merged, err := mergeInternal(state, *schema) + for _, sref := range schemaRefs { + merged, err := mergeInternal(state, *sref) if err != nil { - return schemas, err + return nil, err } - schemas = append(schemas, merged) + srefs = append(srefs, merged) } - return schemas, nil + return srefs, nil } -func flattenSchemas(state *state, schemas []*openapi3.Schema) (*openapi3.Schema, error) { +func flattenSchemas(state *state, schemas []*openapi3.SchemaRef) (*openapi3.Schema, error) { result := openapi3.NewSchema() collection := collect(schemas) @@ -107,11 +115,11 @@ func flattenSchemas(state *state, schemas []*openapi3.Schema) (*openapi3.Schema, result.Description = collection.Description[0] result, err := resolveFormat(result, &collection) if err != nil { - return result, err + return nil, err } result, err = resolveType(result, &collection) if err != nil { - return result, err + return nil, err } result = resolveNumberRange(result, &collection) result.MinLength = findMaxValue(collection.MinLength) @@ -127,7 +135,7 @@ func flattenSchemas(state *state, schemas []*openapi3.Schema) (*openapi3.Schema, enums, err := resolveEnum(collection.Enum) if err != nil { - return result, err + return nil, err } result.Enum = enums result = resolveMultipleOf(result, &collection) @@ -139,22 +147,22 @@ func flattenSchemas(state *state, schemas []*openapi3.Schema) (*openapi3.Schema, result.UniqueItems = resolveUniqueItems(collection.UniqueItems) result, err = resolveProperties(state, result, &collection) if err != nil { - return result, err + return nil, err } result, err = resolveOneOf(state, result, &collection) if err != nil { - return result, err + return nil, err } result, err = resolveAnyOf(state, result, &collection) if err != nil { - return result, err + return nil, err } result, err = resolveNot(state, result, &collection) if err != nil { - return result, err + return nil, err } return result, nil @@ -216,10 +224,11 @@ func resolveNumberRange(schema *openapi3.Schema, collection *SchemaCollection) * } func resolveItems(state *state, schema *openapi3.Schema, collection *SchemaCollection) (*openapi3.Schema, error) { - items := []*openapi3.Schema{} - for _, s := range collection.Items { - if s != nil { - items = append(items, s.Value) + + items := openapi3.SchemaRefs{} + for _, sref := range collection.Items { + if sref != nil { + items = append(items, sref) } } if len(items) == 0 { @@ -229,12 +238,9 @@ func resolveItems(state *state, schema *openapi3.Schema, collection *SchemaColle res, err := flattenSchemas(state, items) if err != nil { - return schema, err - } - ref := openapi3.SchemaRef{ - Value: res, + return nil, err } - schema.Items = &ref + schema.Items = openapi3.NewSchemaRef("", res) return schema, nil } @@ -324,10 +330,10 @@ func resolveFalseAdditionalProps(schema *openapi3.Schema, collection *SchemaColl // if there are additionalProperties which are Schemas, they are merged to a single Schema. func resolveNonFalseAdditionalProps(state *state, schema *openapi3.Schema, collection *SchemaCollection) (*openapi3.Schema, error) { - additionalSchemas := []*openapi3.Schema{} + additionalSchemas := openapi3.SchemaRefs{} for _, ap := range collection.AdditionalProperties { - if ap.Schema != nil && ap.Schema.Value != nil { - additionalSchemas = append(additionalSchemas, ap.Schema.Value) + if ap.Schema != nil { + additionalSchemas = append(additionalSchemas, ap.Schema) } } @@ -335,11 +341,9 @@ func resolveNonFalseAdditionalProps(state *state, schema *openapi3.Schema, colle if len(additionalSchemas) > 0 { result, err := flattenSchemas(state, additionalSchemas) if err != nil { - return schema, err - } - schemaRef = &openapi3.SchemaRef{ - Value: result, + return nil, err } + schemaRef = openapi3.NewSchemaRef("", result) } schema.AdditionalProperties.Has = nil schema.AdditionalProperties.Schema = schemaRef @@ -349,7 +353,7 @@ func resolveNonFalseAdditionalProps(state *state, schema *openapi3.Schema, colle func resolveNonFalseProps(state *state, schema *openapi3.Schema, collection *SchemaCollection) (*openapi3.Schema, error) { result, err := resolveNonFalseAdditionalProps(state, schema, collection) if err != nil { - return &openapi3.Schema{}, err + return nil, err } propsToMerge := getNonFalsePropsKeys(collection) return mergeProps(state, result, collection, propsToMerge) @@ -406,13 +410,13 @@ func resolveProperties(state *state, schema *openapi3.Schema, collection *Schema } func mergeProps(state *state, schema *openapi3.Schema, collection *SchemaCollection, propsToMerge []string) (*openapi3.Schema, error) { - propsToSchemasMap := map[string][]*openapi3.Schema{} + propsToSchemasMap := map[string]openapi3.SchemaRefs{} for _, schema := range collection.Properties { for propKey, schemaRef := range schema { if containsString(propsToMerge, propKey) { propMergedSchema, err := mergeInternal(state, *schemaRef) if err != nil { - return &openapi3.Schema{}, err + return nil, err } propsToSchemasMap[propKey] = append(propsToSchemasMap[propKey], propMergedSchema) } @@ -423,12 +427,9 @@ func mergeProps(state *state, schema *openapi3.Schema, collection *SchemaCollect for prop, schemas := range propsToSchemasMap { mergedProp, err := flattenSchemas(state, schemas) if err != nil { - return &openapi3.Schema{}, err - } - ref := openapi3.SchemaRef{ - Value: mergedProp, + return nil, err } - result[prop] = &ref + result[prop] = openapi3.NewSchemaRef("", mergedProp) } if len(result) == 0 { @@ -452,7 +453,7 @@ func resolveEnum(values [][]interface{}) ([]interface{}, error) { } intersection = findIntersectionOfArrays(nonEmptyEnum) if len(intersection) == 0 { - return intersection, errors.New("unable to resolve Enum conflict: intersection of values must be non-empty") + return nil, errors.New("unable to resolve Enum conflict: intersection of values must be non-empty") } return intersection, nil } @@ -673,37 +674,40 @@ func resolveRequired(values [][]string) []string { return uniqueValues } -func collect(schemas []*openapi3.Schema) SchemaCollection { +func collect(schemas []*openapi3.SchemaRef) SchemaCollection { collection := SchemaCollection{} for _, s := range schemas { - collection.Not = append(collection.Not, s.Not) - collection.AnyOf = append(collection.AnyOf, s.AnyOf) - collection.OneOf = append(collection.OneOf, s.OneOf) - collection.Title = append(collection.Title, s.Title) - collection.Type = append(collection.Type, s.Type) - collection.Format = append(collection.Format, s.Format) - collection.Description = append(collection.Description, s.Description) - collection.Enum = append(collection.Enum, s.Enum) - collection.UniqueItems = append(collection.UniqueItems, s.UniqueItems) - collection.ExclusiveMin = append(collection.ExclusiveMin, s.ExclusiveMin) - collection.ExclusiveMax = append(collection.ExclusiveMax, s.ExclusiveMax) - collection.Min = append(collection.Min, s.Min) - collection.Max = append(collection.Max, s.Max) - collection.MultipleOf = append(collection.MultipleOf, s.MultipleOf) - collection.MinLength = append(collection.MinLength, s.MinLength) - collection.MaxLength = append(collection.MaxLength, s.MaxLength) - collection.Pattern = append(collection.Pattern, s.Pattern) - collection.MinItems = append(collection.MinItems, s.MinItems) - collection.MaxItems = append(collection.MaxItems, s.MaxItems) - collection.Items = append(collection.Items, s.Items) - collection.Required = append(collection.Required, s.Required) - collection.Properties = append(collection.Properties, s.Properties) - collection.MinProps = append(collection.MinProps, s.MinProps) - collection.MaxProps = append(collection.MaxProps, s.MaxProps) - collection.AdditionalProperties = append(collection.AdditionalProperties, s.AdditionalProperties) - collection.Nullable = append(collection.Nullable, s.Nullable) - collection.ReadOnly = append(collection.ReadOnly, s.ReadOnly) - collection.WriteOnly = append(collection.WriteOnly, s.WriteOnly) + if s == nil { + continue + } + collection.Not = append(collection.Not, s.Value.Not) + collection.AnyOf = append(collection.AnyOf, s.Value.AnyOf) + collection.OneOf = append(collection.OneOf, s.Value.OneOf) + collection.Title = append(collection.Title, s.Value.Title) + collection.Type = append(collection.Type, s.Value.Type) + collection.Format = append(collection.Format, s.Value.Format) + collection.Description = append(collection.Description, s.Value.Description) + collection.Enum = append(collection.Enum, s.Value.Enum) + collection.UniqueItems = append(collection.UniqueItems, s.Value.UniqueItems) + collection.ExclusiveMin = append(collection.ExclusiveMin, s.Value.ExclusiveMin) + collection.ExclusiveMax = append(collection.ExclusiveMax, s.Value.ExclusiveMax) + collection.Min = append(collection.Min, s.Value.Min) + collection.Max = append(collection.Max, s.Value.Max) + collection.MultipleOf = append(collection.MultipleOf, s.Value.MultipleOf) + collection.MinLength = append(collection.MinLength, s.Value.MinLength) + collection.MaxLength = append(collection.MaxLength, s.Value.MaxLength) + collection.Pattern = append(collection.Pattern, s.Value.Pattern) + collection.MinItems = append(collection.MinItems, s.Value.MinItems) + collection.MaxItems = append(collection.MaxItems, s.Value.MaxItems) + collection.Items = append(collection.Items, s.Value.Items) + collection.Required = append(collection.Required, s.Value.Required) + collection.Properties = append(collection.Properties, s.Value.Properties) + collection.MinProps = append(collection.MinProps, s.Value.MinProps) + collection.MaxProps = append(collection.MaxProps, s.Value.MaxProps) + collection.AdditionalProperties = append(collection.AdditionalProperties, s.Value.AdditionalProperties) + collection.Nullable = append(collection.Nullable, s.Value.Nullable) + collection.ReadOnly = append(collection.ReadOnly, s.Value.ReadOnly) + collection.WriteOnly = append(collection.WriteOnly, s.Value.WriteOnly) } return collection } @@ -731,18 +735,14 @@ func getCombinations(groups []openapi3.SchemaRefs) []openapi3.SchemaRefs { func mergeCombinations(state *state, combinations []openapi3.SchemaRefs) ([]*openapi3.Schema, error) { merged := []*openapi3.Schema{} for _, combination := range combinations { - schemas := []*openapi3.Schema{} - for _, ref := range combination { - schemas = append(schemas, ref.Value) - } - schema, err := flattenSchemas(state, schemas) + schema, err := flattenSchemas(state, combination) if err != nil { continue } merged = append(merged, schema) } if len(merged) == 0 { - return merged, errors.New("unable to resolve combined schema") + return nil, errors.New("unable to resolve combined schema") } return merged, nil } @@ -752,23 +752,22 @@ func resolveNot(state *state, schema *openapi3.Schema, collection *SchemaCollect if len(refs) == 0 { return schema, nil } + result := openapi3.SchemaRefs{} for _, ref := range refs { merged, err := mergeInternal(state, *ref) if err != nil { - return &openapi3.Schema{}, err + return nil, err } - ref.Value = merged + result = append(result, merged) } - if len(refs) == 1 { - schema.Not = refs[0] + if len(result) == 1 { + schema.Not = result[0] return schema, nil } - schema.Not = &openapi3.SchemaRef{ - Value: &openapi3.Schema{ - AnyOf: refs, - }, - } + schema.Not = openapi3.NewSchemaRef("", &openapi3.Schema{ + AnyOf: result, + }) return schema, nil } @@ -813,7 +812,7 @@ func mergeSchemaRefs(state *state, sr []openapi3.SchemaRefs) ([]openapi3.SchemaR if err != nil { return result, err } - r = append(r, &openapi3.SchemaRef{Value: merged}) + r = append(r, openapi3.NewSchemaRef("", merged.Value)) } result = append(result, r) }