Skip to content

Commit

Permalink
mergeInternal - return SchemaRef instead of Schema
Browse files Browse the repository at this point in the history
  • Loading branch information
tcdsv committed Oct 11, 2023
1 parent 45dafb4 commit a146d1e
Showing 1 changed file with 95 additions and 96 deletions.
191 changes: 95 additions & 96 deletions flatten/merge_allof.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,54 +64,62 @@ func newState() *state {
}

func Merge(schema openapi3.SchemaRef) (*openapi3.Schema, error) {
return mergeInternal(newState(), schema)
result, err := mergeInternal(newState(), schema)
if err != nil {
return nil, err
}
return result.Value, nil
}

// Merge replaces objects under AllOf with a flattened equivalent
func mergeInternal(state *state, baseSchemaRef openapi3.SchemaRef) (*openapi3.Schema, error) {
func mergeInternal(state *state, baseSchemaRef openapi3.SchemaRef) (*openapi3.SchemaRef, error) {
baseSchema := baseSchemaRef.Value
allOfSchemas, err := getAllOfSchemas(state, *baseSchema)
allOfSchemas, err := getAllOfSchemas(state, baseSchema.AllOf)

if err != nil {
return &openapi3.Schema{}, err
return nil, err
}
schemas := []*openapi3.Schema{baseSchema}
schemas = append(schemas, allOfSchemas...)
result, err := flattenSchemas(state, schemas)

schemaRefs := openapi3.SchemaRefs{&baseSchemaRef}
schemaRefs = append(schemaRefs, allOfSchemas...)
result, err := flattenSchemas(state, schemaRefs)

if err != nil {
return &openapi3.Schema{}, err
return nil, err
}
return result, nil

return openapi3.NewSchemaRef(baseSchemaRef.Ref, result), nil
}

func getAllOfSchemas(state *state, schema openapi3.Schema) ([]*openapi3.Schema, error) {
schemas := []*openapi3.Schema{}
if schema.AllOf == nil {
return schemas, nil
func getAllOfSchemas(state *state, schemaRefs openapi3.SchemaRefs) (openapi3.SchemaRefs, error) {

srefs := openapi3.SchemaRefs{}
if schemaRefs == nil {
return srefs, nil
}
for _, schema := range schema.AllOf {
merged, err := mergeInternal(state, *schema)
for _, sref := range schemaRefs {
merged, err := mergeInternal(state, *sref)
if err != nil {
return schemas, err
return nil, err
}
schemas = append(schemas, merged)
srefs = append(srefs, merged)
}
return schemas, nil
return srefs, nil
}

func flattenSchemas(state *state, schemas []*openapi3.Schema) (*openapi3.Schema, error) {
func flattenSchemas(state *state, schemas []*openapi3.SchemaRef) (*openapi3.Schema, error) {
result := openapi3.NewSchema()
collection := collect(schemas)

result.Title = collection.Title[0]
result.Description = collection.Description[0]
result, err := resolveFormat(result, &collection)
if err != nil {
return result, err
return nil, err
}
result, err = resolveType(result, &collection)
if err != nil {
return result, err
return nil, err
}
result = resolveNumberRange(result, &collection)
result.MinLength = findMaxValue(collection.MinLength)
Expand All @@ -127,7 +135,7 @@ func flattenSchemas(state *state, schemas []*openapi3.Schema) (*openapi3.Schema,

enums, err := resolveEnum(collection.Enum)
if err != nil {
return result, err
return nil, err
}
result.Enum = enums
result = resolveMultipleOf(result, &collection)
Expand All @@ -139,22 +147,22 @@ func flattenSchemas(state *state, schemas []*openapi3.Schema) (*openapi3.Schema,
result.UniqueItems = resolveUniqueItems(collection.UniqueItems)
result, err = resolveProperties(state, result, &collection)
if err != nil {
return result, err
return nil, err
}

result, err = resolveOneOf(state, result, &collection)
if err != nil {
return result, err
return nil, err
}

result, err = resolveAnyOf(state, result, &collection)
if err != nil {
return result, err
return nil, err
}

result, err = resolveNot(state, result, &collection)
if err != nil {
return result, err
return nil, err
}

return result, nil
Expand Down Expand Up @@ -216,10 +224,11 @@ func resolveNumberRange(schema *openapi3.Schema, collection *SchemaCollection) *
}

func resolveItems(state *state, schema *openapi3.Schema, collection *SchemaCollection) (*openapi3.Schema, error) {
items := []*openapi3.Schema{}
for _, s := range collection.Items {
if s != nil {
items = append(items, s.Value)

items := openapi3.SchemaRefs{}
for _, sref := range collection.Items {
if sref != nil {
items = append(items, sref)
}
}
if len(items) == 0 {
Expand All @@ -229,12 +238,9 @@ func resolveItems(state *state, schema *openapi3.Schema, collection *SchemaColle

res, err := flattenSchemas(state, items)
if err != nil {
return schema, err
}
ref := openapi3.SchemaRef{
Value: res,
return nil, err
}
schema.Items = &ref
schema.Items = openapi3.NewSchemaRef("", res)
return schema, nil
}

Expand Down Expand Up @@ -324,22 +330,20 @@ func resolveFalseAdditionalProps(schema *openapi3.Schema, collection *SchemaColl

// if there are additionalProperties which are Schemas, they are merged to a single Schema.
func resolveNonFalseAdditionalProps(state *state, schema *openapi3.Schema, collection *SchemaCollection) (*openapi3.Schema, error) {
additionalSchemas := []*openapi3.Schema{}
additionalSchemas := openapi3.SchemaRefs{}
for _, ap := range collection.AdditionalProperties {
if ap.Schema != nil && ap.Schema.Value != nil {
additionalSchemas = append(additionalSchemas, ap.Schema.Value)
if ap.Schema != nil {
additionalSchemas = append(additionalSchemas, ap.Schema)
}
}

var schemaRef *openapi3.SchemaRef
if len(additionalSchemas) > 0 {
result, err := flattenSchemas(state, additionalSchemas)
if err != nil {
return schema, err
}
schemaRef = &openapi3.SchemaRef{
Value: result,
return nil, err
}
schemaRef = openapi3.NewSchemaRef("", result)
}
schema.AdditionalProperties.Has = nil
schema.AdditionalProperties.Schema = schemaRef
Expand All @@ -349,7 +353,7 @@ func resolveNonFalseAdditionalProps(state *state, schema *openapi3.Schema, colle
func resolveNonFalseProps(state *state, schema *openapi3.Schema, collection *SchemaCollection) (*openapi3.Schema, error) {
result, err := resolveNonFalseAdditionalProps(state, schema, collection)
if err != nil {
return &openapi3.Schema{}, err
return nil, err
}
propsToMerge := getNonFalsePropsKeys(collection)
return mergeProps(state, result, collection, propsToMerge)
Expand Down Expand Up @@ -406,13 +410,13 @@ func resolveProperties(state *state, schema *openapi3.Schema, collection *Schema
}

func mergeProps(state *state, schema *openapi3.Schema, collection *SchemaCollection, propsToMerge []string) (*openapi3.Schema, error) {
propsToSchemasMap := map[string][]*openapi3.Schema{}
propsToSchemasMap := map[string]openapi3.SchemaRefs{}
for _, schema := range collection.Properties {
for propKey, schemaRef := range schema {
if containsString(propsToMerge, propKey) {
propMergedSchema, err := mergeInternal(state, *schemaRef)
if err != nil {
return &openapi3.Schema{}, err
return nil, err
}
propsToSchemasMap[propKey] = append(propsToSchemasMap[propKey], propMergedSchema)
}
Expand All @@ -423,12 +427,9 @@ func mergeProps(state *state, schema *openapi3.Schema, collection *SchemaCollect
for prop, schemas := range propsToSchemasMap {
mergedProp, err := flattenSchemas(state, schemas)
if err != nil {
return &openapi3.Schema{}, err
}
ref := openapi3.SchemaRef{
Value: mergedProp,
return nil, err
}
result[prop] = &ref
result[prop] = openapi3.NewSchemaRef("", mergedProp)
}

if len(result) == 0 {
Expand All @@ -452,7 +453,7 @@ func resolveEnum(values [][]interface{}) ([]interface{}, error) {
}
intersection = findIntersectionOfArrays(nonEmptyEnum)
if len(intersection) == 0 {
return intersection, errors.New("unable to resolve Enum conflict: intersection of values must be non-empty")
return nil, errors.New("unable to resolve Enum conflict: intersection of values must be non-empty")
}
return intersection, nil
}
Expand Down Expand Up @@ -673,37 +674,40 @@ func resolveRequired(values [][]string) []string {
return uniqueValues
}

func collect(schemas []*openapi3.Schema) SchemaCollection {
func collect(schemas []*openapi3.SchemaRef) SchemaCollection {
collection := SchemaCollection{}
for _, s := range schemas {
collection.Not = append(collection.Not, s.Not)
collection.AnyOf = append(collection.AnyOf, s.AnyOf)
collection.OneOf = append(collection.OneOf, s.OneOf)
collection.Title = append(collection.Title, s.Title)
collection.Type = append(collection.Type, s.Type)
collection.Format = append(collection.Format, s.Format)
collection.Description = append(collection.Description, s.Description)
collection.Enum = append(collection.Enum, s.Enum)
collection.UniqueItems = append(collection.UniqueItems, s.UniqueItems)
collection.ExclusiveMin = append(collection.ExclusiveMin, s.ExclusiveMin)
collection.ExclusiveMax = append(collection.ExclusiveMax, s.ExclusiveMax)
collection.Min = append(collection.Min, s.Min)
collection.Max = append(collection.Max, s.Max)
collection.MultipleOf = append(collection.MultipleOf, s.MultipleOf)
collection.MinLength = append(collection.MinLength, s.MinLength)
collection.MaxLength = append(collection.MaxLength, s.MaxLength)
collection.Pattern = append(collection.Pattern, s.Pattern)
collection.MinItems = append(collection.MinItems, s.MinItems)
collection.MaxItems = append(collection.MaxItems, s.MaxItems)
collection.Items = append(collection.Items, s.Items)
collection.Required = append(collection.Required, s.Required)
collection.Properties = append(collection.Properties, s.Properties)
collection.MinProps = append(collection.MinProps, s.MinProps)
collection.MaxProps = append(collection.MaxProps, s.MaxProps)
collection.AdditionalProperties = append(collection.AdditionalProperties, s.AdditionalProperties)
collection.Nullable = append(collection.Nullable, s.Nullable)
collection.ReadOnly = append(collection.ReadOnly, s.ReadOnly)
collection.WriteOnly = append(collection.WriteOnly, s.WriteOnly)
if s == nil {
continue
}
collection.Not = append(collection.Not, s.Value.Not)
collection.AnyOf = append(collection.AnyOf, s.Value.AnyOf)
collection.OneOf = append(collection.OneOf, s.Value.OneOf)
collection.Title = append(collection.Title, s.Value.Title)
collection.Type = append(collection.Type, s.Value.Type)
collection.Format = append(collection.Format, s.Value.Format)
collection.Description = append(collection.Description, s.Value.Description)
collection.Enum = append(collection.Enum, s.Value.Enum)
collection.UniqueItems = append(collection.UniqueItems, s.Value.UniqueItems)
collection.ExclusiveMin = append(collection.ExclusiveMin, s.Value.ExclusiveMin)
collection.ExclusiveMax = append(collection.ExclusiveMax, s.Value.ExclusiveMax)
collection.Min = append(collection.Min, s.Value.Min)
collection.Max = append(collection.Max, s.Value.Max)
collection.MultipleOf = append(collection.MultipleOf, s.Value.MultipleOf)
collection.MinLength = append(collection.MinLength, s.Value.MinLength)
collection.MaxLength = append(collection.MaxLength, s.Value.MaxLength)
collection.Pattern = append(collection.Pattern, s.Value.Pattern)
collection.MinItems = append(collection.MinItems, s.Value.MinItems)
collection.MaxItems = append(collection.MaxItems, s.Value.MaxItems)
collection.Items = append(collection.Items, s.Value.Items)
collection.Required = append(collection.Required, s.Value.Required)
collection.Properties = append(collection.Properties, s.Value.Properties)
collection.MinProps = append(collection.MinProps, s.Value.MinProps)
collection.MaxProps = append(collection.MaxProps, s.Value.MaxProps)
collection.AdditionalProperties = append(collection.AdditionalProperties, s.Value.AdditionalProperties)
collection.Nullable = append(collection.Nullable, s.Value.Nullable)
collection.ReadOnly = append(collection.ReadOnly, s.Value.ReadOnly)
collection.WriteOnly = append(collection.WriteOnly, s.Value.WriteOnly)
}
return collection
}
Expand Down Expand Up @@ -731,18 +735,14 @@ func getCombinations(groups []openapi3.SchemaRefs) []openapi3.SchemaRefs {
func mergeCombinations(state *state, combinations []openapi3.SchemaRefs) ([]*openapi3.Schema, error) {
merged := []*openapi3.Schema{}
for _, combination := range combinations {
schemas := []*openapi3.Schema{}
for _, ref := range combination {
schemas = append(schemas, ref.Value)
}
schema, err := flattenSchemas(state, schemas)
schema, err := flattenSchemas(state, combination)
if err != nil {
continue
}
merged = append(merged, schema)
}
if len(merged) == 0 {
return merged, errors.New("unable to resolve combined schema")
return nil, errors.New("unable to resolve combined schema")
}
return merged, nil
}
Expand All @@ -752,23 +752,22 @@ func resolveNot(state *state, schema *openapi3.Schema, collection *SchemaCollect
if len(refs) == 0 {
return schema, nil
}
result := openapi3.SchemaRefs{}
for _, ref := range refs {
merged, err := mergeInternal(state, *ref)
if err != nil {
return &openapi3.Schema{}, err
return nil, err
}
ref.Value = merged
result = append(result, merged)
}

if len(refs) == 1 {
schema.Not = refs[0]
if len(result) == 1 {
schema.Not = result[0]
return schema, nil
}
schema.Not = &openapi3.SchemaRef{
Value: &openapi3.Schema{
AnyOf: refs,
},
}
schema.Not = openapi3.NewSchemaRef("", &openapi3.Schema{
AnyOf: result,
})
return schema, nil
}

Expand Down Expand Up @@ -813,7 +812,7 @@ func mergeSchemaRefs(state *state, sr []openapi3.SchemaRefs) ([]openapi3.SchemaR
if err != nil {
return result, err
}
r = append(r, &openapi3.SchemaRef{Value: merged})
r = append(r, openapi3.NewSchemaRef("", merged.Value))
}
result = append(result, r)
}
Expand Down

0 comments on commit a146d1e

Please sign in to comment.