Skip to content

Commit

Permalink
Merge pull request #988 from ellemouton/sql17Sessions9
Browse files Browse the repository at this point in the history
[sql-17] sessions: test preparation
  • Loading branch information
ellemouton authored Feb 28, 2025
2 parents ee46094 + 44625c3 commit bc4439f
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 37 deletions.
4 changes: 3 additions & 1 deletion itest/litd_firewall_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -866,7 +866,9 @@ func testSessionLinking(net *NetworkHarness, t *harnessTest) {
LinkedGroupId: sessResp.Session.GroupId,
},
)
require.ErrorContains(t.t, err, "is still active")
require.ErrorContains(
t.t, err, session.ErrSessionsInGroupStillActive.Error(),
)

// Revoke the previous one and repeat.
_, err = litAutopilotClient.RevokeAutopilotSession(
Expand Down
12 changes: 12 additions & 0 deletions session/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,16 @@ var (
// ErrSessionNotFound is an error returned when we attempt to retrieve
// information about a session but it is not found.
ErrSessionNotFound = errors.New("session not found")

// ErrUnknownGroup is returned when an attempt is made to insert a
// session and link it to an existing group where the group is not
// known.
ErrUnknownGroup = errors.New("unknown group")

// ErrSessionsInGroupStillActive is returned when an attempt is made to
// insert a session and link it to a group that still has other active
// sessions.
ErrSessionsInGroupStillActive = errors.New(
"group has active sessions",
)
)
29 changes: 17 additions & 12 deletions session/kvdb_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,9 @@ func (db *BoltStore) NewSession(label string, typ Type, expiry time.Time,
if session.ID != session.GroupID {
_, err = getKeyForID(sessionBucket, session.GroupID)
if err != nil {
return fmt.Errorf("unknown linked session "+
"%x: %w", session.GroupID, err)
return fmt.Errorf("%w: unknown linked "+
"session %x: %w", ErrUnknownGroup,
session.GroupID, err)
}

// Fetch all the session IDs for this group. This will
Expand All @@ -237,18 +238,22 @@ func (db *BoltStore) NewSession(label string, typ Type, expiry time.Time,
return err
}

// Ensure that the all the linked sessions are no longer
// active.
for _, id := range sessionIDs {
sess, err := getSessionByID(sessionBucket, id)
if err != nil {
return err
}

// Ensure that the session is no longer active.
if !sess.State.Terminal() {
return fmt.Errorf("session (id=%x) "+
"in group %x is still active",
sess.ID, sess.GroupID)
if sess.State.Terminal() {
continue
}

return fmt.Errorf("%w: session (id=%x) in "+
"group %x is still active",
ErrSessionsInGroupStillActive, sess.ID,
sess.GroupID)
}
}

Expand Down Expand Up @@ -625,14 +630,14 @@ func (db *BoltStore) GetGroupID(sessionID ID) (ID, error) {

sessionIDBkt := idIndex.Bucket(sessionID[:])
if sessionIDBkt == nil {
return fmt.Errorf("no index entry for session ID: %x",
sessionID)
return fmt.Errorf("%w: no index entry for session "+
"ID: %x", ErrUnknownGroup, sessionID)
}

groupIDBytes := sessionIDBkt.Get(groupIDKey)
if len(groupIDBytes) == 0 {
return fmt.Errorf("group ID not found for session "+
"ID %x", sessionID)
return fmt.Errorf("%w: group ID not found for "+
"session ID %x", ErrUnknownGroup, sessionID)
}

copy(groupID[:], groupIDBytes)
Expand Down Expand Up @@ -801,7 +806,7 @@ func addIDToGroupIDPair(sessionBkt *bbolt.Bucket, id, groupID ID) error {
func getSessionByID(bucket *bbolt.Bucket, id ID) (*Session, error) {
keyBytes, err := getKeyForID(bucket, id)
if err != nil {
return nil, err
return nil, fmt.Errorf("%w: %w", ErrSessionNotFound, err)
}

v := bucket.Get(keyBytes)
Expand Down
36 changes: 12 additions & 24 deletions session/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ var testTime = time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)
func TestBasicSessionStore(t *testing.T) {
// Set up a new DB.
clock := clock.NewTestClock(testTime)
db, err := NewDB(t.TempDir(), "test.db", clock)
require.NoError(t, err)
t.Cleanup(func() {
_ = db.Close()
})
db := NewTestDB(t, clock)

// Try fetch a session that doesn't exist yet.
_, err := db.GetSessionByID(ID{1, 3, 4, 4})
require.ErrorIs(t, err, ErrSessionNotFound)

// Reserve a session. This should succeed.
s1, err := reserveSession(db, "session 1")
Expand Down Expand Up @@ -183,7 +183,7 @@ func TestBasicSessionStore(t *testing.T) {
require.Empty(t, sessions)

_, err = db.GetGroupID(s4.ID)
require.ErrorContains(t, err, "no index entry")
require.ErrorIs(t, err, ErrUnknownGroup)

// Only session 1 should remain in this group.
sessIDs, err = db.GetSessionIDs(s4.GroupID)
Expand All @@ -197,11 +197,7 @@ func TestLinkingSessions(t *testing.T) {

// Set up a new DB.
clock := clock.NewTestClock(testTime)
db, err := NewDB(t.TempDir(), "test.db", clock)
require.NoError(t, err)
t.Cleanup(func() {
_ = db.Close()
})
db := NewTestDB(t, clock)

groupID, err := IDFromBytes([]byte{1, 2, 3, 4})
require.NoError(t, err)
Expand All @@ -211,7 +207,7 @@ func TestLinkingSessions(t *testing.T) {
_, err = reserveSession(
db, "session 2", withLinkedGroupID(&groupID),
)
require.ErrorContains(t, err, "unknown linked session")
require.ErrorIs(t, err, ErrUnknownGroup)

// Create a new session with no previous link.
s1 := createSession(t, db, "session 1")
Expand All @@ -220,7 +216,7 @@ func TestLinkingSessions(t *testing.T) {
// session. This should fail due to the first session still being
// active.
_, err = reserveSession(db, "session 2", withLinkedGroupID(&s1.GroupID))
require.ErrorContains(t, err, "is still active")
require.ErrorIs(t, err, ErrSessionsInGroupStillActive)

// Revoke the first session.
require.NoError(t, db.ShiftState(s1.ID, StateRevoked))
Expand All @@ -238,11 +234,7 @@ func TestLinkedSessions(t *testing.T) {

// Set up a new DB.
clock := clock.NewTestClock(testTime)
db, err := NewDB(t.TempDir(), "test.db", clock)
require.NoError(t, err)
t.Cleanup(func() {
_ = db.Close()
})
db := NewTestDB(t, clock)

// Create a few sessions. The first one is a new session and the two
// after are all linked to the prior one. All these sessions belong to
Expand Down Expand Up @@ -294,18 +286,14 @@ func TestLinkedSessions(t *testing.T) {
func TestStateShift(t *testing.T) {
// Set up a new DB.
clock := clock.NewTestClock(testTime)
db, err := NewDB(t.TempDir(), "test.db", clock)
require.NoError(t, err)
t.Cleanup(func() {
_ = db.Close()
})
db := NewTestDB(t, clock)

// Add a new session to the DB.
s1 := createSession(t, db, "label 1")

// Check that the session is in the StateCreated state. Also check that
// the "RevokedAt" time has not yet been set.
s1, err = db.GetSession(s1.LocalPublicKey)
s1, err := db.GetSession(s1.LocalPublicKey)
require.NoError(t, err)
require.Equal(t, StateCreated, s1.State)
require.Equal(t, time.Time{}, s1.RevokedAt)
Expand Down
28 changes: 28 additions & 0 deletions session/test_kvdb.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package session

import (
"testing"

"github.com/lightningnetwork/lnd/clock"
"github.com/stretchr/testify/require"
)

// NewTestDB is a helper function that creates an BBolt database for testing.
func NewTestDB(t *testing.T, clock clock.Clock) *BoltStore {
return NewTestDBFromPath(t, t.TempDir(), clock)
}

// NewTestDBFromPath is a helper function that creates a new BoltStore with a
// connection to an existing BBolt database for testing.
func NewTestDBFromPath(t *testing.T, dbPath string,
clock clock.Clock) *BoltStore {

store, err := NewDB(dbPath, DBFilename, clock)
require.NoError(t, err)

t.Cleanup(func() {
require.NoError(t, store.DB.Close())
})

return store
}

0 comments on commit bc4439f

Please sign in to comment.