From 322ba569d6e03bd1b9895857f53f3ddf04f9726a Mon Sep 17 00:00:00 2001 From: Keenan Nemetz Date: Fri, 12 Apr 2024 13:23:12 -0700 Subject: [PATCH 1/2] remove txn params from private db methods --- db/backup.go | 13 ++--- db/backup_test.go | 40 +++++++++---- db/collection.go | 125 +++++++++++++++++++--------------------- db/collection_delete.go | 24 +++----- db/collection_get.go | 7 +-- db/collection_index.go | 76 ++++++++++++------------ db/collection_update.go | 24 +++----- db/context.go | 5 ++ db/db.go | 6 +- db/index_test.go | 24 +++++--- db/indexed_docs_test.go | 8 ++- db/lens.go | 11 ++-- db/request.go | 3 +- db/schema.go | 24 ++++---- db/sequence.go | 21 ++++--- db/store.go | 28 ++++----- db/subscriptions.go | 5 +- db/view.go | 8 +-- 18 files changed, 234 insertions(+), 218 deletions(-) diff --git a/db/backup.go b/db/backup.go index 17110bec05..4c72797b0e 100644 --- a/db/backup.go +++ b/db/backup.go @@ -20,10 +20,9 @@ import ( acpIdentity "github.com/sourcenetwork/defradb/acp/identity" "github.com/sourcenetwork/defradb/client" "github.com/sourcenetwork/defradb/client/request" - "github.com/sourcenetwork/defradb/datastore" ) -func (db *db) basicImport(ctx context.Context, txn datastore.Txn, filepath string) (err error) { +func (db *db) basicImport(ctx context.Context, filepath string) (err error) { f, err := os.Open(filepath) if err != nil { return NewErrOpenFile(err, filepath) @@ -50,7 +49,7 @@ func (db *db) basicImport(ctx context.Context, txn datastore.Txn, filepath strin return err } colName := t.(string) - col, err := db.getCollectionByName(ctx, txn, colName) + col, err := db.getCollectionByName(ctx, colName) if err != nil { return NewErrFailedToGetCollection(colName, err) } @@ -119,19 +118,19 @@ func (db *db) basicImport(ctx context.Context, txn datastore.Txn, filepath strin return nil } -func (db *db) basicExport(ctx context.Context, txn datastore.Txn, config *client.BackupConfig) (err error) { +func (db *db) basicExport(ctx context.Context, config *client.BackupConfig) (err error) { // old key -> new Key keyChangeCache := map[string]string{} cols := []client.Collection{} if len(config.Collections) == 0 { - cols, err = db.getCollections(ctx, txn, client.CollectionFetchOptions{}) + cols, err = db.getCollections(ctx, client.CollectionFetchOptions{}) if err != nil { return NewErrFailedToGetAllCollections(err) } } else { for _, colName := range config.Collections { - col, err := db.getCollectionByName(ctx, txn, colName) + col, err := db.getCollectionByName(ctx, colName) if err != nil { return NewErrFailedToGetCollection(colName, err) } @@ -233,7 +232,7 @@ func (db *db) basicExport(ctx context.Context, txn datastore.Txn, config *client refFieldName = field.Name + request.RelatedObjectID } } else { - foreignCol, err := db.getCollectionByName(ctx, txn, field.Kind.Underlying()) + foreignCol, err := db.getCollectionByName(ctx, field.Kind.Underlying()) if err != nil { return NewErrFailedToGetCollection(field.Kind.Underlying(), err) } diff --git a/db/backup_test.go b/db/backup_test.go index 6a9eab3cc9..968415f3b3 100644 --- a/db/backup_test.go +++ b/db/backup_test.go @@ -64,10 +64,12 @@ func TestBasicExport_WithNormalFormatting_NoError(t *testing.T) { txn, err := db.NewTxn(ctx, true) require.NoError(t, err) + + ctx = SetContextTxn(ctx, txn) defer txn.Discard(ctx) filepath := t.TempDir() + "/test.json" - err = db.basicExport(ctx, txn, &client.BackupConfig{Filepath: filepath}) + err = db.basicExport(ctx, &client.BackupConfig{Filepath: filepath}) require.NoError(t, err) b, err := os.ReadFile(filepath) @@ -126,10 +128,12 @@ func TestBasicExport_WithPrettyFormatting_NoError(t *testing.T) { txn, err := db.NewTxn(ctx, true) require.NoError(t, err) + + ctx = SetContextTxn(ctx, txn) defer txn.Discard(ctx) filepath := t.TempDir() + "/test.json" - err = db.basicExport(ctx, txn, &client.BackupConfig{Filepath: filepath, Pretty: true}) + err = db.basicExport(ctx, &client.BackupConfig{Filepath: filepath, Pretty: true}) require.NoError(t, err) b, err := os.ReadFile(filepath) @@ -188,10 +192,12 @@ func TestBasicExport_WithSingleCollection_NoError(t *testing.T) { txn, err := db.NewTxn(ctx, true) require.NoError(t, err) + + ctx = SetContextTxn(ctx, txn) defer txn.Discard(ctx) filepath := t.TempDir() + "/test.json" - err = db.basicExport(ctx, txn, &client.BackupConfig{Filepath: filepath, Collections: []string{"Address"}}) + err = db.basicExport(ctx, &client.BackupConfig{Filepath: filepath, Collections: []string{"Address"}}) require.NoError(t, err) b, err := os.ReadFile(filepath) @@ -262,10 +268,12 @@ func TestBasicExport_WithMultipleCollectionsAndUpdate_NoError(t *testing.T) { txn, err := db.NewTxn(ctx, true) require.NoError(t, err) + + ctx = SetContextTxn(ctx, txn) defer txn.Discard(ctx) filepath := t.TempDir() + "/test.json" - err = db.basicExport(ctx, txn, &client.BackupConfig{Filepath: filepath}) + err = db.basicExport(ctx, &client.BackupConfig{Filepath: filepath}) require.NoError(t, err) b, err := os.ReadFile(filepath) @@ -324,6 +332,8 @@ func TestBasicExport_EnsureFileOverwrite_NoError(t *testing.T) { txn, err := db.NewTxn(ctx, true) require.NoError(t, err) + + ctx = SetContextTxn(ctx, txn) defer txn.Discard(ctx) filepath := t.TempDir() + "/test.json" @@ -335,7 +345,7 @@ func TestBasicExport_EnsureFileOverwrite_NoError(t *testing.T) { ) require.NoError(t, err) - err = db.basicExport(ctx, txn, &client.BackupConfig{Filepath: filepath, Collections: []string{"Address"}}) + err = db.basicExport(ctx, &client.BackupConfig{Filepath: filepath, Collections: []string{"Address"}}) require.NoError(t, err) b, err := os.ReadFile(filepath) @@ -370,6 +380,7 @@ func TestBasicImport_WithMultipleCollectionsAndObjects_NoError(t *testing.T) { txn, err := db.NewTxn(ctx, false) require.NoError(t, err) + ctx = SetContextTxn(ctx, txn) filepath := t.TempDir() + "/test.json" @@ -380,15 +391,16 @@ func TestBasicImport_WithMultipleCollectionsAndObjects_NoError(t *testing.T) { ) require.NoError(t, err) - err = db.basicImport(ctx, txn, filepath) + err = db.basicImport(ctx, filepath) require.NoError(t, err) err = txn.Commit(ctx) require.NoError(t, err) txn, err = db.NewTxn(ctx, true) require.NoError(t, err) + ctx = SetContextTxn(ctx, txn) - col1, err := db.getCollectionByName(ctx, txn, "Address") + col1, err := db.getCollectionByName(ctx, "Address") require.NoError(t, err) key1, err := client.NewDocIDFromString("bae-8096f2c1-ea4c-5226-8ba5-17fc4b68ac1f") @@ -396,7 +408,7 @@ func TestBasicImport_WithMultipleCollectionsAndObjects_NoError(t *testing.T) { _, err = col1.Get(ctx, acpIdentity.NoIdentity, key1, false) require.NoError(t, err) - col2, err := db.getCollectionByName(ctx, txn, "User") + col2, err := db.getCollectionByName(ctx, "User") require.NoError(t, err) key2, err := client.NewDocIDFromString("bae-b94880d1-e6d2-542f-b9e0-5a369fafd0df") @@ -429,6 +441,7 @@ func TestBasicImport_WithJSONArray_ReturnError(t *testing.T) { txn, err := db.NewTxn(ctx, false) require.NoError(t, err) + ctx = SetContextTxn(ctx, txn) filepath := t.TempDir() + "/test.json" @@ -439,7 +452,7 @@ func TestBasicImport_WithJSONArray_ReturnError(t *testing.T) { ) require.NoError(t, err) - err = db.basicImport(ctx, txn, filepath) + err = db.basicImport(ctx, filepath) require.ErrorIs(t, err, ErrExpectedJSONObject) err = txn.Commit(ctx) require.NoError(t, err) @@ -464,6 +477,7 @@ func TestBasicImport_WithObjectCollection_ReturnError(t *testing.T) { txn, err := db.NewTxn(ctx, false) require.NoError(t, err) + ctx = SetContextTxn(ctx, txn) filepath := t.TempDir() + "/test.json" @@ -474,7 +488,7 @@ func TestBasicImport_WithObjectCollection_ReturnError(t *testing.T) { ) require.NoError(t, err) - err = db.basicImport(ctx, txn, filepath) + err = db.basicImport(ctx, filepath) require.ErrorIs(t, err, ErrExpectedJSONArray) err = txn.Commit(ctx) require.NoError(t, err) @@ -499,6 +513,7 @@ func TestBasicImport_WithInvalidFilepath_ReturnError(t *testing.T) { txn, err := db.NewTxn(ctx, false) require.NoError(t, err) + ctx = SetContextTxn(ctx, txn) filepath := t.TempDir() + "/test.json" @@ -510,7 +525,7 @@ func TestBasicImport_WithInvalidFilepath_ReturnError(t *testing.T) { require.NoError(t, err) wrongFilepath := t.TempDir() + "/some/test.json" - err = db.basicImport(ctx, txn, wrongFilepath) + err = db.basicImport(ctx, wrongFilepath) require.ErrorIs(t, err, os.ErrNotExist) err = txn.Commit(ctx) require.NoError(t, err) @@ -535,6 +550,7 @@ func TestBasicImport_WithInvalidCollection_ReturnError(t *testing.T) { txn, err := db.NewTxn(ctx, false) require.NoError(t, err) + ctx = SetContextTxn(ctx, txn) filepath := t.TempDir() + "/test.json" @@ -545,7 +561,7 @@ func TestBasicImport_WithInvalidCollection_ReturnError(t *testing.T) { ) require.NoError(t, err) - err = db.basicImport(ctx, txn, filepath) + err = db.basicImport(ctx, filepath) require.ErrorIs(t, err, ErrFailedToGetCollection) err = txn.Commit(ctx) require.NoError(t, err) diff --git a/db/collection.go b/db/collection.go index 1afa1c775a..faae1bbda7 100644 --- a/db/collection.go +++ b/db/collection.go @@ -31,7 +31,6 @@ import ( "github.com/sourcenetwork/defradb/client" "github.com/sourcenetwork/defradb/client/request" "github.com/sourcenetwork/defradb/core" - "github.com/sourcenetwork/defradb/datastore" "github.com/sourcenetwork/defradb/db/base" "github.com/sourcenetwork/defradb/db/description" "github.com/sourcenetwork/defradb/db/fetcher" @@ -85,11 +84,11 @@ func (c *collection) newFetcher() fetcher.Fetcher { // Note: Collection.ID is an auto-incrementing value that is generated by the database. func (db *db) createCollection( ctx context.Context, - txn datastore.Txn, def client.CollectionDefinition, ) (client.Collection, error) { schema := def.Schema desc := def.Description + txn := mustGetContextTxn(ctx) if desc.Name.HasValue() { exists, err := description.HasCollectionByName(ctx, txn, desc.Name.Value()) @@ -101,16 +100,16 @@ func (db *db) createCollection( } } - colSeq, err := db.getSequence(ctx, txn, core.CollectionIDSequenceKey{}) + colSeq, err := db.getSequence(ctx, core.CollectionIDSequenceKey{}) if err != nil { return nil, err } - colID, err := colSeq.next(ctx, txn) + colID, err := colSeq.next(ctx) if err != nil { return nil, err } - fieldSeq, err := db.getSequence(ctx, txn, core.NewFieldIDSequenceKey(uint32(colID))) + fieldSeq, err := db.getSequence(ctx, core.NewFieldIDSequenceKey(uint32(colID))) if err != nil { return nil, err } @@ -131,7 +130,7 @@ func (db *db) createCollection( // queries too. fieldID = 0 } else { - fieldID, err = fieldSeq.next(ctx, txn) + fieldID, err = fieldSeq.next(ctx) if err != nil { return nil, err } @@ -154,12 +153,12 @@ func (db *db) createCollection( col := db.newCollection(desc, schema) for _, index := range desc.Indexes { - if _, err := col.createIndex(ctx, txn, index); err != nil { + if _, err := col.createIndex(ctx, index); err != nil { return nil, err } } - return db.getCollectionByID(ctx, txn, desc.ID) + return db.getCollectionByID(ctx, desc.ID) } // validateCollectionDefinitionPolicyDesc validates that the policy definition is valid, beyond syntax. @@ -203,7 +202,6 @@ func (db *db) validateCollectionDefinitionPolicyDesc( // applied. func (db *db) updateSchema( ctx context.Context, - txn datastore.Txn, existingSchemaByName map[string]client.SchemaDescription, proposedDescriptionsByName map[string]client.SchemaDescription, schema client.SchemaDescription, @@ -244,6 +242,7 @@ func (db *db) updateSchema( } } + txn := mustGetContextTxn(ctx) previousVersionID := schema.VersionID schema, err = description.CreateSchemaVersion(ctx, txn, schema) if err != nil { @@ -259,7 +258,7 @@ func (db *db) updateSchema( return err } - colSeq, err := db.getSequence(ctx, txn, core.CollectionIDSequenceKey{}) + colSeq, err := db.getSequence(ctx, core.CollectionIDSequenceKey{}) if err != nil { return err } @@ -289,7 +288,7 @@ func (db *db) updateSchema( existingCol.RootID = col.RootID } - fieldSeq, err := db.getSequence(ctx, txn, core.NewFieldIDSequenceKey(existingCol.RootID)) + fieldSeq, err := db.getSequence(ctx, core.NewFieldIDSequenceKey(existingCol.RootID)) if err != nil { return err } @@ -302,7 +301,7 @@ func (db *db) updateSchema( if ok { fieldID = existingField.ID } else { - nextFieldID, err := fieldSeq.next(ctx, txn) + nextFieldID, err := fieldSeq.next(ctx) if err != nil { return err } @@ -328,12 +327,12 @@ func (db *db) updateSchema( } if !isExistingCol { - colID, err := colSeq.next(ctx, txn) + colID, err := colSeq.next(ctx) if err != nil { return err } - fieldSeq, err := db.getSequence(ctx, txn, core.NewFieldIDSequenceKey(col.RootID)) + fieldSeq, err := db.getSequence(ctx, core.NewFieldIDSequenceKey(col.RootID)) if err != nil { return err } @@ -353,7 +352,7 @@ func (db *db) updateSchema( for _, globalField := range schema.Fields { _, exists := col.GetFieldByName(globalField.Name) if !exists { - fieldID, err := fieldSeq.next(ctx, txn) + fieldID, err := fieldSeq.next(ctx) if err != nil { return err } @@ -385,7 +384,7 @@ func (db *db) updateSchema( if setAsActiveVersion { // activate collection versions using the new schema ID. This call must be made after // all new collection versions have been saved. - err = db.setActiveSchemaVersion(ctx, txn, schema.VersionID) + err = db.setActiveSchemaVersion(ctx, schema.VersionID) if err != nil { return err } @@ -549,14 +548,13 @@ func validateUpdateSchemaFields( func (db *db) patchCollection( ctx context.Context, - txn datastore.Txn, patchString string, ) error { patch, err := jsonpatch.DecodePatch([]byte(patchString)) if err != nil { return err } - + txn := mustGetContextTxn(ctx) cols, err := description.GetCollections(ctx, txn) if err != nil { return err @@ -638,7 +636,7 @@ func (db *db) patchCollection( } } - return db.loadSchema(ctx, txn) + return db.loadSchema(ctx) } var patchCollectionValidators = []func( @@ -917,13 +915,12 @@ oldLoop: // It will return an error if the provided schema version ID does not exist. func (db *db) setActiveSchemaVersion( ctx context.Context, - txn datastore.Txn, schemaVersionID string, ) error { if schemaVersionID == "" { return ErrSchemaVersionIDEmpty } - + txn := mustGetContextTxn(ctx) cols, err := description.GetCollectionsBySchemaVersionID(ctx, txn, schemaVersionID) if err != nil { return err @@ -967,11 +964,11 @@ func (db *db) setActiveSchemaVersion( if len(sources) > 0 { // For now, we assume that each collection can only have a single source. This will likely need // to change later. - activeCol, rootCol, isActiveFound = db.getActiveCollectionDown(ctx, txn, colsByID, sources[0].SourceCollectionID) + activeCol, rootCol, isActiveFound = db.getActiveCollectionDown(ctx, colsByID, sources[0].SourceCollectionID) } if !isActiveFound { // We need to look both down and up for the active version - the most recent is not necessarily the active one. - activeCol, isActiveFound = db.getActiveCollectionUp(ctx, txn, colsBySourceID, rootCol.ID) + activeCol, isActiveFound = db.getActiveCollectionUp(ctx, colsBySourceID, rootCol.ID) } var newName string @@ -1000,12 +997,11 @@ func (db *db) setActiveSchemaVersion( } // Load the schema into the clients (e.g. GQL) - return db.loadSchema(ctx, txn) + return db.loadSchema(ctx) } func (db *db) getActiveCollectionDown( ctx context.Context, - txn datastore.Txn, colsByID map[uint32]client.CollectionDescription, id uint32, ) (client.CollectionDescription, client.CollectionDescription, bool) { @@ -1028,12 +1024,11 @@ func (db *db) getActiveCollectionDown( // For now, we assume that each collection can only have a single source. This will likely need // to change later. - return db.getActiveCollectionDown(ctx, txn, colsByID, sources[0].SourceCollectionID) + return db.getActiveCollectionDown(ctx, colsByID, sources[0].SourceCollectionID) } func (db *db) getActiveCollectionUp( ctx context.Context, - txn datastore.Txn, colsBySourceID map[uint32][]client.CollectionDescription, id uint32, ) (client.CollectionDescription, bool) { @@ -1047,7 +1042,7 @@ func (db *db) getActiveCollectionUp( if col.Name.HasValue() { return col, true } - activeCol, isFound := db.getActiveCollectionUp(ctx, txn, colsBySourceID, col.ID) + activeCol, isFound := db.getActiveCollectionUp(ctx, colsBySourceID, col.ID) if isFound { return activeCol, isFound } @@ -1056,7 +1051,9 @@ func (db *db) getActiveCollectionUp( return client.CollectionDescription{}, false } -func (db *db) getCollectionByID(ctx context.Context, txn datastore.Txn, id uint32) (client.Collection, error) { +func (db *db) getCollectionByID(ctx context.Context, id uint32) (client.Collection, error) { + txn := mustGetContextTxn(ctx) + col, err := description.GetCollectionByID(ctx, txn, id) if err != nil { return nil, err @@ -1069,7 +1066,7 @@ func (db *db) getCollectionByID(ctx context.Context, txn datastore.Txn, id uint3 collection := db.newCollection(col, schema) - err = collection.loadIndexes(ctx, txn) + err = collection.loadIndexes(ctx) if err != nil { return nil, err } @@ -1078,12 +1075,12 @@ func (db *db) getCollectionByID(ctx context.Context, txn datastore.Txn, id uint3 } // getCollectionByName returns an existing collection within the database. -func (db *db) getCollectionByName(ctx context.Context, txn datastore.Txn, name string) (client.Collection, error) { +func (db *db) getCollectionByName(ctx context.Context, name string) (client.Collection, error) { if name == "" { return nil, ErrCollectionNameEmpty } - cols, err := db.getCollections(ctx, txn, client.CollectionFetchOptions{Name: immutable.Some(name)}) + cols, err := db.getCollections(ctx, client.CollectionFetchOptions{Name: immutable.Some(name)}) if err != nil { return nil, err } @@ -1099,11 +1096,11 @@ func (db *db) getCollectionByName(ctx context.Context, txn datastore.Txn, name s // is provided. func (db *db) getCollections( ctx context.Context, - txn datastore.Txn, options client.CollectionFetchOptions, ) ([]client.Collection, error) { - var cols []client.CollectionDescription + txn := mustGetContextTxn(ctx) + var cols []client.CollectionDescription switch { case options.Name.HasValue(): col, err := description.GetCollectionByName(ctx, txn, options.Name.Value()) @@ -1172,7 +1169,7 @@ func (db *db) getCollections( collection := db.newCollection(col, schema) collections = append(collections, collection) - err = collection.loadIndexes(ctx, txn) + err = collection.loadIndexes(ctx) if err != nil { return nil, err } @@ -1182,7 +1179,9 @@ func (db *db) getCollections( } // getAllActiveDefinitions returns all queryable collection/views and any embedded schema used by them. -func (db *db) getAllActiveDefinitions(ctx context.Context, txn datastore.Txn) ([]client.CollectionDefinition, error) { +func (db *db) getAllActiveDefinitions(ctx context.Context) ([]client.CollectionDefinition, error) { + txn := mustGetContextTxn(ctx) + cols, err := description.GetActiveCollections(ctx, txn) if err != nil { return nil, err @@ -1197,7 +1196,7 @@ func (db *db) getAllActiveDefinitions(ctx context.Context, txn datastore.Txn) ([ collection := db.newCollection(col, schema) - err = collection.loadIndexes(ctx, txn) + err = collection.loadIndexes(ctx) if err != nil { return nil, err } @@ -1230,18 +1229,18 @@ func (c *collection) GetAllDocIDs( ctx context.Context, identity immutable.Option[string], ) (<-chan client.DocIDResult, error) { - ctx, txn, err := ensureContextTxn(ctx, c.db, true) + ctx, _, err := ensureContextTxn(ctx, c.db, true) if err != nil { return nil, err } - return c.getAllDocIDsChan(ctx, identity, txn) + return c.getAllDocIDsChan(ctx, identity) } func (c *collection) getAllDocIDsChan( ctx context.Context, identity immutable.Option[string], - txn datastore.Txn, ) (<-chan client.DocIDResult, error) { + txn := mustGetContextTxn(ctx) prefix := core.PrimaryDataStoreKey{ // empty path for all keys prefix CollectionRootID: c.Description().RootID, } @@ -1353,7 +1352,7 @@ func (c *collection) Create( } defer txn.Discard(ctx) - err = c.create(ctx, identity, txn, doc) + err = c.create(ctx, identity, doc) if err != nil { return err } @@ -1375,7 +1374,7 @@ func (c *collection) CreateMany( defer txn.Discard(ctx) for _, doc := range docs { - err = c.create(ctx, identity, txn, doc) + err = c.create(ctx, identity, doc) if err != nil { return err } @@ -1402,7 +1401,6 @@ func (c *collection) getDocIDAndPrimaryKeyFromDoc( func (c *collection) create( ctx context.Context, identity immutable.Option[string], - txn datastore.Txn, doc *client.Document, ) error { docID, primaryKey, err := c.getDocIDAndPrimaryKeyFromDoc(doc) @@ -1411,7 +1409,7 @@ func (c *collection) create( } // check if doc already exists - exists, isDeleted, err := c.exists(ctx, identity, txn, primaryKey) + exists, isDeleted, err := c.exists(ctx, identity, primaryKey) if err != nil { return err } @@ -1424,6 +1422,7 @@ func (c *collection) create( // write value object marker if we have an empty doc if len(doc.Values()) == 0 { + txn := mustGetContextTxn(ctx) valueKey := c.getDataStoreKeyFromDocID(docID) err = txn.Datastore().Put(ctx, valueKey.ToDS(), []byte{base.ObjectMarker}) if err != nil { @@ -1432,12 +1431,12 @@ func (c *collection) create( } // write data to DB via MerkleClock/CRDT - _, err = c.save(ctx, identity, txn, doc, true) + _, err = c.save(ctx, identity, doc, true) if err != nil { return err } - err = c.indexNewDoc(ctx, txn, doc) + err = c.indexNewDoc(ctx, doc) if err != nil { return err } @@ -1460,7 +1459,7 @@ func (c *collection) Update( defer txn.Discard(ctx) primaryKey := c.getPrimaryKeyFromDocID(doc.ID()) - exists, isDeleted, err := c.exists(ctx, identity, txn, primaryKey) + exists, isDeleted, err := c.exists(ctx, identity, primaryKey) if err != nil { return err } @@ -1471,7 +1470,7 @@ func (c *collection) Update( return NewErrDocumentDeleted(primaryKey.DocID) } - err = c.update(ctx, identity, txn, doc) + err = c.update(ctx, identity, doc) if err != nil { return err } @@ -1487,7 +1486,6 @@ func (c *collection) Update( func (c *collection) update( ctx context.Context, identity immutable.Option[string], - txn datastore.Txn, doc *client.Document, ) error { // Stop the update if the correct permissions aren't there. @@ -1504,7 +1502,7 @@ func (c *collection) update( return client.ErrDocumentNotFoundOrNotAuthorized } - _, err = c.save(ctx, identity, txn, doc, false) + _, err = c.save(ctx, identity, doc, false) if err != nil { return err } @@ -1526,7 +1524,7 @@ func (c *collection) Save( // Check if document already exists with primary DS key. primaryKey := c.getPrimaryKeyFromDocID(doc.ID()) - exists, isDeleted, err := c.exists(ctx, identity, txn, primaryKey) + exists, isDeleted, err := c.exists(ctx, identity, primaryKey) if err != nil { return err } @@ -1536,9 +1534,9 @@ func (c *collection) Save( } if exists { - err = c.update(ctx, identity, txn, doc) + err = c.update(ctx, identity, doc) } else { - err = c.create(ctx, identity, txn, doc) + err = c.create(ctx, identity, doc) } if err != nil { return err @@ -1553,16 +1551,17 @@ func (c *collection) Save( func (c *collection) save( ctx context.Context, identity immutable.Option[string], - txn datastore.Txn, doc *client.Document, isCreate bool, ) (cid.Cid, error) { if !isCreate { - err := c.updateIndexedDoc(ctx, txn, doc) + err := c.updateIndexedDoc(ctx, doc) if err != nil { return cid.Undef, err } } + txn := mustGetContextTxn(ctx) + // NOTE: We delay the final Clean() call until we know // the commit on the transaction is successful. If we didn't // wait, and just did it here, then *if* the commit fails down @@ -1608,7 +1607,6 @@ func (c *collection) save( err = c.patchPrimaryDoc( ctx, identity, - txn, c.Name().Value(), relationFieldDescription, primaryKey.DocID, @@ -1626,7 +1624,6 @@ func (c *collection) save( err = c.validateOneToOneLinkDoesntAlreadyExist( ctx, identity, - txn, doc.ID().String(), fieldDescription, val.Value(), @@ -1662,7 +1659,6 @@ func (c *collection) save( headNode, priority, err := c.saveCompositeToMerkleCRDT( ctx, - txn, primaryKey.ToDataStoreKey(), links, client.Active, @@ -1697,7 +1693,6 @@ func (c *collection) save( func (c *collection) validateOneToOneLinkDoesntAlreadyExist( ctx context.Context, identity immutable.Option[string], - txn datastore.Txn, docID string, fieldDescription client.FieldDefinition, value any, @@ -1720,7 +1715,7 @@ func (c *collection) validateOneToOneLinkDoesntAlreadyExist( return nil } - otherCol, err := c.db.getCollectionByName(ctx, txn, objFieldDescription.Kind.Underlying()) + otherCol, err := c.db.getCollectionByName(ctx, objFieldDescription.Kind.Underlying()) if err != nil { return err } @@ -1743,7 +1738,7 @@ func (c *collection) validateOneToOneLinkDoesntAlreadyExist( fieldDescription.Name, value, ) - selectionPlan, err := c.makeSelectionPlan(ctx, identity, txn, filter) + selectionPlan, err := c.makeSelectionPlan(ctx, identity, filter) if err != nil { return err } @@ -1808,7 +1803,7 @@ func (c *collection) Delete( primaryKey := c.getPrimaryKeyFromDocID(docID) - err = c.applyDelete(ctx, identity, txn, primaryKey) + err = c.applyDelete(ctx, identity, primaryKey) if err != nil { return false, err } @@ -1828,7 +1823,7 @@ func (c *collection) Exists( defer txn.Discard(ctx) primaryKey := c.getPrimaryKeyFromDocID(docID) - exists, isDeleted, err := c.exists(ctx, identity, txn, primaryKey) + exists, isDeleted, err := c.exists(ctx, identity, primaryKey) if err != nil && !errors.Is(err, ds.ErrNotFound) { return false, err } @@ -1839,7 +1834,6 @@ func (c *collection) Exists( func (c *collection) exists( ctx context.Context, identity immutable.Option[string], - txn datastore.Txn, primaryKey core.PrimaryDataStoreKey, ) (exists bool, isDeleted bool, err error) { canRead, err := c.checkAccessOfDocWithACP( @@ -1854,6 +1848,7 @@ func (c *collection) exists( return false, false, nil } + txn := mustGetContextTxn(ctx) val, err := txn.Datastore().Get(ctx, primaryKey.ToDS()) if err != nil && errors.Is(err, ds.ErrNotFound) { return false, false, nil @@ -1873,11 +1868,11 @@ func (c *collection) exists( // Calling it elsewhere could cause the omission of acp checks. func (c *collection) saveCompositeToMerkleCRDT( ctx context.Context, - txn datastore.Txn, dsKey core.DataStoreKey, links []core.DAGLink, status client.DocumentStatus, ) (ipld.Node, uint64, error) { + txn := mustGetContextTxn(ctx) dsKey = dsKey.WithFieldId(core.COMPOSITE_NAMESPACE) merkleCRDT := merklecrdt.NewMerkleCompositeDAG( txn, diff --git a/db/collection_delete.go b/db/collection_delete.go index 8d5bf3f2bb..a6d12399ce 100644 --- a/db/collection_delete.go +++ b/db/collection_delete.go @@ -19,7 +19,6 @@ import ( "github.com/sourcenetwork/defradb/client" "github.com/sourcenetwork/defradb/client/request" "github.com/sourcenetwork/defradb/core" - "github.com/sourcenetwork/defradb/datastore" "github.com/sourcenetwork/defradb/events" "github.com/sourcenetwork/defradb/merkle/clock" ) @@ -61,7 +60,7 @@ func (c *collection) DeleteWithDocID( defer txn.Discard(ctx) dsKey := c.getPrimaryKeyFromDocID(docID) - res, err := c.deleteWithKey(ctx, identity, txn, dsKey) + res, err := c.deleteWithKey(ctx, identity, dsKey) if err != nil { return nil, err } @@ -81,7 +80,7 @@ func (c *collection) DeleteWithDocIDs( } defer txn.Discard(ctx) - res, err := c.deleteWithIDs(ctx, identity, txn, docIDs, client.Deleted) + res, err := c.deleteWithIDs(ctx, identity, docIDs, client.Deleted) if err != nil { return nil, err } @@ -101,7 +100,7 @@ func (c *collection) DeleteWithFilter( } defer txn.Discard(ctx) - res, err := c.deleteWithFilter(ctx, identity, txn, filter, client.Deleted) + res, err := c.deleteWithFilter(ctx, identity, filter, client.Deleted) if err != nil { return nil, err } @@ -112,12 +111,11 @@ func (c *collection) DeleteWithFilter( func (c *collection) deleteWithKey( ctx context.Context, identity immutable.Option[string], - txn datastore.Txn, key core.PrimaryDataStoreKey, ) (*client.DeleteResult, error) { // Check the key we have been given to delete with actually has a corresponding // document (i.e. document actually exists in the collection). - err := c.applyDelete(ctx, identity, txn, key) + err := c.applyDelete(ctx, identity, key) if err != nil { return nil, err } @@ -134,7 +132,6 @@ func (c *collection) deleteWithKey( func (c *collection) deleteWithIDs( ctx context.Context, identity immutable.Option[string], - txn datastore.Txn, docIDs []client.DocID, _ client.DocumentStatus, ) (*client.DeleteResult, error) { @@ -146,7 +143,7 @@ func (c *collection) deleteWithIDs( primaryKey := c.getPrimaryKeyFromDocID(docID) // Apply the function that will perform the full deletion of this document. - err := c.applyDelete(ctx, identity, txn, primaryKey) + err := c.applyDelete(ctx, identity, primaryKey) if err != nil { return nil, err } @@ -164,12 +161,11 @@ func (c *collection) deleteWithIDs( func (c *collection) deleteWithFilter( ctx context.Context, identity immutable.Option[string], - txn datastore.Txn, filter any, _ client.DocumentStatus, ) (*client.DeleteResult, error) { // Make a selection plan that will scan through only the documents with matching filter. - selectionPlan, err := c.makeSelectionPlan(ctx, identity, txn, filter) + selectionPlan, err := c.makeSelectionPlan(ctx, identity, filter) if err != nil { return nil, err } @@ -217,7 +213,7 @@ func (c *collection) deleteWithFilter( } // Delete the document that is associated with this DS key we got from the filter. - err = c.applyDelete(ctx, identity, txn, primaryKey) + err = c.applyDelete(ctx, identity, primaryKey) if err != nil { return nil, err } @@ -234,11 +230,10 @@ func (c *collection) deleteWithFilter( func (c *collection) applyDelete( ctx context.Context, identity immutable.Option[string], - txn datastore.Txn, primaryKey core.PrimaryDataStoreKey, ) error { // Must also have read permission to delete, inorder to check if document exists. - found, isDeleted, err := c.exists(ctx, identity, txn, primaryKey) + found, isDeleted, err := c.exists(ctx, identity, primaryKey) if err != nil { return err } @@ -264,8 +259,8 @@ func (c *collection) applyDelete( return client.ErrDocumentNotFoundOrNotAuthorized } + txn := mustGetContextTxn(ctx) dsKey := primaryKey.ToDataStoreKey() - headset := clock.NewHeadSet( txn.Headstore(), dsKey.WithFieldId(core.COMPOSITE_NAMESPACE).ToHeadStoreKey(), @@ -285,7 +280,6 @@ func (c *collection) applyDelete( headNode, priority, err := c.saveCompositeToMerkleCRDT( ctx, - txn, dsKey, dagLinks, client.Deleted, diff --git a/db/collection_get.go b/db/collection_get.go index 8ae0dcae75..968e6ca761 100644 --- a/db/collection_get.go +++ b/db/collection_get.go @@ -17,7 +17,6 @@ import ( "github.com/sourcenetwork/defradb/client" "github.com/sourcenetwork/defradb/core" - "github.com/sourcenetwork/defradb/datastore" "github.com/sourcenetwork/defradb/db/base" "github.com/sourcenetwork/defradb/db/fetcher" ) @@ -36,7 +35,7 @@ func (c *collection) Get( defer txn.Discard(ctx) primaryKey := c.getPrimaryKeyFromDocID(docID) - found, isDeleted, err := c.exists(ctx, identity, txn, primaryKey) + found, isDeleted, err := c.exists(ctx, identity, primaryKey) if err != nil { return nil, err } @@ -44,7 +43,7 @@ func (c *collection) Get( return nil, client.ErrDocumentNotFoundOrNotAuthorized } - doc, err := c.get(ctx, identity, txn, primaryKey, nil, showDeleted) + doc, err := c.get(ctx, identity, primaryKey, nil, showDeleted) if err != nil { return nil, err } @@ -59,11 +58,11 @@ func (c *collection) Get( func (c *collection) get( ctx context.Context, identity immutable.Option[string], - txn datastore.Txn, primaryKey core.PrimaryDataStoreKey, fields []client.FieldDefinition, showDeleted bool, ) (*client.Document, error) { + txn := mustGetContextTxn(ctx) // create a new document fetcher df := c.newFetcher() // initialize it with the primary index diff --git a/db/collection_index.go b/db/collection_index.go index 3e33c94709..0c1921dd62 100644 --- a/db/collection_index.go +++ b/db/collection_index.go @@ -33,36 +33,33 @@ import ( // createCollectionIndex creates a new collection index and saves it to the database in its system store. func (db *db) createCollectionIndex( ctx context.Context, - txn datastore.Txn, collectionName string, desc client.IndexDescription, ) (client.IndexDescription, error) { - col, err := db.getCollectionByName(ctx, txn, collectionName) + col, err := db.getCollectionByName(ctx, collectionName) if err != nil { return client.IndexDescription{}, NewErrCanNotReadCollection(collectionName, err) } - ctx = SetContextTxn(ctx, txn) return col.CreateIndex(ctx, desc) } func (db *db) dropCollectionIndex( ctx context.Context, - txn datastore.Txn, collectionName, indexName string, ) error { - col, err := db.getCollectionByName(ctx, txn, collectionName) + col, err := db.getCollectionByName(ctx, collectionName) if err != nil { return NewErrCanNotReadCollection(collectionName, err) } - ctx = SetContextTxn(ctx, txn) return col.DropIndex(ctx, indexName) } // getAllIndexDescriptions returns all the index descriptions in the database. func (db *db) getAllIndexDescriptions( ctx context.Context, - txn datastore.Txn, ) (map[client.CollectionName][]client.IndexDescription, error) { + // callers of this function must set a context transaction + txn := mustGetContextTxn(ctx) prefix := core.NewCollectionIndexKey(immutable.None[uint32](), "") keys, indexDescriptions, err := datastore.DeserializePrefix[client.IndexDescription](ctx, @@ -96,9 +93,10 @@ func (db *db) getAllIndexDescriptions( func (db *db) fetchCollectionIndexDescriptions( ctx context.Context, - txn datastore.Txn, colID uint32, ) ([]client.IndexDescription, error) { + // callers of this function must set a context transaction + txn := mustGetContextTxn(ctx) prefix := core.NewCollectionIndexKey(immutable.Some(colID), "") _, indexDescriptions, err := datastore.DeserializePrefix[client.IndexDescription]( ctx, @@ -118,7 +116,7 @@ func (c *collection) CreateDocIndex(ctx context.Context, doc *client.Document) e } defer txn.Discard(ctx) - err = c.indexNewDoc(ctx, txn, doc) + err = c.indexNewDoc(ctx, doc) if err != nil { return err } @@ -133,11 +131,11 @@ func (c *collection) UpdateDocIndex(ctx context.Context, oldDoc, newDoc *client. } defer txn.Discard(ctx) - err = c.deleteIndexedDoc(ctx, txn, oldDoc) + err = c.deleteIndexedDoc(ctx, oldDoc) if err != nil { return err } - err = c.indexNewDoc(ctx, txn, newDoc) + err = c.indexNewDoc(ctx, newDoc) if err != nil { return err } @@ -152,7 +150,7 @@ func (c *collection) DeleteDocIndex(ctx context.Context, doc *client.Document) e } defer txn.Discard(ctx) - err = c.deleteIndexedDoc(ctx, txn, doc) + err = c.deleteIndexedDoc(ctx, doc) if err != nil { return err } @@ -160,11 +158,13 @@ func (c *collection) DeleteDocIndex(ctx context.Context, doc *client.Document) e return txn.Commit(ctx) } -func (c *collection) indexNewDoc(ctx context.Context, txn datastore.Txn, doc *client.Document) error { - err := c.loadIndexes(ctx, txn) +func (c *collection) indexNewDoc(ctx context.Context, doc *client.Document) error { + err := c.loadIndexes(ctx) if err != nil { return err } + // callers of this function must set a context transaction + txn := mustGetContextTxn(ctx) for _, index := range c.indexes { err = index.Save(ctx, txn, doc) if err != nil { @@ -176,10 +176,9 @@ func (c *collection) indexNewDoc(ctx context.Context, txn datastore.Txn, doc *cl func (c *collection) updateIndexedDoc( ctx context.Context, - txn datastore.Txn, doc *client.Document, ) error { - err := c.loadIndexes(ctx, txn) + err := c.loadIndexes(ctx) if err != nil { return err } @@ -188,7 +187,6 @@ func (c *collection) updateIndexedDoc( oldDoc, err := c.get( ctx, acpIdentity.NoIdentity, - txn, c.getPrimaryKeyFromDocID(doc.ID()), c.Definition().CollectIndexedFields(), false, @@ -196,6 +194,7 @@ func (c *collection) updateIndexedDoc( if err != nil { return err } + txn := mustGetContextTxn(ctx) for _, index := range c.indexes { err = index.Update(ctx, txn, oldDoc, doc) if err != nil { @@ -207,13 +206,13 @@ func (c *collection) updateIndexedDoc( func (c *collection) deleteIndexedDoc( ctx context.Context, - txn datastore.Txn, doc *client.Document, ) error { - err := c.loadIndexes(ctx, txn) + err := c.loadIndexes(ctx) if err != nil { return err } + txn := mustGetContextTxn(ctx) for _, index := range c.indexes { err = index.Delete(ctx, txn, doc) if err != nil { @@ -248,7 +247,7 @@ func (c *collection) CreateIndex( } defer txn.Discard(ctx) - index, err := c.createIndex(ctx, txn, desc) + index, err := c.createIndex(ctx, desc) if err != nil { return client.IndexDescription{}, err } @@ -257,7 +256,6 @@ func (c *collection) CreateIndex( func (c *collection) createIndex( ctx context.Context, - txn datastore.Txn, desc client.IndexDescription, ) (CollectionIndex, error) { // Don't allow creating index on a permissioned collection, until following is implemented. @@ -279,20 +277,19 @@ func (c *collection) createIndex( return nil, err } - indexKey, err := c.generateIndexNameIfNeededAndCreateKey(ctx, txn, &desc) + indexKey, err := c.generateIndexNameIfNeededAndCreateKey(ctx, &desc) if err != nil { return nil, err } colSeq, err := c.db.getSequence( ctx, - txn, core.NewIndexIDSequenceKey(c.ID()), ) if err != nil { return nil, err } - colID, err := colSeq.next(ctx, txn) + colID, err := colSeq.next(ctx) if err != nil { return nil, err } @@ -303,6 +300,7 @@ func (c *collection) createIndex( return nil, err } + txn := mustGetContextTxn(ctx) err = txn.Systemstore().Put(ctx, indexKey.ToDS(), buf) if err != nil { return nil, err @@ -313,7 +311,7 @@ func (c *collection) createIndex( } c.def.Description.Indexes = append(c.def.Description.Indexes, colIndex.Description()) c.indexes = append(c.indexes, colIndex) - err = c.indexExistingDocs(ctx, txn, colIndex) + err = c.indexExistingDocs(ctx, colIndex) if err != nil { removeErr := colIndex.RemoveAll(ctx, txn) return nil, errors.Join(err, removeErr) @@ -323,10 +321,10 @@ func (c *collection) createIndex( func (c *collection) iterateAllDocs( ctx context.Context, - txn datastore.Txn, fields []client.FieldDefinition, exec func(doc *client.Document) error, ) error { + txn := mustGetContextTxn(ctx) df := c.newFetcher() err := df.Init( ctx, @@ -376,7 +374,6 @@ func (c *collection) iterateAllDocs( func (c *collection) indexExistingDocs( ctx context.Context, - txn datastore.Txn, index CollectionIndex, ) error { fields := make([]client.FieldDefinition, 0, 1) @@ -386,8 +383,8 @@ func (c *collection) indexExistingDocs( fields = append(fields, colField) } } - - return c.iterateAllDocs(ctx, txn, fields, func(doc *client.Document) error { + txn := mustGetContextTxn(ctx) + return c.iterateAllDocs(ctx, fields, func(doc *client.Document) error { return index.Save(ctx, txn, doc) }) } @@ -404,18 +401,19 @@ func (c *collection) DropIndex(ctx context.Context, indexName string) error { } defer txn.Discard(ctx) - err = c.dropIndex(ctx, txn, indexName) + err = c.dropIndex(ctx, indexName) if err != nil { return err } return txn.Commit(ctx) } -func (c *collection) dropIndex(ctx context.Context, txn datastore.Txn, indexName string) error { - err := c.loadIndexes(ctx, txn) +func (c *collection) dropIndex(ctx context.Context, indexName string) error { + err := c.loadIndexes(ctx) if err != nil { return err } + txn := mustGetContextTxn(ctx) var didFind bool for i := range c.indexes { @@ -448,7 +446,9 @@ func (c *collection) dropIndex(ctx context.Context, txn datastore.Txn, indexName return nil } -func (c *collection) dropAllIndexes(ctx context.Context, txn datastore.Txn) error { +func (c *collection) dropAllIndexes(ctx context.Context) error { + // callers of this function must set a context transaction + txn := mustGetContextTxn(ctx) prefix := core.NewCollectionIndexKey(immutable.Some(c.ID()), "") keys, err := datastore.FetchKeysForPrefix(ctx, prefix.ToString(), txn.Systemstore()) @@ -466,8 +466,8 @@ func (c *collection) dropAllIndexes(ctx context.Context, txn datastore.Txn) erro return err } -func (c *collection) loadIndexes(ctx context.Context, txn datastore.Txn) error { - indexDescriptions, err := c.db.fetchCollectionIndexDescriptions(ctx, txn, c.ID()) +func (c *collection) loadIndexes(ctx context.Context) error { + indexDescriptions, err := c.db.fetchCollectionIndexDescriptions(ctx, c.ID()) if err != nil { return err } @@ -492,7 +492,7 @@ func (c *collection) GetIndexes(ctx context.Context) ([]client.IndexDescription, } defer txn.Discard(ctx) - err = c.loadIndexes(ctx, txn) + err = c.loadIndexes(ctx) if err != nil { return nil, err } @@ -520,9 +520,11 @@ func (c *collection) checkExistingFields( func (c *collection) generateIndexNameIfNeededAndCreateKey( ctx context.Context, - txn datastore.Txn, desc *client.IndexDescription, ) (core.CollectionIndexKey, error) { + // callers of this function must set a context transaction + txn := mustGetContextTxn(ctx) + var indexKey core.CollectionIndexKey if desc.Name == "" { nameIncrement := 1 diff --git a/db/collection_update.go b/db/collection_update.go index e9ab2e7fa1..9a8e2bc552 100644 --- a/db/collection_update.go +++ b/db/collection_update.go @@ -20,7 +20,6 @@ import ( "github.com/sourcenetwork/defradb/client" "github.com/sourcenetwork/defradb/client/request" - "github.com/sourcenetwork/defradb/datastore" "github.com/sourcenetwork/defradb/errors" "github.com/sourcenetwork/defradb/planner" ) @@ -63,7 +62,7 @@ func (c *collection) UpdateWithFilter( } defer txn.Discard(ctx) - res, err := c.updateWithFilter(ctx, identity, txn, filter, updater) + res, err := c.updateWithFilter(ctx, identity, filter, updater) if err != nil { return nil, err } @@ -85,7 +84,7 @@ func (c *collection) UpdateWithDocID( } defer txn.Discard(ctx) - res, err := c.updateWithDocID(ctx, identity, txn, docID, updater) + res, err := c.updateWithDocID(ctx, identity, docID, updater) if err != nil { return nil, err } @@ -108,7 +107,7 @@ func (c *collection) UpdateWithDocIDs( } defer txn.Discard(ctx) - res, err := c.updateWithIDs(ctx, identity, txn, docIDs, updater) + res, err := c.updateWithIDs(ctx, identity, docIDs, updater) if err != nil { return nil, err } @@ -119,7 +118,6 @@ func (c *collection) UpdateWithDocIDs( func (c *collection) updateWithDocID( ctx context.Context, identity immutable.Option[string], - txn datastore.Txn, docID client.DocID, updater string, ) (*client.UpdateResult, error) { @@ -149,7 +147,7 @@ func (c *collection) updateWithDocID( return nil, err } - err = c.update(ctx, identity, txn, doc) + err = c.update(ctx, identity, doc) if err != nil { return nil, err } @@ -164,7 +162,6 @@ func (c *collection) updateWithDocID( func (c *collection) updateWithIDs( ctx context.Context, identity immutable.Option[string], - txn datastore.Txn, docIDs []client.DocID, updater string, ) (*client.UpdateResult, error) { @@ -198,7 +195,7 @@ func (c *collection) updateWithIDs( return nil, err } - err = c.update(ctx, identity, txn, doc) + err = c.update(ctx, identity, doc) if err != nil { return nil, err } @@ -212,7 +209,6 @@ func (c *collection) updateWithIDs( func (c *collection) updateWithFilter( ctx context.Context, identity immutable.Option[string], - txn datastore.Txn, filter any, updater string, ) (*client.UpdateResult, error) { @@ -233,7 +229,7 @@ func (c *collection) updateWithFilter( } // Make a selection plan that will scan through only the documents with matching filter. - selectionPlan, err := c.makeSelectionPlan(ctx, identity, txn, filter) + selectionPlan, err := c.makeSelectionPlan(ctx, identity, filter) if err != nil { return nil, err } @@ -287,7 +283,7 @@ func (c *collection) updateWithFilter( } } - err = c.update(ctx, identity, txn, doc) + err = c.update(ctx, identity, doc) if err != nil { return nil, err } @@ -321,7 +317,6 @@ func (c *collection) isSecondaryIDField(fieldDesc client.FieldDefinition) (clien func (c *collection) patchPrimaryDoc( ctx context.Context, identity immutable.Option[string], - txn datastore.Txn, secondaryCollectionName string, relationFieldDescription client.FieldDefinition, docID string, @@ -332,7 +327,7 @@ func (c *collection) patchPrimaryDoc( return err } - primaryCol, err := c.db.getCollectionByName(ctx, txn, relationFieldDescription.Kind.Underlying()) + primaryCol, err := c.db.getCollectionByName(ctx, relationFieldDescription.Kind.Underlying()) if err != nil { return err } @@ -373,7 +368,6 @@ func (c *collection) patchPrimaryDoc( err = pc.validateOneToOneLinkDoesntAlreadyExist( ctx, identity, - txn, primaryDocID.String(), primaryIDField, docID, @@ -411,7 +405,6 @@ func (c *collection) patchPrimaryDoc( func (c *collection) makeSelectionPlan( ctx context.Context, identity immutable.Option[string], - txn datastore.Txn, filter any, ) (planner.RequestPlan, error) { var f immutable.Option[request.Filter] @@ -437,6 +430,7 @@ func (c *collection) makeSelectionPlan( return nil, err } + txn := mustGetContextTxn(ctx) planner := planner.New( ctx, identity, diff --git a/db/context.go b/db/context.go index d39472ea5a..8c8b941284 100644 --- a/db/context.go +++ b/db/context.go @@ -53,6 +53,11 @@ func ensureContextTxn(ctx context.Context, db transactionDB, readOnly bool) (con return SetContextTxn(ctx, txn), txn, nil } +// mustGetContextTxn returns the transaction from the context or panics. +func mustGetContextTxn(ctx context.Context) datastore.Txn { + return ctx.Value(txnContextKey{}).(datastore.Txn) +} + // TryGetContextTxn returns a transaction and a bool indicating if the // txn was retrieved from the given context. func TryGetContextTxn(ctx context.Context) (datastore.Txn, bool) { diff --git a/db/db.go b/db/db.go index e7a6fa8d09..327f8e9c9e 100644 --- a/db/db.go +++ b/db/db.go @@ -181,7 +181,7 @@ func (db *db) initialize(ctx context.Context) error { db.glock.Lock() defer db.glock.Unlock() - txn, err := db.NewTxn(ctx, false) + ctx, txn, err := ensureContextTxn(ctx, db, false) if err != nil { return err } @@ -202,7 +202,7 @@ func (db *db) initialize(ctx context.Context) error { // if we're loading an existing database, just load the schema // and migrations and finish initialization if exists { - err = db.loadSchema(ctx, txn) + err = db.loadSchema(ctx) if err != nil { return err } @@ -220,7 +220,7 @@ func (db *db) initialize(ctx context.Context) error { // init meta data // collection sequence - _, err = db.getSequence(ctx, txn, core.CollectionIDSequenceKey{}) + _, err = db.getSequence(ctx, core.CollectionIDSequenceKey{}) if err != nil { return err } diff --git a/db/index_test.go b/db/index_test.go index aeda2bdd6d..5409b6c20e 100644 --- a/db/index_test.go +++ b/db/index_test.go @@ -219,7 +219,8 @@ func (f *indexTestFixture) createUserCollectionIndexOnAge() client.IndexDescript } func (f *indexTestFixture) dropIndex(colName, indexName string) error { - return f.db.dropCollectionIndex(f.ctx, f.txn, colName, indexName) + ctx := SetContextTxn(f.ctx, f.txn) + return f.db.dropCollectionIndex(ctx, colName, indexName) } func (f *indexTestFixture) countIndexPrefixes(indexName string) int { @@ -255,7 +256,8 @@ func (f *indexTestFixture) createCollectionIndexFor( collectionName string, desc client.IndexDescription, ) (client.IndexDescription, error) { - index, err := f.db.createCollectionIndex(f.ctx, f.txn, collectionName, desc) + ctx := SetContextTxn(f.ctx, f.txn) + index, err := f.db.createCollectionIndex(ctx, collectionName, desc) if err == nil { f.commitTxn() } @@ -263,11 +265,13 @@ func (f *indexTestFixture) createCollectionIndexFor( } func (f *indexTestFixture) getAllIndexes() (map[client.CollectionName][]client.IndexDescription, error) { - return f.db.getAllIndexDescriptions(f.ctx, f.txn) + ctx := SetContextTxn(f.ctx, f.txn) + return f.db.getAllIndexDescriptions(ctx) } func (f *indexTestFixture) getCollectionIndexes(colID uint32) ([]client.IndexDescription, error) { - return f.db.fetchCollectionIndexDescriptions(f.ctx, f.txn, colID) + ctx := SetContextTxn(f.ctx, f.txn) + return f.db.fetchCollectionIndexDescriptions(ctx, colID) } func TestCreateIndex_IfFieldsIsEmpty_ReturnError(t *testing.T) { @@ -1172,7 +1176,8 @@ func TestDropAllIndexes_ShouldDeleteAllIndexes(t *testing.T) { assert.Equal(t, 2, f.countIndexPrefixes("")) - err = f.users.(*collection).dropAllIndexes(f.ctx, f.txn) + ctx := SetContextTxn(f.ctx, f.txn) + err = f.users.(*collection).dropAllIndexes(ctx) assert.NoError(t, err) assert.Equal(t, 0, f.countIndexPrefixes("")) @@ -1184,7 +1189,8 @@ func TestDropAllIndexes_IfStorageFails_ReturnError(t *testing.T) { f.createUserCollectionIndexOnName() f.db.Close() - err := f.users.(*collection).dropAllIndexes(f.ctx, f.txn) + ctx := SetContextTxn(f.ctx, f.txn) + err := f.users.(*collection).dropAllIndexes(ctx) assert.Error(t, err) } @@ -1240,7 +1246,8 @@ func TestDropAllIndexes_IfSystemStorageFails_ReturnError(t *testing.T) { mockedTxn.EXPECT().Systemstore().Unset() mockedTxn.EXPECT().Systemstore().Return(mockedTxn.MockSystemstore).Maybe() - err := f.users.(*collection).dropAllIndexes(f.ctx, f.txn) + ctx := SetContextTxn(f.ctx, f.txn) + err := f.users.(*collection).dropAllIndexes(ctx) assert.ErrorIs(t, err, testErr, testCase.Name) } } @@ -1261,7 +1268,8 @@ func TestDropAllIndexes_ShouldCloseQueryIterator(t *testing.T) { mockedTxn.EXPECT().Systemstore().Unset() mockedTxn.EXPECT().Systemstore().Return(mockedTxn.MockSystemstore).Maybe() - _ = f.users.(*collection).dropAllIndexes(f.ctx, f.txn) + ctx := SetContextTxn(f.ctx, f.txn) + _ = f.users.(*collection).dropAllIndexes(ctx) } func TestNewCollectionIndex_IfDescriptionHasNoFields_ReturnError(t *testing.T) { diff --git a/db/indexed_docs_test.go b/db/indexed_docs_test.go index 70604fdc1f..99a4c9ee56 100644 --- a/db/indexed_docs_test.go +++ b/db/indexed_docs_test.go @@ -131,7 +131,8 @@ func (b *indexKeyBuilder) Build() core.IndexDataStoreKey { return key } - cols, err := b.f.db.getCollections(b.f.ctx, b.f.txn, client.CollectionFetchOptions{}) + ctx := SetContextTxn(b.f.ctx, b.f.txn) + cols, err := b.f.db.getCollections(ctx, client.CollectionFetchOptions{}) require.NoError(b.f.t, err) var collection client.Collection for _, col := range cols { @@ -793,7 +794,8 @@ func TestNonUniqueUpdate_IfFailsToReadIndexDescription_ReturnError(t *testing.T) require.NoError(t, err) // retrieve the collection without index cached - usersCol, err := f.db.getCollectionByName(f.ctx, f.txn, usersColName) + ctx := SetContextTxn(f.ctx, f.txn) + usersCol, err := f.db.getCollectionByName(ctx, usersColName) require.NoError(t, err) testErr := errors.New("test error") @@ -809,7 +811,7 @@ func TestNonUniqueUpdate_IfFailsToReadIndexDescription_ReturnError(t *testing.T) usersCol.(*collection).fetcherFactory = func() fetcher.Fetcher { return fetcherMocks.NewStubbedFetcher(t) } - ctx := SetContextTxn(f.ctx, mockedTxn) + ctx = SetContextTxn(f.ctx, mockedTxn) err = usersCol.Update(ctx, acpIdentity.NoIdentity, doc) require.ErrorIs(t, err, testErr) } diff --git a/db/lens.go b/db/lens.go index d5240dad83..f21d084f88 100644 --- a/db/lens.go +++ b/db/lens.go @@ -18,12 +18,13 @@ import ( "github.com/sourcenetwork/defradb/client" "github.com/sourcenetwork/defradb/core" - "github.com/sourcenetwork/defradb/datastore" "github.com/sourcenetwork/defradb/db/description" "github.com/sourcenetwork/defradb/errors" ) -func (db *db) setMigration(ctx context.Context, txn datastore.Txn, cfg client.LensConfig) error { +func (db *db) setMigration(ctx context.Context, cfg client.LensConfig) error { + txn := mustGetContextTxn(ctx) + dstCols, err := description.GetCollectionsBySchemaVersionID(ctx, txn, cfg.DestinationSchemaVersionID) if err != nil { return err @@ -34,7 +35,7 @@ func (db *db) setMigration(ctx context.Context, txn datastore.Txn, cfg client.Le return err } - colSeq, err := db.getSequence(ctx, txn, core.CollectionIDSequenceKey{}) + colSeq, err := db.getSequence(ctx, core.CollectionIDSequenceKey{}) if err != nil { return err } @@ -42,7 +43,7 @@ func (db *db) setMigration(ctx context.Context, txn datastore.Txn, cfg client.Le if len(sourceCols) == 0 { // If no collections are found with the given [SourceSchemaVersionID], this migration must be from // a collection/schema version that does not yet exist locally. We must now create it. - colID, err := colSeq.next(ctx, txn) + colID, err := colSeq.next(ctx) if err != nil { return err } @@ -86,7 +87,7 @@ func (db *db) setMigration(ctx context.Context, txn datastore.Txn, cfg client.Le if !isDstCollectionFound { // If the destination collection was not found, we must create it. This can happen when setting a migration // to a schema version that does not yet exist locally. - colID, err := colSeq.next(ctx, txn) + colID, err := colSeq.next(ctx) if err != nil { return err } diff --git a/db/request.go b/db/request.go index 69b300f482..099f8852ed 100644 --- a/db/request.go +++ b/db/request.go @@ -16,7 +16,6 @@ import ( "github.com/sourcenetwork/immutable" "github.com/sourcenetwork/defradb/client" - "github.com/sourcenetwork/defradb/datastore" "github.com/sourcenetwork/defradb/planner" ) @@ -25,7 +24,6 @@ func (db *db) execRequest( ctx context.Context, identity immutable.Option[string], request string, - txn datastore.Txn, ) *client.RequestResult { res := &client.RequestResult{} ast, err := db.parser.BuildRequestAST(request) @@ -55,6 +53,7 @@ func (db *db) execRequest( return res } + txn := mustGetContextTxn(ctx) planner := planner.New( ctx, identity, diff --git a/db/schema.go b/db/schema.go index 5b10df9906..6d52a92aee 100644 --- a/db/schema.go +++ b/db/schema.go @@ -23,7 +23,6 @@ import ( "github.com/sourcenetwork/immutable" "github.com/sourcenetwork/defradb/client" - "github.com/sourcenetwork/defradb/datastore" "github.com/sourcenetwork/defradb/db/description" ) @@ -37,7 +36,6 @@ const ( // and creates the necessary collections, request types, etc. func (db *db) addSchema( ctx context.Context, - txn datastore.Txn, schemaString string, ) ([]client.CollectionDescription, error) { newDefinitions, err := db.parser.ParseSDL(ctx, schemaString) @@ -53,14 +51,14 @@ func (db *db) addSchema( return nil, err } - col, err := db.createCollection(ctx, txn, definition) + col, err := db.createCollection(ctx, definition) if err != nil { return nil, err } returnDescriptions[i] = col.Description() } - err = db.loadSchema(ctx, txn) + err = db.loadSchema(ctx) if err != nil { return nil, err } @@ -68,8 +66,10 @@ func (db *db) addSchema( return returnDescriptions, nil } -func (db *db) loadSchema(ctx context.Context, txn datastore.Txn) error { - definitions, err := db.getAllActiveDefinitions(ctx, txn) +func (db *db) loadSchema(ctx context.Context) error { + txn := mustGetContextTxn(ctx) + + definitions, err := db.getAllActiveDefinitions(ctx) if err != nil { return err } @@ -90,11 +90,12 @@ func (db *db) loadSchema(ctx context.Context, txn datastore.Txn) error { // will be applied. func (db *db) patchSchema( ctx context.Context, - txn datastore.Txn, patchString string, migration immutable.Option[model.Lens], setAsDefaultVersion bool, ) error { + txn := mustGetContextTxn(ctx) + patch, err := jsonpatch.DecodePatch([]byte(patchString)) if err != nil { return err @@ -137,7 +138,6 @@ func (db *db) patchSchema( for _, schema := range newSchemaByName { err := db.updateSchema( ctx, - txn, existingSchemaByName, newSchemaByName, schema, @@ -149,7 +149,7 @@ func (db *db) patchSchema( } } - return db.loadSchema(ctx, txn) + return db.loadSchema(ctx) } // substituteSchemaPatch handles any substitution of values that may be required before @@ -246,10 +246,9 @@ func substituteSchemaPatch( func (db *db) getSchemaByVersionID( ctx context.Context, - txn datastore.Txn, versionID string, ) (client.SchemaDescription, error) { - schemas, err := db.getSchemas(ctx, txn, client.SchemaFetchOptions{ID: immutable.Some(versionID)}) + schemas, err := db.getSchemas(ctx, client.SchemaFetchOptions{ID: immutable.Some(versionID)}) if err != nil { return client.SchemaDescription{}, err } @@ -260,9 +259,10 @@ func (db *db) getSchemaByVersionID( func (db *db) getSchemas( ctx context.Context, - txn datastore.Txn, options client.SchemaFetchOptions, ) ([]client.SchemaDescription, error) { + txn := mustGetContextTxn(ctx) + schemas := []client.SchemaDescription{} switch { diff --git a/db/sequence.go b/db/sequence.go index 3c510ec78c..f39bdcfb65 100644 --- a/db/sequence.go +++ b/db/sequence.go @@ -17,7 +17,6 @@ import ( ds "github.com/ipfs/go-datastore" "github.com/sourcenetwork/defradb/core" - "github.com/sourcenetwork/defradb/datastore" "github.com/sourcenetwork/defradb/errors" ) @@ -26,15 +25,15 @@ type sequence struct { val uint64 } -func (db *db) getSequence(ctx context.Context, txn datastore.Txn, key core.Key) (*sequence, error) { +func (db *db) getSequence(ctx context.Context, key core.Key) (*sequence, error) { seq := &sequence{ key: key, val: uint64(0), } - _, err := seq.get(ctx, txn) + _, err := seq.get(ctx) if errors.Is(err, ds.ErrNotFound) { - err = seq.update(ctx, txn) + err = seq.update(ctx) if err != nil { return nil, err } @@ -45,7 +44,9 @@ func (db *db) getSequence(ctx context.Context, txn datastore.Txn, key core.Key) return seq, nil } -func (seq *sequence) get(ctx context.Context, txn datastore.Txn) (uint64, error) { +func (seq *sequence) get(ctx context.Context) (uint64, error) { + txn := mustGetContextTxn(ctx) + val, err := txn.Systemstore().Get(ctx, seq.key.ToDS()) if err != nil { return 0, err @@ -55,7 +56,9 @@ func (seq *sequence) get(ctx context.Context, txn datastore.Txn) (uint64, error) return seq.val, nil } -func (seq *sequence) update(ctx context.Context, txn datastore.Txn) error { +func (seq *sequence) update(ctx context.Context) error { + txn := mustGetContextTxn(ctx) + var buf [8]byte binary.BigEndian.PutUint64(buf[:], seq.val) if err := txn.Systemstore().Put(ctx, seq.key.ToDS(), buf[:]); err != nil { @@ -65,12 +68,12 @@ func (seq *sequence) update(ctx context.Context, txn datastore.Txn) error { return nil } -func (seq *sequence) next(ctx context.Context, txn datastore.Txn) (uint64, error) { - _, err := seq.get(ctx, txn) +func (seq *sequence) next(ctx context.Context) (uint64, error) { + _, err := seq.get(ctx) if err != nil { return 0, err } seq.val++ - return seq.val, seq.update(ctx, txn) + return seq.val, seq.update(ctx) } diff --git a/db/store.go b/db/store.go index aff11f851d..5a3f3f7ad6 100644 --- a/db/store.go +++ b/db/store.go @@ -34,7 +34,7 @@ func (db *db) ExecRequest( } defer txn.Discard(ctx) - res := db.execRequest(ctx, identity, request, txn) + res := db.execRequest(ctx, identity, request) if len(res.GQL.Errors) > 0 { return res } @@ -55,7 +55,7 @@ func (db *db) GetCollectionByName(ctx context.Context, name string) (client.Coll } defer txn.Discard(ctx) - return db.getCollectionByName(ctx, txn, name) + return db.getCollectionByName(ctx, name) } // GetCollections gets all the currently defined collections. @@ -69,7 +69,7 @@ func (db *db) GetCollections( } defer txn.Discard(ctx) - return db.getCollections(ctx, txn, options) + return db.getCollections(ctx, options) } // GetSchemaByVersionID returns the schema description for the schema version of the @@ -83,7 +83,7 @@ func (db *db) GetSchemaByVersionID(ctx context.Context, versionID string) (clien } defer txn.Discard(ctx) - return db.getSchemaByVersionID(ctx, txn, versionID) + return db.getSchemaByVersionID(ctx, versionID) } // GetSchemas returns all schema versions that currently exist within @@ -98,7 +98,7 @@ func (db *db) GetSchemas( } defer txn.Discard(ctx) - return db.getSchemas(ctx, txn, options) + return db.getSchemas(ctx, options) } // GetAllIndexes gets all the indexes in the database. @@ -111,7 +111,7 @@ func (db *db) GetAllIndexes( } defer txn.Discard(ctx) - return db.getAllIndexDescriptions(ctx, txn) + return db.getAllIndexDescriptions(ctx) } // AddSchema takes the provided GQL schema in SDL format, and applies it to the database, @@ -126,7 +126,7 @@ func (db *db) AddSchema(ctx context.Context, schemaString string) ([]client.Coll } defer txn.Discard(ctx) - cols, err := db.addSchema(ctx, txn, schemaString) + cols, err := db.addSchema(ctx, schemaString) if err != nil { return nil, err } @@ -160,7 +160,7 @@ func (db *db) PatchSchema( } defer txn.Discard(ctx) - err = db.patchSchema(ctx, txn, patchString, migration, setAsDefaultVersion) + err = db.patchSchema(ctx, patchString, migration, setAsDefaultVersion) if err != nil { return err } @@ -178,7 +178,7 @@ func (db *db) PatchCollection( } defer txn.Discard(ctx) - err = db.patchCollection(ctx, txn, patchString) + err = db.patchCollection(ctx, patchString) if err != nil { return err } @@ -193,7 +193,7 @@ func (db *db) SetActiveSchemaVersion(ctx context.Context, schemaVersionID string } defer txn.Discard(ctx) - err = db.setActiveSchemaVersion(ctx, txn, schemaVersionID) + err = db.setActiveSchemaVersion(ctx, schemaVersionID) if err != nil { return err } @@ -208,7 +208,7 @@ func (db *db) SetMigration(ctx context.Context, cfg client.LensConfig) error { } defer txn.Discard(ctx) - err = db.setMigration(ctx, txn, cfg) + err = db.setMigration(ctx, cfg) if err != nil { return err } @@ -228,7 +228,7 @@ func (db *db) AddView( } defer txn.Discard(ctx) - defs, err := db.addView(ctx, txn, query, sdl, transform) + defs, err := db.addView(ctx, query, sdl, transform) if err != nil { return nil, err } @@ -250,7 +250,7 @@ func (db *db) BasicImport(ctx context.Context, filepath string) error { } defer txn.Discard(ctx) - err = db.basicImport(ctx, txn, filepath) + err = db.basicImport(ctx, filepath) if err != nil { return err } @@ -266,7 +266,7 @@ func (db *db) BasicExport(ctx context.Context, config *client.BackupConfig) erro } defer txn.Discard(ctx) - err = db.basicExport(ctx, txn, config) + err = db.basicExport(ctx, config) if err != nil { return err } diff --git a/db/subscriptions.go b/db/subscriptions.go index e649769c18..a8c8f5bb42 100644 --- a/db/subscriptions.go +++ b/db/subscriptions.go @@ -17,7 +17,6 @@ import ( "github.com/sourcenetwork/defradb/client" "github.com/sourcenetwork/defradb/client/request" - "github.com/sourcenetwork/defradb/datastore" "github.com/sourcenetwork/defradb/events" "github.com/sourcenetwork/defradb/planner" ) @@ -63,7 +62,7 @@ func (db *db) handleSubscription( } ctx := SetContextTxn(ctx, txn) - db.handleEvent(ctx, identity, txn, pub, evt, r) + db.handleEvent(ctx, identity, pub, evt, r) txn.Discard(ctx) } } @@ -71,11 +70,11 @@ func (db *db) handleSubscription( func (db *db) handleEvent( ctx context.Context, identity immutable.Option[string], - txn datastore.Txn, pub *events.Publisher[events.Update], evt events.Update, r *request.ObjectSubscription, ) { + txn := mustGetContextTxn(ctx) p := planner.New( ctx, identity, diff --git a/db/view.go b/db/view.go index ea57f94541..5a778efd53 100644 --- a/db/view.go +++ b/db/view.go @@ -20,17 +20,17 @@ import ( "github.com/sourcenetwork/defradb/client" "github.com/sourcenetwork/defradb/client/request" - "github.com/sourcenetwork/defradb/datastore" "github.com/sourcenetwork/defradb/db/description" ) func (db *db) addView( ctx context.Context, - txn datastore.Txn, inputQuery string, sdl string, transform immutable.Option[model.Lens], ) ([]client.CollectionDefinition, error) { + txn := mustGetContextTxn(ctx) + // Wrap the given query as part of the GQL query object - this simplifies the syntax for users // and ensures that we can't be given mutations. In the future this line should disappear along // with the all calls to the parser appart from `ParseSDL` when we implement the DQL stuff. @@ -80,7 +80,7 @@ func (db *db) addView( Schema: schema, } } else { - col, err := db.createCollection(ctx, txn, definition) + col, err := db.createCollection(ctx, definition) if err != nil { return nil, err } @@ -97,7 +97,7 @@ func (db *db) addView( } } - err = db.loadSchema(ctx, txn) + err = db.loadSchema(ctx) if err != nil { return nil, err } From d32c1419cd4b253489477a1c41965761c3d3999c Mon Sep 17 00:00:00 2001 From: Keenan Nemetz Date: Mon, 15 Apr 2024 10:02:16 -0700 Subject: [PATCH 2/2] update db context docs --- db/context.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/db/context.go b/db/context.go index 8c8b941284..f235475d24 100644 --- a/db/context.go +++ b/db/context.go @@ -41,6 +41,9 @@ type transactionDB interface { // // If a transactions exists on the context it will be made explicit, // otherwise a new implicit transaction will be created. +// +// The returned context will contain the transaction +// along with the copied values from the input context. func ensureContextTxn(ctx context.Context, db transactionDB, readOnly bool) (context.Context, datastore.Txn, error) { txn, ok := TryGetContextTxn(ctx) if ok { @@ -54,6 +57,9 @@ func ensureContextTxn(ctx context.Context, db transactionDB, readOnly bool) (con } // mustGetContextTxn returns the transaction from the context or panics. +// +// This should only be called from private functions within the db package +// where we ensure an implicit or explicit transaction always exists. func mustGetContextTxn(ctx context.Context) datastore.Txn { return ctx.Value(txnContextKey{}).(datastore.Txn) }