Skip to content

Commit

Permalink
Add Prune() for pruning TTL'd sessions in Postgres store.
Browse files Browse the repository at this point in the history
  • Loading branch information
knadh committed May 15, 2024
1 parent c5ca812 commit ad5265b
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 4 deletions.
16 changes: 16 additions & 0 deletions stores/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ type queries struct {
update *sql.Stmt
delete *sql.Stmt
clear *sql.Stmt
prune *sql.Stmt
}

// Store represents redis session store for simple sessions.
Expand Down Expand Up @@ -128,6 +129,9 @@ func (s *Store) Get(id, key string) (interface{}, error) {
// preserving the types.
var b []byte
if err := s.q.get.QueryRow(id, s.opt.TTL.Seconds()).Scan(&b); err != nil {
if err == sql.ErrNoRows {
return nil, ErrInvalidSession
}
return nil, err
}

Expand Down Expand Up @@ -324,6 +328,13 @@ func (s *Store) Clear(id string) error {
return nil
}

// Prune deletes rows that have exceeded the TTL. This should be run externally periodically (ideally as a separate goroutine)
// at desired intervals, hourly/daily etc. based on the expected volume of sessions.
func (s *Store) Prune() error {
_, err := s.q.prune.Exec(s.opt.TTL.Seconds())
return err
}

func (s *Store) prepareQueries() (*queries, error) {
var (
q = &queries{}
Expand Down Expand Up @@ -355,6 +366,11 @@ func (s *Store) prepareQueries() (*queries, error) {
return nil, err
}

q.prune, err = s.db.Prepare(fmt.Sprintf("DELETE FROM %s WHERE created_at <= NOW() - INTERVAL '1 second' * $1", s.opt.Table))
if err != nil {
return nil, err
}

return q, err
}

Expand Down
58 changes: 54 additions & 4 deletions stores/postgres/postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,17 @@ import (
"log"
"os"
"testing"
"time"

_ "github.com/lib/pq"
"github.com/stretchr/testify/assert"
)

const testTable = "sessions"

var (
st *Store
db *sql.DB
randID, _ = generateID(sessionIDLen)
)

Expand All @@ -26,18 +30,20 @@ func init() {

p := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=disable",
os.Getenv("PG_HOST"), os.Getenv("PG_PORT"), os.Getenv("PG_USER"), os.Getenv("PG_PASSWORD"), os.Getenv("PG_DB"))
db, err := sql.Open("postgres", p)
if err != nil {
if d, err := sql.Open("postgres", p); err != nil {
log.Fatal(err)
} else {
db = d
}

if err := db.Ping(); err != nil {
log.Fatal(err)
}

st, err = New(Opt{}, db)
if err != nil {
if s, err := New(Opt{TTL: time.Second * 2, Table: testTable}, db); err != nil {
log.Fatal(err)
} else {
st = s
}
}

Expand Down Expand Up @@ -119,3 +125,47 @@ func TestSet(t *testing.T) {
v, err = st.Get(id, "str")
assert.Error(t, err, ErrFieldNotFound)
}

func TestPrune(t *testing.T) {
// Create a new session.
id, err := st.Create()
assert.NoError(t, err)
assert.NotEmpty(t, id)

// Set value.
assert.NoError(t, st.Set(id, "str", "hello 123"))
assert.NoError(t, st.Commit(id))

// Get value and verify.
v, err := st.Get(id, "str")
assert.NoError(t, err)
assert.Equal(t, v, "hello 123")

// Wait until the 2 sec TTL expires and run prune.
time.Sleep(time.Second * 3)

// Session shouldn't be returned.
_, err = st.Get(id, "str")
assert.ErrorIs(t, err, ErrInvalidSession)

// Create one more session and immediately run prune. Except for this,
// all previous sessions should be gone.
id, err = st.Create()
assert.NoError(t, err)
assert.NoError(t, st.Set(id, "str", "hello 123"))
assert.NoError(t, st.Commit(id))

// Run prune. All previously created sessions should be gone.
assert.NoError(t, st.Prune())

var num int
err = db.QueryRow(fmt.Sprintf("SELECT COUNT(*) FROM %s", testTable)).Scan(&num)
assert.NoError(t, err)
assert.Equal(t, num, 1)

// The last created session shouldn't have been pruned.
v, err = st.Get(id, "str")
assert.NoError(t, err)
assert.Equal(t, v, "hello 123")

}

0 comments on commit ad5265b

Please sign in to comment.