Skip to content

Commit

Permalink
Merge pull request #986 from ellemouton/sql16Sessions8
Browse files Browse the repository at this point in the history
[sql-16] sessions: update Store interface methods to take a context
  • Loading branch information
ellemouton authored Feb 28, 2025
2 parents bc4439f + 01c19c5 commit 66b0f15
Show file tree
Hide file tree
Showing 11 changed files with 170 additions and 120 deletions.
4 changes: 2 additions & 2 deletions firewall/privacy_mapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ func (p *PrivacyMapper) checkAndReplaceIncomingRequest(ctx context.Context,
uri string, req proto.Message, sessionID session.ID) (proto.Message,
error) {

session, err := p.sessionDB.GetSessionByID(sessionID)
session, err := p.sessionDB.GetSessionByID(ctx, sessionID)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -220,7 +220,7 @@ func (p *PrivacyMapper) checkAndReplaceIncomingRequest(ctx context.Context,
func (p *PrivacyMapper) replaceOutgoingResponse(ctx context.Context, uri string,
resp proto.Message, sessionID session.ID) (proto.Message, error) {

session, err := p.sessionDB.GetSessionByID(sessionID)
session, err := p.sessionDB.GetSessionByID(ctx, sessionID)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion firewall/rule_enforcer.go
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ func (r *RuleEnforcer) initRule(ctx context.Context, reqID uint64, name string,
return nil, err
}

session, err := r.sessionDB.GetSessionByID(sessionID)
session, err := r.sessionDB.GetSessionByID(ctx, sessionID)
if err != nil {
return nil, err
}
Expand Down
12 changes: 6 additions & 6 deletions firewalldb/actions.go
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ func (db *DB) ListSessionActions(sessionID session.ID,
// pass the filterFn requirements.
//
// TODO: update to allow for pagination.
func (db *DB) ListGroupActions(groupID session.ID,
func (db *DB) ListGroupActions(ctx context.Context, groupID session.ID,
filterFn ListActionsFilterFn) ([]*Action, error) {

if filterFn == nil {
Expand All @@ -400,7 +400,7 @@ func (db *DB) ListGroupActions(groupID session.ID,
}
}

sessionIDs, err := db.sessionIDIndex.GetSessionIDs(groupID)
sessionIDs, err := db.sessionIDIndex.GetSessionIDs(ctx, groupID)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -629,11 +629,11 @@ type groupActionsReadDB struct {
var _ ActionsDB = (*groupActionsReadDB)(nil)

// ListActions will return all the Actions for a particular group.
func (s *groupActionsReadDB) ListActions(_ context.Context) ([]*RuleAction,
func (s *groupActionsReadDB) ListActions(ctx context.Context) ([]*RuleAction,
error) {

sessionActions, err := s.db.ListGroupActions(
s.groupID, func(a *Action, _ bool) (bool, bool) {
ctx, s.groupID, func(a *Action, _ bool) (bool, bool) {
return a.State == ActionStateDone, true
},
)
Expand All @@ -660,11 +660,11 @@ var _ ActionsDB = (*groupFeatureActionsReadDB)(nil)

// ListActions will return all the Actions for a particular group that were
// executed by a particular feature.
func (a *groupFeatureActionsReadDB) ListActions(_ context.Context) (
func (a *groupFeatureActionsReadDB) ListActions(ctx context.Context) (
[]*RuleAction, error) {

featureActions, err := a.db.ListGroupActions(
a.groupID, func(action *Action, _ bool) (bool, bool) {
ctx, a.groupID, func(action *Action, _ bool) (bool, bool) {
return action.State == ActionStateDone &&
action.FeatureName == a.featureName, true
},
Expand Down
10 changes: 7 additions & 3 deletions firewalldb/actions_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package firewalldb

import (
"context"
"fmt"
"testing"
"time"
Expand Down Expand Up @@ -342,6 +343,9 @@ func TestListActions(t *testing.T) {
// TestListGroupActions tests that the ListGroupActions correctly returns all
// actions in a particular session group.
func TestListGroupActions(t *testing.T) {
t.Parallel()
ctx := context.Background()

group1 := intToSessionID(0)

// Link session 1 and session 2 to group 1.
Expand All @@ -356,7 +360,7 @@ func TestListGroupActions(t *testing.T) {
})

// There should not be any actions in group 1 yet.
al, err := db.ListGroupActions(group1, nil)
al, err := db.ListGroupActions(ctx, group1, nil)
require.NoError(t, err)
require.Empty(t, al)

Expand All @@ -365,7 +369,7 @@ func TestListGroupActions(t *testing.T) {
require.NoError(t, err)

// There should now be one action in the group.
al, err = db.ListGroupActions(group1, nil)
al, err = db.ListGroupActions(ctx, group1, nil)
require.NoError(t, err)
require.Len(t, al, 1)
require.Equal(t, sessionID1, al[0].SessionID)
Expand All @@ -375,7 +379,7 @@ func TestListGroupActions(t *testing.T) {
require.NoError(t, err)

// There should now be actions in the group.
al, err = db.ListGroupActions(group1, nil)
al, err = db.ListGroupActions(ctx, group1, nil)
require.NoError(t, err)
require.Len(t, al, 2)
require.Equal(t, sessionID1, al[0].SessionID)
Expand Down
8 changes: 6 additions & 2 deletions firewalldb/interface.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
package firewalldb

import "github.com/lightninglabs/lightning-terminal/session"
import (
"context"

"github.com/lightninglabs/lightning-terminal/session"
)

// SessionDB is an interface that abstracts the database operations needed for
// the privacy mapper to function.
type SessionDB interface {
session.IDToGroupIndex

// GetSessionByID returns the session for a specific id.
GetSessionByID(session.ID) (*session.Session, error)
GetSessionByID(context.Context, session.ID) (*session.Session, error)
}
13 changes: 9 additions & 4 deletions firewalldb/mock.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package firewalldb

import (
"context"
"fmt"

"github.com/lightninglabs/lightning-terminal/session"
Expand Down Expand Up @@ -33,7 +34,9 @@ func (m *mockSessionDB) AddPair(sessionID, groupID session.ID) {
}

// GetGroupID returns the group ID for the given session ID.
func (m *mockSessionDB) GetGroupID(sessionID session.ID) (session.ID, error) {
func (m *mockSessionDB) GetGroupID(_ context.Context, sessionID session.ID) (
session.ID, error) {

id, ok := m.sessionToGroupID[sessionID]
if !ok {
return session.ID{}, fmt.Errorf("no group ID found for " +
Expand All @@ -44,7 +47,9 @@ func (m *mockSessionDB) GetGroupID(sessionID session.ID) (session.ID, error) {
}

// GetSessionIDs returns the set of session IDs that are in the group
func (m *mockSessionDB) GetSessionIDs(groupID session.ID) ([]session.ID, error) {
func (m *mockSessionDB) GetSessionIDs(_ context.Context, groupID session.ID) (
[]session.ID, error) {

ids, ok := m.groupToSessionIDs[groupID]
if !ok {
return nil, fmt.Errorf("no session IDs found for group ID")
Expand All @@ -54,8 +59,8 @@ func (m *mockSessionDB) GetSessionIDs(groupID session.ID) ([]session.ID, error)
}

// GetSessionByID returns the session for a specific id.
func (m *mockSessionDB) GetSessionByID(sessionID session.ID) (*session.Session,
error) {
func (m *mockSessionDB) GetSessionByID(_ context.Context,
sessionID session.ID) (*session.Session, error) {

s, ok := m.sessionToGroupID[sessionID]
if !ok {
Expand Down
27 changes: 15 additions & 12 deletions session/interface.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package session

import (
"context"
"fmt"
"time"

Expand Down Expand Up @@ -260,11 +261,11 @@ func WithMacaroonRecipe(caveats []macaroon.Caveat, perms []bakery.Op) Option {
// IDToGroupIndex defines an interface for the session ID to group ID index.
type IDToGroupIndex interface {
// GetGroupID will return the group ID for the given session ID.
GetGroupID(sessionID ID) (ID, error)
GetGroupID(ctx context.Context, sessionID ID) (ID, error)

// GetSessionIDs will return the set of session IDs that are in the
// group with the given ID.
GetSessionIDs(groupID ID) ([]ID, error)
GetSessionIDs(ctx context.Context, groupID ID) ([]ID, error)
}

// Store is the interface a persistent storage must implement for storing and
Expand All @@ -273,37 +274,39 @@ type Store interface {
// NewSession creates a new session with the given user-defined
// parameters. The session will remain in the StateReserved state until
// ShiftState is called to update the state.
NewSession(label string, typ Type, expiry time.Time, serverAddr string,
opts ...Option) (*Session, error)
NewSession(ctx context.Context, label string, typ Type,
expiry time.Time, serverAddr string, opts ...Option) (*Session,
error)

// GetSession fetches the session with the given key.
GetSession(key *btcec.PublicKey) (*Session, error)
GetSession(ctx context.Context, key *btcec.PublicKey) (*Session, error)

// ListAllSessions returns all sessions currently known to the store.
ListAllSessions() ([]*Session, error)
ListAllSessions(ctx context.Context) ([]*Session, error)

// ListSessionsByType returns all sessions of the given type.
ListSessionsByType(t Type) ([]*Session, error)
ListSessionsByType(ctx context.Context, t Type) ([]*Session, error)

// ListSessionsByState returns all sessions currently known to the store
// that are in the given states.
ListSessionsByState(...State) ([]*Session, error)
ListSessionsByState(ctx context.Context, state ...State) ([]*Session,
error)

// UpdateSessionRemotePubKey can be used to add the given remote pub key
// to the session with the given local pub key.
UpdateSessionRemotePubKey(localPubKey,
UpdateSessionRemotePubKey(ctx context.Context, localPubKey,
remotePubKey *btcec.PublicKey) error

// GetSessionByID fetches the session with the given ID.
GetSessionByID(id ID) (*Session, error)
GetSessionByID(ctx context.Context, id ID) (*Session, error)

// DeleteReservedSessions deletes all sessions that are in the
// StateReserved state.
DeleteReservedSessions() error
DeleteReservedSessions(ctx context.Context) error

// ShiftState updates the state of the session with the given ID to the
// "dest" state.
ShiftState(id ID, dest State) error
ShiftState(ctx context.Context, id ID, dest State) error

IDToGroupIndex
}
35 changes: 23 additions & 12 deletions session/kvdb_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package session

import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
Expand Down Expand Up @@ -185,8 +186,8 @@ func getSessionKey(session *Session) []byte {
// ShiftState is called with StateCreated.
//
// NOTE: this is part of the Store interface.
func (db *BoltStore) NewSession(label string, typ Type, expiry time.Time,
serverAddr string, opts ...Option) (*Session, error) {
func (db *BoltStore) NewSession(ctx context.Context, label string, typ Type,
expiry time.Time, serverAddr string, opts ...Option) (*Session, error) {

var session *Session
err := db.Update(func(tx *bbolt.Tx) error {
Expand Down Expand Up @@ -285,7 +286,7 @@ func (db *BoltStore) NewSession(label string, typ Type, expiry time.Time,
// to the session with the given local pub key.
//
// NOTE: this is part of the Store interface.
func (db *BoltStore) UpdateSessionRemotePubKey(localPubKey,
func (db *BoltStore) UpdateSessionRemotePubKey(_ context.Context, localPubKey,
remotePubKey *btcec.PublicKey) error {

key := localPubKey.SerializeCompressed()
Expand Down Expand Up @@ -318,7 +319,9 @@ func (db *BoltStore) UpdateSessionRemotePubKey(localPubKey,
// GetSession fetches the session with the given key.
//
// NOTE: this is part of the Store interface.
func (db *BoltStore) GetSession(key *btcec.PublicKey) (*Session, error) {
func (db *BoltStore) GetSession(_ context.Context, key *btcec.PublicKey) (
*Session, error) {

var session *Session
err := db.View(func(tx *bbolt.Tx) error {
sessionBucket, err := getBucket(tx, sessionBucketKey)
Expand Down Expand Up @@ -348,7 +351,7 @@ func (db *BoltStore) GetSession(key *btcec.PublicKey) (*Session, error) {
// ListAllSessions returns all sessions currently known to the store.
//
// NOTE: this is part of the Store interface.
func (db *BoltStore) ListAllSessions() ([]*Session, error) {
func (db *BoltStore) ListAllSessions(_ context.Context) ([]*Session, error) {
return db.listSessions(func(s *Session) bool {
return true
})
Expand All @@ -358,7 +361,9 @@ func (db *BoltStore) ListAllSessions() ([]*Session, error) {
// have the given type.
//
// NOTE: this is part of the Store interface.
func (db *BoltStore) ListSessionsByType(t Type) ([]*Session, error) {
func (db *BoltStore) ListSessionsByType(_ context.Context, t Type) ([]*Session,
error) {

return db.listSessions(func(s *Session) bool {
return s.Type == t
})
Expand All @@ -368,7 +373,9 @@ func (db *BoltStore) ListSessionsByType(t Type) ([]*Session, error) {
// are in the given states.
//
// NOTE: this is part of the Store interface.
func (db *BoltStore) ListSessionsByState(states ...State) ([]*Session, error) {
func (db *BoltStore) ListSessionsByState(_ context.Context, states ...State) (
[]*Session, error) {

return db.listSessions(func(s *Session) bool {
for _, state := range states {
if s.State == state {
Expand Down Expand Up @@ -429,7 +436,7 @@ func (db *BoltStore) listSessions(filterFn func(s *Session) bool) ([]*Session,
// state.
//
// NOTE: this is part of the Store interface.
func (db *BoltStore) DeleteReservedSessions() error {
func (db *BoltStore) DeleteReservedSessions(_ context.Context) error {
return db.Update(func(tx *bbolt.Tx) error {
sessionBucket, err := getBucket(tx, sessionBucketKey)
if err != nil {
Expand Down Expand Up @@ -522,7 +529,7 @@ func (db *BoltStore) DeleteReservedSessions() error {
// state.
//
// NOTE: this is part of the Store interface.
func (db *BoltStore) ShiftState(id ID, dest State) error {
func (db *BoltStore) ShiftState(_ context.Context, id ID, dest State) error {
return db.Update(func(tx *bbolt.Tx) error {
sessionBucket, err := getBucket(tx, sessionBucketKey)
if err != nil {
Expand Down Expand Up @@ -562,7 +569,9 @@ func (db *BoltStore) ShiftState(id ID, dest State) error {
// GetSessionByID fetches the session with the given ID.
//
// NOTE: this is part of the Store interface.
func (db *BoltStore) GetSessionByID(id ID) (*Session, error) {
func (db *BoltStore) GetSessionByID(_ context.Context, id ID) (*Session,
error) {

var session *Session
err := db.View(func(tx *bbolt.Tx) error {
sessionBucket, err := getBucket(tx, sessionBucketKey)
Expand Down Expand Up @@ -615,7 +624,7 @@ func getUnusedIDAndKeyPair(bucket *bbolt.Bucket) (ID, *btcec.PrivateKey,
// GetGroupID will return the group ID for the given session ID.
//
// NOTE: this is part of the IDToGroupIndex interface.
func (db *BoltStore) GetGroupID(sessionID ID) (ID, error) {
func (db *BoltStore) GetGroupID(_ context.Context, sessionID ID) (ID, error) {
var groupID ID
err := db.View(func(tx *bbolt.Tx) error {
sessionBkt, err := getBucket(tx, sessionBucketKey)
Expand Down Expand Up @@ -655,7 +664,9 @@ func (db *BoltStore) GetGroupID(sessionID ID) (ID, error) {
// group with the given ID.
//
// NOTE: this is part of the IDToGroupIndex interface.
func (db *BoltStore) GetSessionIDs(groupID ID) ([]ID, error) {
func (db *BoltStore) GetSessionIDs(_ context.Context, groupID ID) ([]ID,
error) {

var (
sessionIDs []ID
err error
Expand Down
Loading

0 comments on commit 66b0f15

Please sign in to comment.