diff --git a/core/parser.go b/core/parser.go index ee2d2cfbf1..300f4411a4 100644 --- a/core/parser.go +++ b/core/parser.go @@ -51,8 +51,8 @@ type Parser interface { NewFilterFromString(collectionType string, body string) (immutable.Option[request.Filter], error) // ParseSDL parses an SDL string into a set of collection descriptions. - ParseSDL(ctx context.Context, schemaString string) ([]client.CollectionDescription, error) + ParseSDL(ctx context.Context, schemaString string) ([]client.CollectionDefinition, error) // Adds the given schema to this parser's model. - SetSchema(ctx context.Context, txn datastore.Txn, collections []client.CollectionDescription) error + SetSchema(ctx context.Context, txn datastore.Txn, collections []client.CollectionDefinition) error } diff --git a/db/collection.go b/db/collection.go index 3c789ac224..330bfb8d88 100644 --- a/db/collection.go +++ b/db/collection.go @@ -584,12 +584,17 @@ func (db *db) setDefaultSchemaVersion( return err } - cols, err := db.getCollectionDescriptions(ctx, txn) + cols, err := db.getAllCollections(ctx, txn) if err != nil { return err } - return db.parser.SetSchema(ctx, txn, cols) + definitions := make([]client.CollectionDefinition, len(cols)) + for i, col := range cols { + definitions[i] = col + } + + return db.parser.SetSchema(ctx, txn, definitions) } func (db *db) setDefaultSchemaVersionExplicit( diff --git a/db/schema.go b/db/schema.go index 910f44f8c1..ff583c7cf4 100644 --- a/db/schema.go +++ b/db/schema.go @@ -39,24 +39,29 @@ func (db *db) addSchema( txn datastore.Txn, schemaString string, ) ([]client.CollectionDescription, error) { - existingDescriptions, err := db.getCollectionDescriptions(ctx, txn) + existingCollections, err := db.getAllCollections(ctx, txn) if err != nil { return nil, err } + existingDefinitions := make([]client.CollectionDefinition, len(existingCollections)) + for i := range existingCollections { + existingDefinitions[i] = existingCollections[i] + } + newDescriptions, err := db.parser.ParseSDL(ctx, schemaString) if err != nil { return nil, err } - err = db.parser.SetSchema(ctx, txn, append(existingDescriptions, newDescriptions...)) + err = db.parser.SetSchema(ctx, txn, append(existingDefinitions, newDescriptions...)) if err != nil { return nil, err } returnDescriptions := make([]client.CollectionDescription, len(newDescriptions)) for i, desc := range newDescriptions { - col, err := db.createCollection(ctx, txn, desc) + col, err := db.createCollection(ctx, txn, desc.Description()) if err != nil { return nil, err } @@ -67,12 +72,17 @@ func (db *db) addSchema( } func (db *db) loadSchema(ctx context.Context, txn datastore.Txn) error { - descriptions, err := db.getCollectionDescriptions(ctx, txn) + collections, err := db.getAllCollections(ctx, txn) if err != nil { return err } - return db.parser.SetSchema(ctx, txn, descriptions) + definitions := make([]client.CollectionDefinition, len(collections)) + for i := range collections { + definitions[i] = collections[i] + } + + return db.parser.SetSchema(ctx, txn, definitions) } func (db *db) getCollectionDescriptions( @@ -138,21 +148,26 @@ func (db *db) patchSchema(ctx context.Context, txn datastore.Txn, patchString st return err } - newDescriptions := []client.CollectionDescription{} + newCollections := []client.CollectionDefinition{} for _, desc := range newDescriptionsByName { - newDescriptions = append(newDescriptions, desc) + col, err := db.newCollection(desc) + if err != nil { + return err + } + + newCollections = append(newCollections, col) } - for i, desc := range newDescriptions { - col, err := db.updateCollection(ctx, txn, collectionsByName, newDescriptionsByName, desc, setAsDefaultVersion) + for i, col := range newCollections { + col, err := db.updateCollection(ctx, txn, collectionsByName, newDescriptionsByName, col.Description(), setAsDefaultVersion) if err != nil { return err } - newDescriptions[i] = col.Description() + newCollections[i] = col } - return db.parser.SetSchema(ctx, txn, newDescriptions) + return db.parser.SetSchema(ctx, txn, newCollections) } func (db *db) getCollectionsByName( diff --git a/request/graphql/parser.go b/request/graphql/parser.go index ddd13d9e62..743c3eab97 100644 --- a/request/graphql/parser.go +++ b/request/graphql/parser.go @@ -104,13 +104,13 @@ func (p *parser) Parse(ast *ast.Document) (*request.Request, []error) { } func (p *parser) ParseSDL(ctx context.Context, schemaString string) ( - []client.CollectionDescription, + []client.CollectionDefinition, error, ) { return schema.FromString(ctx, schemaString) } -func (p *parser) SetSchema(ctx context.Context, txn datastore.Txn, collections []client.CollectionDescription) error { +func (p *parser) SetSchema(ctx context.Context, txn datastore.Txn, collections []client.CollectionDefinition) error { schemaManager, err := schema.NewSchemaManager() if err != nil { return err diff --git a/request/graphql/schema/collection.go b/request/graphql/schema/collection.go index 00287c4454..d0daf499fd 100644 --- a/request/graphql/schema/collection.go +++ b/request/graphql/schema/collection.go @@ -24,9 +24,32 @@ import ( "github.com/graphql-go/graphql/language/source" ) +type collectionDefinition struct { + collection client.CollectionDescription + schema client.SchemaDescription +} + +var _ client.CollectionDefinition = (*collectionDefinition)(nil) + +func (c *collectionDefinition) Description() client.CollectionDescription { + return c.collection +} +func (c *collectionDefinition) Name() string { + return c.collection.Name +} +func (c *collectionDefinition) Schema() client.SchemaDescription { + return c.schema +} +func (c *collectionDefinition) ID() uint32 { + return c.collection.ID +} +func (c *collectionDefinition) SchemaID() string { + return c.schema.SchemaID +} + // FromString parses a GQL SDL string into a set of collection descriptions. func FromString(ctx context.Context, schemaString string) ( - []client.CollectionDescription, + []client.CollectionDefinition, error, ) { source := source.NewSource(&source.Source{ @@ -47,11 +70,11 @@ func FromString(ctx context.Context, schemaString string) ( // fromAst parses a GQL AST into a set of collection descriptions. func fromAst(ctx context.Context, doc *ast.Document) ( - []client.CollectionDescription, + []client.CollectionDefinition, error, ) { relationManager := NewRelationManager() - descriptions := []client.CollectionDescription{} + descriptions := []collectionDefinition{} for _, def := range doc.Definitions { switch defType := def.(type) { @@ -77,7 +100,12 @@ func fromAst(ctx context.Context, doc *ast.Document) ( return nil, err } - return descriptions, nil + definitions := make([]client.CollectionDefinition, len(descriptions)) + for i := range descriptions { + definitions[i] = &descriptions[i] + } + + return definitions, nil } // fromAstDefinition parses a AST object definition into a set of collection descriptions. @@ -85,7 +113,7 @@ func fromAstDefinition( ctx context.Context, relationManager *RelationManager, def *ast.ObjectDefinition, -) (client.CollectionDescription, error) { +) (collectionDefinition, error) { fieldDescriptions := []client.FieldDescription{ { Name: request.KeyFieldName, @@ -98,7 +126,7 @@ func fromAstDefinition( for _, field := range def.Fields { tmpFieldsDescriptions, err := fieldsFromAST(field, relationManager, def) if err != nil { - return client.CollectionDescription{}, err + return collectionDefinition{}, err } fieldDescriptions = append(fieldDescriptions, tmpFieldsDescriptions...) @@ -107,7 +135,7 @@ func fromAstDefinition( if directive.Name.Value == types.IndexDirectiveLabel { index, err := fieldIndexFromAST(field, directive) if err != nil { - return client.CollectionDescription{}, err + return collectionDefinition{}, err } indexDescriptions = append(indexDescriptions, index) } @@ -129,19 +157,26 @@ func fromAstDefinition( if directive.Name.Value == types.IndexDirectiveLabel { index, err := indexFromAST(directive) if err != nil { - return client.CollectionDescription{}, err + return collectionDefinition{}, err } indexDescriptions = append(indexDescriptions, index) } } - return client.CollectionDescription{ - Name: def.Name.Value, - Schema: client.SchemaDescription{ + return collectionDefinition{ + collection: client.CollectionDescription{ + Name: def.Name.Value, + Indexes: indexDescriptions, + Schema: client.SchemaDescription{ // temp + Name: def.Name.Value, + Fields: fieldDescriptions, + }, + }, + + schema: client.SchemaDescription{ Name: def.Name.Value, Fields: fieldDescriptions, }, - Indexes: indexDescriptions, }, nil } @@ -424,9 +459,9 @@ func getRelationshipName( return genRelationName(hostName, targetName) } -func finalizeRelations(relationManager *RelationManager, descriptions []client.CollectionDescription) error { +func finalizeRelations(relationManager *RelationManager, descriptions []collectionDefinition) error { for _, description := range descriptions { - for i, field := range description.Schema.Fields { + for i, field := range description.schema.Fields { if field.RelationType == 0 || field.RelationType&client.Relation_Type_INTERNAL_ID != 0 { continue } @@ -447,7 +482,7 @@ func finalizeRelations(relationManager *RelationManager, descriptions []client.C } field.RelationType = rel.Kind() | fieldRelationType - description.Schema.Fields[i] = field + description.schema.Fields[i] = field } } diff --git a/request/graphql/schema/generate.go b/request/graphql/schema/generate.go index e30693b3de..27c0beb9e6 100644 --- a/request/graphql/schema/generate.go +++ b/request/graphql/schema/generate.go @@ -47,7 +47,7 @@ func (m *SchemaManager) NewGenerator() *Generator { // Generate generates the query-op and mutation-op type definitions from // the given CollectionDescriptions. -func (g *Generator) Generate(ctx context.Context, collections []client.CollectionDescription) ([]*gql.Object, error) { +func (g *Generator) Generate(ctx context.Context, collections []client.CollectionDefinition) ([]*gql.Object, error) { typeMapBeforeMutation := g.manager.schema.TypeMap() typesBeforeMutation := make(map[string]any, len(typeMapBeforeMutation)) @@ -79,7 +79,7 @@ func (g *Generator) Generate(ctx context.Context, collections []client.Collectio // generate generates the query-op and mutation-op type definitions from // the given CollectionDescriptions. -func (g *Generator) generate(ctx context.Context, collections []client.CollectionDescription) ([]*gql.Object, error) { +func (g *Generator) generate(ctx context.Context, collections []client.CollectionDefinition) ([]*gql.Object, error) { // build base types defs, err := g.buildTypes(ctx, collections) if err != nil { @@ -354,7 +354,7 @@ func (g *Generator) createExpandedFieldList( // extract and return the correct gql.Object type(s) func (g *Generator) buildTypes( ctx context.Context, - collections []client.CollectionDescription, + collections []client.CollectionDefinition, ) ([]*gql.Object, error) { // @todo: Check for duplicate named defined types in the TypeMap // get all the defined types from the AST @@ -364,15 +364,15 @@ func (g *Generator) buildTypes( // Copy the loop variable before usage within the loop or it // will be reassigned before the thunk is run collection := c - fieldDescriptions := collection.Schema.Fields + fieldDescriptions := collection.Schema().Fields // check if type exists - if _, ok := g.manager.schema.TypeMap()[collection.Name]; ok { - return nil, NewErrSchemaTypeAlreadyExist(collection.Name) + if _, ok := g.manager.schema.TypeMap()[collection.Name()]; ok { + return nil, NewErrSchemaTypeAlreadyExist(collection.Name()) } objconf := gql.ObjectConfig{ - Name: collection.Name, + Name: collection.Name(), } // Wrap field definition in a thunk so we can @@ -435,9 +435,9 @@ func (g *Generator) buildTypes( Type: gql.Boolean, } - gqlType, ok := g.manager.schema.TypeMap()[collection.Name] + gqlType, ok := g.manager.schema.TypeMap()[collection.Name()] if !ok { - return nil, NewErrObjectNotFoundDuringThunk(collection.Name) + return nil, NewErrObjectNotFoundDuringThunk(collection.Name()) } fields[request.GroupFieldName] = &gql.Field{ diff --git a/request/graphql/schema/index_test.go b/request/graphql/schema/index_test.go index 379b84647d..7b7b43ecb4 100644 --- a/request/graphql/schema/index_test.go +++ b/request/graphql/schema/index_test.go @@ -276,9 +276,9 @@ func parseIndexAndTest(t *testing.T, testCase indexTestCase) { cols, err := FromString(ctx, testCase.sdl) assert.NoError(t, err, testCase.description) assert.Equal(t, len(cols), 1, testCase.description) - assert.Equal(t, len(cols[0].Indexes), len(testCase.targetDescriptions), testCase.description) + assert.Equal(t, len(cols[0].Description().Indexes), len(testCase.targetDescriptions), testCase.description) - for i, d := range cols[0].Indexes { + for i, d := range cols[0].Description().Indexes { assert.Equal(t, testCase.targetDescriptions[i], d, testCase.description) } }