diff --git a/cli/collection.go b/cli/collection.go index e21c29283b..e0a3118f65 100644 --- a/cli/collection.go +++ b/cli/collection.go @@ -43,15 +43,17 @@ func MakeCollectionCommand(cfg *config.Config) *cobra.Command { store := mustGetStoreContext(cmd) var col client.Collection + var cols []client.Collection switch { case versionID != "": - col, err = store.GetCollectionByVersionID(cmd.Context(), versionID) + cols, err = store.GetCollectionsByVersionID(cmd.Context(), versionID) case schemaID != "": - col, err = store.GetCollectionBySchemaID(cmd.Context(), schemaID) + cols, err = store.GetCollectionsBySchemaID(cmd.Context(), schemaID) case name != "": col, err = store.GetCollectionByName(cmd.Context(), name) + cols = []client.Collection{col} default: return nil @@ -60,6 +62,38 @@ func MakeCollectionCommand(cfg *config.Config) *cobra.Command { if err != nil { return err } + + if schemaID != "" && versionID != "" && len(cols) > 0 { + if cols[0].SchemaID() != schemaID { + // If the a versionID has been provided that does not pair up with the given schemaID + // we should error and let the user know they have provided impossible params. + // We only need to check the first item - they will all be the same. + return NewErrSchemaVersionNotOfSchema(schemaID, versionID) + } + } + + if name != "" { + // Multiple params may have been specified, and in some cases both are needed. + // For example if a schema version and a collection name have been provided, + // we need to ensure that a collection at the requested version is returned. + // Likewise we need to ensure that if a collection name and schema id are provided, + // but there are none matching both, that nothing is returned. + fetchedCols := cols + cols = nil + for _, c := range fetchedCols { + if c.Name() == name { + cols = append(cols, c) + break + } + } + } + + if len(cols) != 1 { + // If more than one collection matches the given criteria we cannot set the context collection + return nil + } + col = cols[0] + if tx, ok := cmd.Context().Value(txContextKey).(datastore.Txn); ok { col = col.WithTxn(tx) } diff --git a/cli/errors.go b/cli/errors.go index a7d6cbd26b..bd42dfa153 100644 --- a/cli/errors.go +++ b/cli/errors.go @@ -14,17 +14,29 @@ import ( "github.com/sourcenetwork/defradb/errors" ) -const errInvalidLensConfig = "invalid lens configuration" +const ( + errInvalidLensConfig string = "invalid lens configuration" + errSchemaVersionNotOfSchema string = "the given schema version is from a different schema" +) var ( - ErrNoDocOrFile = errors.New("document or file must be defined") - ErrInvalidDocument = errors.New("invalid document") - ErrNoDocKeyOrFilter = errors.New("document key or filter must be defined") - ErrInvalidExportFormat = errors.New("invalid export format") - ErrNoLensConfig = errors.New("lens config cannot be empty") - ErrInvalidLensConfig = errors.New("invalid lens configuration") + ErrNoDocOrFile = errors.New("document or file must be defined") + ErrInvalidDocument = errors.New("invalid document") + ErrNoDocKeyOrFilter = errors.New("document key or filter must be defined") + ErrInvalidExportFormat = errors.New("invalid export format") + ErrNoLensConfig = errors.New("lens config cannot be empty") + ErrInvalidLensConfig = errors.New("invalid lens configuration") + ErrSchemaVersionNotOfSchema = errors.New(errSchemaVersionNotOfSchema) ) func NewErrInvalidLensConfig(inner error) error { return errors.Wrap(errInvalidLensConfig, inner) } + +func NewErrSchemaVersionNotOfSchema(schemaID string, schemaVersionID string) error { + return errors.New( + errSchemaVersionNotOfSchema, + errors.NewKV("SchemaID", schemaID), + errors.NewKV("SchemaVersionID", schemaVersionID), + ) +} diff --git a/client/db.go b/client/db.go index 57e96a3416..81a47b9e24 100644 --- a/client/db.go +++ b/client/db.go @@ -147,15 +147,15 @@ type Store interface { // If no matching collection is found an error will be returned. GetCollectionByName(context.Context, CollectionName) (Collection, error) - // GetCollectionBySchemaID attempts to retrieve a collection matching the given schema ID. + // GetCollectionsBySchemaID attempts to retrieve all collections using the given schema ID. // - // If no matching collection is found an error will be returned. - GetCollectionBySchemaID(context.Context, string) (Collection, error) + // If no matching collection is found an empty set will be returned. + GetCollectionsBySchemaID(context.Context, string) ([]Collection, error) - // GetCollectionBySchemaID attempts to retrieve a collection matching the given schema version ID. + // GetCollectionsByVersionID attempts to retrieve all collections using the given schema version ID. // - // If no matching collection is found an error will be returned. - GetCollectionByVersionID(context.Context, string) (Collection, error) + // If no matching collections are found an empty set will be returned. + GetCollectionsByVersionID(context.Context, string) ([]Collection, error) // GetAllCollections returns all the collections and their descriptions that currently exist within // this [Store]. diff --git a/client/errors.go b/client/errors.go index ad1ad0027a..9454c13768 100644 --- a/client/errors.go +++ b/client/errors.go @@ -23,6 +23,7 @@ const ( errUninitializeProperty string = "invalid state, required property is uninitialized" errMaxTxnRetries string = "reached maximum transaction reties" errRelationOneSided string = "relation must be defined on both schemas" + errCollectionNotFound string = "collection not found" ) // Errors returnable from this package. @@ -45,6 +46,7 @@ var ( ErrInvalidDocKeyVersion = errors.New("invalid DocKey version") ErrMaxTxnRetries = errors.New(errMaxTxnRetries) ErrRelationOneSided = errors.New(errRelationOneSided) + ErrCollectionNotFound = errors.New(errCollectionNotFound) ) // NewErrFieldNotExist returns an error indicating that the given field does not exist. @@ -107,3 +109,17 @@ func NewErrRelationOneSided(fieldName string, typeName string) error { errors.NewKV("Type", typeName), ) } + +func NewErrCollectionNotFoundForSchemaVersion(schemaVersionID string) error { + return errors.New( + errCollectionNotFound, + errors.NewKV("SchemaVersionID", schemaVersionID), + ) +} + +func NewErrCollectionNotFoundForSchema(schemaID string) error { + return errors.New( + errCollectionNotFound, + errors.NewKV("SchemaID", schemaID), + ) +} diff --git a/client/mocks/db.go b/client/mocks/db.go index 42bc303963..bddcd5049f 100644 --- a/client/mocks/db.go +++ b/client/mocks/db.go @@ -493,20 +493,20 @@ func (_c *DB_GetCollectionByName_Call) RunAndReturn(run func(context.Context, st return _c } -// GetCollectionBySchemaID provides a mock function with given fields: _a0, _a1 -func (_m *DB) GetCollectionBySchemaID(_a0 context.Context, _a1 string) (client.Collection, error) { +// GetCollectionsBySchemaID provides a mock function with given fields: _a0, _a1 +func (_m *DB) GetCollectionsBySchemaID(_a0 context.Context, _a1 string) ([]client.Collection, error) { ret := _m.Called(_a0, _a1) - var r0 client.Collection + var r0 []client.Collection var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string) (client.Collection, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, string) ([]client.Collection, error)); ok { return rf(_a0, _a1) } - if rf, ok := ret.Get(0).(func(context.Context, string) client.Collection); ok { + if rf, ok := ret.Get(0).(func(context.Context, string) []client.Collection); ok { r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(client.Collection) + r0 = ret.Get(0).([]client.Collection) } } @@ -519,49 +519,49 @@ func (_m *DB) GetCollectionBySchemaID(_a0 context.Context, _a1 string) (client.C return r0, r1 } -// DB_GetCollectionBySchemaID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCollectionBySchemaID' -type DB_GetCollectionBySchemaID_Call struct { +// DB_GetCollectionsBySchemaID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCollectionsBySchemaID' +type DB_GetCollectionsBySchemaID_Call struct { *mock.Call } -// GetCollectionBySchemaID is a helper method to define mock.On call +// GetCollectionsBySchemaID is a helper method to define mock.On call // - _a0 context.Context // - _a1 string -func (_e *DB_Expecter) GetCollectionBySchemaID(_a0 interface{}, _a1 interface{}) *DB_GetCollectionBySchemaID_Call { - return &DB_GetCollectionBySchemaID_Call{Call: _e.mock.On("GetCollectionBySchemaID", _a0, _a1)} +func (_e *DB_Expecter) GetCollectionsBySchemaID(_a0 interface{}, _a1 interface{}) *DB_GetCollectionsBySchemaID_Call { + return &DB_GetCollectionsBySchemaID_Call{Call: _e.mock.On("GetCollectionsBySchemaID", _a0, _a1)} } -func (_c *DB_GetCollectionBySchemaID_Call) Run(run func(_a0 context.Context, _a1 string)) *DB_GetCollectionBySchemaID_Call { +func (_c *DB_GetCollectionsBySchemaID_Call) Run(run func(_a0 context.Context, _a1 string)) *DB_GetCollectionsBySchemaID_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(string)) }) return _c } -func (_c *DB_GetCollectionBySchemaID_Call) Return(_a0 client.Collection, _a1 error) *DB_GetCollectionBySchemaID_Call { +func (_c *DB_GetCollectionsBySchemaID_Call) Return(_a0 []client.Collection, _a1 error) *DB_GetCollectionsBySchemaID_Call { _c.Call.Return(_a0, _a1) return _c } -func (_c *DB_GetCollectionBySchemaID_Call) RunAndReturn(run func(context.Context, string) (client.Collection, error)) *DB_GetCollectionBySchemaID_Call { +func (_c *DB_GetCollectionsBySchemaID_Call) RunAndReturn(run func(context.Context, string) ([]client.Collection, error)) *DB_GetCollectionsBySchemaID_Call { _c.Call.Return(run) return _c } -// GetCollectionByVersionID provides a mock function with given fields: _a0, _a1 -func (_m *DB) GetCollectionByVersionID(_a0 context.Context, _a1 string) (client.Collection, error) { +// GetCollectionsByVersionID provides a mock function with given fields: _a0, _a1 +func (_m *DB) GetCollectionsByVersionID(_a0 context.Context, _a1 string) ([]client.Collection, error) { ret := _m.Called(_a0, _a1) - var r0 client.Collection + var r0 []client.Collection var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string) (client.Collection, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, string) ([]client.Collection, error)); ok { return rf(_a0, _a1) } - if rf, ok := ret.Get(0).(func(context.Context, string) client.Collection); ok { + if rf, ok := ret.Get(0).(func(context.Context, string) []client.Collection); ok { r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(client.Collection) + r0 = ret.Get(0).([]client.Collection) } } @@ -574,31 +574,31 @@ func (_m *DB) GetCollectionByVersionID(_a0 context.Context, _a1 string) (client. return r0, r1 } -// DB_GetCollectionByVersionID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCollectionByVersionID' -type DB_GetCollectionByVersionID_Call struct { +// DB_GetCollectionsByVersionID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCollectionsByVersionID' +type DB_GetCollectionsByVersionID_Call struct { *mock.Call } -// GetCollectionByVersionID is a helper method to define mock.On call +// GetCollectionsByVersionID is a helper method to define mock.On call // - _a0 context.Context // - _a1 string -func (_e *DB_Expecter) GetCollectionByVersionID(_a0 interface{}, _a1 interface{}) *DB_GetCollectionByVersionID_Call { - return &DB_GetCollectionByVersionID_Call{Call: _e.mock.On("GetCollectionByVersionID", _a0, _a1)} +func (_e *DB_Expecter) GetCollectionsByVersionID(_a0 interface{}, _a1 interface{}) *DB_GetCollectionsByVersionID_Call { + return &DB_GetCollectionsByVersionID_Call{Call: _e.mock.On("GetCollectionsByVersionID", _a0, _a1)} } -func (_c *DB_GetCollectionByVersionID_Call) Run(run func(_a0 context.Context, _a1 string)) *DB_GetCollectionByVersionID_Call { +func (_c *DB_GetCollectionsByVersionID_Call) Run(run func(_a0 context.Context, _a1 string)) *DB_GetCollectionsByVersionID_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(string)) }) return _c } -func (_c *DB_GetCollectionByVersionID_Call) Return(_a0 client.Collection, _a1 error) *DB_GetCollectionByVersionID_Call { +func (_c *DB_GetCollectionsByVersionID_Call) Return(_a0 []client.Collection, _a1 error) *DB_GetCollectionsByVersionID_Call { _c.Call.Return(_a0, _a1) return _c } -func (_c *DB_GetCollectionByVersionID_Call) RunAndReturn(run func(context.Context, string) (client.Collection, error)) *DB_GetCollectionByVersionID_Call { +func (_c *DB_GetCollectionsByVersionID_Call) RunAndReturn(run func(context.Context, string) ([]client.Collection, error)) *DB_GetCollectionsByVersionID_Call { _c.Call.Return(run) return _c } diff --git a/db/collection.go b/db/collection.go index f7da81e59d..feeff441c5 100644 --- a/db/collection.go +++ b/db/collection.go @@ -66,11 +66,11 @@ type collection struct { // CollectionOptions object. // NewCollection returns a pointer to a newly instanciated DB Collection -func (db *db) newCollection(desc client.CollectionDescription, schema client.SchemaDescription) (*collection, error) { +func (db *db) newCollection(desc client.CollectionDescription, schema client.SchemaDescription) *collection { return &collection{ db: db, def: client.CollectionDefinition{Description: desc, Schema: schema}, - }, nil + } } // newFetcher returns a new fetcher instance for this collection. @@ -149,11 +149,7 @@ func (db *db) createCollection( return nil, err } - col, err := db.newCollection(desc, schema) - if err != nil { - return nil, err - } - + col := db.newCollection(desc, schema) for _, index := range desc.Indexes { if _, err := col.createIndex(ctx, txn, index); err != nil { return nil, err @@ -490,10 +486,14 @@ func (db *db) setDefaultSchemaVersion( txn datastore.Txn, schemaVersionID string, ) error { - col, err := db.getCollectionByVersionID(ctx, txn, schemaVersionID) + // This call makes no sense at the moment, but needs to be done due to the bad way we currently store + // collections. + // https://github.com/sourcenetwork/defradb/issues/1964 + collections, err := db.getCollectionsByVersionID(ctx, txn, schemaVersionID) if err != nil { return err } + col := collections[0] desc := col.Description() err = db.setDefaultSchemaVersionExplicit(ctx, txn, desc.Name, col.Schema().SchemaID, schemaVersionID) @@ -531,14 +531,14 @@ func (db *db) setDefaultSchemaVersionExplicit( return txn.Systemstore().Put(ctx, collectionKey.ToDS(), []byte(schemaVersionID)) } -// getCollectionByVersionId returns the [*collection] at the given [schemaVersionId] version. +// getCollectionsByVersionId returns the [*collection]s at the given [schemaVersionId] version. // -// Will return an error if the given key is empty, or not found. -func (db *db) getCollectionByVersionID( +// Will return an error if the given key is empty, or if none are found. +func (db *db) getCollectionsByVersionID( ctx context.Context, txn datastore.Txn, schemaVersionId string, -) (*collection, error) { +) ([]*collection, error) { if schemaVersionId == "" { return nil, ErrSchemaVersionIDEmpty } @@ -560,20 +560,14 @@ func (db *db) getCollectionByVersionID( return nil, err } - col := &collection{ - db: db, - def: client.CollectionDefinition{ - Description: desc, - Schema: schema, - }, - } + col := db.newCollection(desc, schema) err = col.loadIndexes(ctx, txn) if err != nil { return nil, err } - return col, nil + return []*collection{col}, nil } // getCollectionByName returns an existing collection within the database. @@ -589,15 +583,26 @@ func (db *db) getCollectionByName(ctx context.Context, txn datastore.Txn, name s } schemaVersionId := string(buf) - return db.getCollectionByVersionID(ctx, txn, schemaVersionId) + // This call makes no sense at the moment, but needs to be done due to the bad way we currently store + // collections. + // https://github.com/sourcenetwork/defradb/issues/1964 + cols, err := db.getCollectionsByVersionID(ctx, txn, schemaVersionId) + if err != nil { + return nil, err + } + if len(cols) == 0 { + return nil, NewErrFailedToGetCollection(schemaVersionId, err) + } + + return cols[0], nil } -// getCollectionBySchemaID returns an existing collection using the schema hash ID. -func (db *db) getCollectionBySchemaID( +// getCollectionsBySchemaID returns all existing collections using the schema hash ID. +func (db *db) getCollectionsBySchemaID( ctx context.Context, txn datastore.Txn, schemaID string, -) (client.Collection, error) { +) ([]client.Collection, error) { if schemaID == "" { return nil, ErrSchemaIDEmpty } @@ -609,7 +614,17 @@ func (db *db) getCollectionBySchemaID( } schemaVersionId := string(buf) - return db.getCollectionByVersionID(ctx, txn, schemaVersionId) + cols, err := db.getCollectionsByVersionID(ctx, txn, schemaVersionId) + if err != nil { + return nil, err + } + + collections := make([]client.Collection, len(cols)) + for i, col := range cols { + collections[i] = col + } + + return collections, nil } // getAllCollections gets all the currently defined collections. @@ -635,11 +650,18 @@ func (db *db) getAllCollections(ctx context.Context, txn datastore.Txn) ([]clien } schemaVersionId := string(res.Value) - col, err := db.getCollectionByVersionID(ctx, txn, schemaVersionId) + // This call makes no sense at the moment, but needs to be done due to the bad way we currently store + // collections. + // https://github.com/sourcenetwork/defradb/issues/1964 + collections, err := db.getCollectionsByVersionID(ctx, txn, schemaVersionId) if err != nil { return nil, NewErrFailedToGetCollection(schemaVersionId, err) } - cols = append(cols, col) + if len(collections) == 0 { + return nil, NewErrFailedToGetCollection(schemaVersionId, err) + } + + cols = append(cols, collections[0]) } return cols, nil diff --git a/db/txn_db.go b/db/txn_db.go index 0627f8ebc8..ba8b20582c 100644 --- a/db/txn_db.go +++ b/db/txn_db.go @@ -79,46 +79,84 @@ func (db *explicitTxnDB) GetCollectionByName(ctx context.Context, name string) ( return db.getCollectionByName(ctx, db.txn, name) } -// GetCollectionBySchemaID returns an existing collection using the schema hash ID. -func (db *implicitTxnDB) GetCollectionBySchemaID( +// GetCollectionsBySchemaID attempts to retrieve all collections using the given schema ID. +// +// If no matching collection is found an empty set will be returned. +func (db *implicitTxnDB) GetCollectionsBySchemaID( ctx context.Context, schemaID string, -) (client.Collection, error) { +) ([]client.Collection, error) { txn, err := db.NewTxn(ctx, true) if err != nil { return nil, err } defer txn.Discard(ctx) - return db.getCollectionBySchemaID(ctx, txn, schemaID) + cols, err := db.getCollectionsBySchemaID(ctx, txn, schemaID) + if err != nil { + return nil, err + } + + return cols, nil } -// GetCollectionBySchemaID returns an existing collection using the schema hash ID. -func (db *explicitTxnDB) GetCollectionBySchemaID( +// GetCollectionsBySchemaID attempts to retrieve all collections using the given schema ID. +// +// If no matching collection is found an empty set will be returned. +func (db *explicitTxnDB) GetCollectionsBySchemaID( ctx context.Context, schemaID string, -) (client.Collection, error) { - return db.getCollectionBySchemaID(ctx, db.txn, schemaID) +) ([]client.Collection, error) { + cols, err := db.getCollectionsBySchemaID(ctx, db.txn, schemaID) + if err != nil { + return nil, err + } + + return cols, nil } -// GetCollectionByVersionID returns an existing collection using the schema version hash ID. -func (db *implicitTxnDB) GetCollectionByVersionID( +// GetCollectionsByVersionID attempts to retrieve all collections using the given schema version ID. +// +// If no matching collections are found an empty set will be returned. +func (db *implicitTxnDB) GetCollectionsByVersionID( ctx context.Context, schemaVersionID string, -) (client.Collection, error) { +) ([]client.Collection, error) { txn, err := db.NewTxn(ctx, true) if err != nil { return nil, err } defer txn.Discard(ctx) - return db.getCollectionByVersionID(ctx, txn, schemaVersionID) + cols, err := db.getCollectionsByVersionID(ctx, txn, schemaVersionID) + if err != nil { + return nil, err + } + + collections := make([]client.Collection, len(cols)) + for i, col := range cols { + collections[i] = col + } + + return collections, nil } -// GetCollectionByVersionID returns an existing collection using the schema version hash ID. -func (db *explicitTxnDB) GetCollectionByVersionID( +// GetCollectionsByVersionID attempts to retrieve all collections using the given schema version ID. +// +// If no matching collections are found an empty set will be returned. +func (db *explicitTxnDB) GetCollectionsByVersionID( ctx context.Context, schemaVersionID string, -) (client.Collection, error) { - return db.getCollectionByVersionID(ctx, db.txn, schemaVersionID) +) ([]client.Collection, error) { + cols, err := db.getCollectionsByVersionID(ctx, db.txn, schemaVersionID) + if err != nil { + return nil, err + } + + collections := make([]client.Collection, len(cols)) + for i, col := range cols { + collections[i] = col + } + + return collections, nil } // GetAllCollections gets all the currently defined collections. diff --git a/http/client.go b/http/client.go index 4ef4c92119..bc468d1d96 100644 --- a/http/client.go +++ b/http/client.go @@ -186,7 +186,7 @@ func (c *Client) GetCollectionByName(ctx context.Context, name client.Collection return &Collection{c.http, definition}, nil } -func (c *Client) GetCollectionBySchemaID(ctx context.Context, schemaId string) (client.Collection, error) { +func (c *Client) GetCollectionsBySchemaID(ctx context.Context, schemaId string) ([]client.Collection, error) { methodURL := c.http.baseURL.JoinPath("collections") methodURL.RawQuery = url.Values{"schema_id": []string{schemaId}}.Encode() @@ -194,14 +194,18 @@ func (c *Client) GetCollectionBySchemaID(ctx context.Context, schemaId string) ( if err != nil { return nil, err } - var definition client.CollectionDefinition - if err := c.http.requestJson(req, &definition); err != nil { + var descriptions []client.CollectionDefinition + if err := c.http.requestJson(req, &descriptions); err != nil { return nil, err } - return &Collection{c.http, definition}, nil + collections := make([]client.Collection, len(descriptions)) + for i, d := range descriptions { + collections[i] = &Collection{c.http, d} + } + return collections, nil } -func (c *Client) GetCollectionByVersionID(ctx context.Context, versionId string) (client.Collection, error) { +func (c *Client) GetCollectionsByVersionID(ctx context.Context, versionId string) ([]client.Collection, error) { methodURL := c.http.baseURL.JoinPath("collections") methodURL.RawQuery = url.Values{"version_id": []string{versionId}}.Encode() @@ -209,11 +213,15 @@ func (c *Client) GetCollectionByVersionID(ctx context.Context, versionId string) if err != nil { return nil, err } - var definition client.CollectionDefinition - if err := c.http.requestJson(req, &definition); err != nil { + var descriptions []client.CollectionDefinition + if err := c.http.requestJson(req, &descriptions); err != nil { return nil, err } - return &Collection{c.http, definition}, nil + collections := make([]client.Collection, len(descriptions)) + for i, d := range descriptions { + collections[i] = &Collection{c.http, d} + } + return collections, nil } func (c *Client) GetAllCollections(ctx context.Context) ([]client.Collection, error) { diff --git a/http/handler_store.go b/http/handler_store.go index 0b47069afc..91f1b3cbd4 100644 --- a/http/handler_store.go +++ b/http/handler_store.go @@ -118,19 +118,27 @@ func (s *storeHandler) GetCollection(rw http.ResponseWriter, req *http.Request) } responseJSON(rw, http.StatusOK, col.Definition()) case req.URL.Query().Has("schema_id"): - col, err := store.GetCollectionBySchemaID(req.Context(), req.URL.Query().Get("schema_id")) + cols, err := store.GetCollectionsBySchemaID(req.Context(), req.URL.Query().Get("schema_id")) if err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) return } - responseJSON(rw, http.StatusOK, col.Definition()) + colDesc := make([]client.CollectionDefinition, len(cols)) + for i, col := range cols { + colDesc[i] = col.Definition() + } + responseJSON(rw, http.StatusOK, colDesc) case req.URL.Query().Has("version_id"): - col, err := store.GetCollectionByVersionID(req.Context(), req.URL.Query().Get("version_id")) + cols, err := store.GetCollectionsByVersionID(req.Context(), req.URL.Query().Get("version_id")) if err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) return } - responseJSON(rw, http.StatusOK, col.Definition()) + colDesc := make([]client.CollectionDefinition, len(cols)) + for i, col := range cols { + colDesc[i] = col.Definition() + } + responseJSON(rw, http.StatusOK, colDesc) default: cols, err := store.GetAllCollections(req.Context()) if err != nil { diff --git a/net/peer_collection.go b/net/peer_collection.go index 91e3f66154..86b5d9b483 100644 --- a/net/peer_collection.go +++ b/net/peer_collection.go @@ -31,11 +31,11 @@ func (p *Peer) AddP2PCollections(ctx context.Context, collectionIDs []string) er // first let's make sure the collections actually exists storeCollections := []client.Collection{} for _, col := range collectionIDs { - storeCol, err := p.db.WithTxn(txn).GetCollectionBySchemaID(p.ctx, col) + storeCol, err := p.db.WithTxn(txn).GetCollectionsBySchemaID(p.ctx, col) if err != nil { return err } - storeCollections = append(storeCollections, storeCol) + storeCollections = append(storeCollections, storeCol...) } // Ensure we can add all the collections to the store on the transaction @@ -93,11 +93,11 @@ func (p *Peer) RemoveP2PCollections(ctx context.Context, collectionIDs []string) // first let's make sure the collections actually exists storeCollections := []client.Collection{} for _, col := range collectionIDs { - storeCol, err := p.db.WithTxn(txn).GetCollectionBySchemaID(p.ctx, col) + storeCol, err := p.db.WithTxn(txn).GetCollectionsBySchemaID(p.ctx, col) if err != nil { return err } - storeCollections = append(storeCollections, storeCol) + storeCollections = append(storeCollections, storeCol...) } // Ensure we can remove all the collections to the store on the transaction diff --git a/net/server.go b/net/server.go index 7322d845ad..c1bcc7b983 100644 --- a/net/server.go +++ b/net/server.go @@ -260,10 +260,16 @@ func (s *server) PushLog(ctx context.Context, req *pb.PushLogRequest) (*pb.PushL defer txn.Discard(ctx) store := s.db.WithTxn(txn) - col, err := store.GetCollectionBySchemaID(ctx, schemaID) + // Currently a schema is the best way we have to link a push log request to a collection, + // this will change with https://github.com/sourcenetwork/defradb/issues/1085 + cols, err := store.GetCollectionsBySchemaID(ctx, schemaID) if err != nil { return nil, errors.Wrap(fmt.Sprintf("Failed to get collection from schemaID %s", schemaID), err) } + if len(cols) == 0 { + return nil, client.NewErrCollectionNotFoundForSchema(schemaID) + } + col := cols[0] // Create a new DAG service with the current transaction var getter format.NodeGetter = s.peer.newDAGSyncerTxn(txn) diff --git a/planner/commit.go b/planner/commit.go index c2cff28c30..b4fd3ed3c1 100644 --- a/planner/commit.go +++ b/planner/commit.go @@ -328,12 +328,17 @@ func (n *dagScanNode) dagBlockToNodeDoc(block blocks.Block) (core.Doc, []*ipld.L fieldName = nil default: - c, err := n.planner.db.GetCollectionByVersionID(n.planner.ctx, schemaVersionId) + cols, err := n.planner.db.GetCollectionsByVersionID(n.planner.ctx, schemaVersionId) if err != nil { return core.Doc{}, nil, err } + if len(cols) == 0 { + return core.Doc{}, nil, client.NewErrCollectionNotFoundForSchemaVersion(schemaVersionId) + } - field, ok := c.Schema().GetField(fieldName.(string)) + // Because we only care about the schema, we can safely take the first - the schema is the same + // for all in the set. + field, ok := cols[0].Schema().GetField(fieldName.(string)) if !ok { return core.Doc{}, nil, client.NewErrFieldNotExist(fieldName.(string)) } @@ -353,13 +358,19 @@ func (n *dagScanNode) dagBlockToNodeDoc(block blocks.Block) (core.Doc, []*ipld.L n.commitSelect.DocumentMapping.SetFirstOfName(&commit, request.DockeyFieldName, string(dockey)) - collection, err := n.planner.db.GetCollectionByVersionID(n.planner.ctx, schemaVersionId) + cols, err := n.planner.db.GetCollectionsByVersionID(n.planner.ctx, schemaVersionId) if err != nil { return core.Doc{}, nil, err } + if len(cols) == 0 { + return core.Doc{}, nil, client.NewErrCollectionNotFoundForSchemaVersion(schemaVersionId) + } + // WARNING: This will become incorrect once we allow multiple collections to share the same schema, + // we should by then instead fetch the collection be global collection ID: + // https://github.com/sourcenetwork/defradb/issues/1085 n.commitSelect.DocumentMapping.SetFirstOfName(&commit, - request.CollectionIDFieldName, int64(collection.ID())) + request.CollectionIDFieldName, int64(cols[0].ID())) heads := make([]*ipld.Link, 0) diff --git a/tests/clients/cli/wrapper.go b/tests/clients/cli/wrapper.go index 18ea1b2e4f..35caa3a591 100644 --- a/tests/clients/cli/wrapper.go +++ b/tests/clients/cli/wrapper.go @@ -225,7 +225,7 @@ func (w *Wrapper) GetCollectionByName(ctx context.Context, name client.Collectio return &Collection{w.cmd, definition}, nil } -func (w *Wrapper) GetCollectionBySchemaID(ctx context.Context, schemaId string) (client.Collection, error) { +func (w *Wrapper) GetCollectionsBySchemaID(ctx context.Context, schemaId string) ([]client.Collection, error) { args := []string{"client", "collection", "describe"} args = append(args, "--schema", schemaId) @@ -233,14 +233,18 @@ func (w *Wrapper) GetCollectionBySchemaID(ctx context.Context, schemaId string) if err != nil { return nil, err } - var definition client.CollectionDefinition - if err := json.Unmarshal(data, &definition); err != nil { + var colDesc []client.CollectionDefinition + if err := json.Unmarshal(data, &colDesc); err != nil { return nil, err } - return &Collection{w.cmd, definition}, nil + cols := make([]client.Collection, len(colDesc)) + for i, v := range colDesc { + cols[i] = &Collection{w.cmd, v} + } + return cols, err } -func (w *Wrapper) GetCollectionByVersionID(ctx context.Context, versionId string) (client.Collection, error) { +func (w *Wrapper) GetCollectionsByVersionID(ctx context.Context, versionId string) ([]client.Collection, error) { args := []string{"client", "collection", "describe"} args = append(args, "--version", versionId) @@ -248,11 +252,15 @@ func (w *Wrapper) GetCollectionByVersionID(ctx context.Context, versionId string if err != nil { return nil, err } - var definition client.CollectionDefinition - if err := json.Unmarshal(data, &definition); err != nil { + var colDesc []client.CollectionDefinition + if err := json.Unmarshal(data, &colDesc); err != nil { return nil, err } - return &Collection{w.cmd, definition}, nil + cols := make([]client.Collection, len(colDesc)) + for i, v := range colDesc { + cols[i] = &Collection{w.cmd, v} + } + return cols, err } func (w *Wrapper) GetAllCollections(ctx context.Context) ([]client.Collection, error) { diff --git a/tests/clients/http/wrapper.go b/tests/clients/http/wrapper.go index 8c6b42bc68..b61058525a 100644 --- a/tests/clients/http/wrapper.go +++ b/tests/clients/http/wrapper.go @@ -115,12 +115,12 @@ func (w *Wrapper) GetCollectionByName(ctx context.Context, name client.Collectio return w.client.GetCollectionByName(ctx, name) } -func (w *Wrapper) GetCollectionBySchemaID(ctx context.Context, schemaId string) (client.Collection, error) { - return w.client.GetCollectionBySchemaID(ctx, schemaId) +func (w *Wrapper) GetCollectionsBySchemaID(ctx context.Context, schemaId string) ([]client.Collection, error) { + return w.client.GetCollectionsBySchemaID(ctx, schemaId) } -func (w *Wrapper) GetCollectionByVersionID(ctx context.Context, versionId string) (client.Collection, error) { - return w.client.GetCollectionByVersionID(ctx, versionId) +func (w *Wrapper) GetCollectionsByVersionID(ctx context.Context, versionId string) ([]client.Collection, error) { + return w.client.GetCollectionsByVersionID(ctx, versionId) } func (w *Wrapper) GetAllCollections(ctx context.Context) ([]client.Collection, error) {