From ba1bfb9ca32b20aed499bd767691465d6fb89873 Mon Sep 17 00:00:00 2001 From: Kailash Nadh Date: Tue, 14 May 2024 17:04:15 +0530 Subject: [PATCH] Add Postgres session store. --- go.work | 3 +- stores/pg/go.mod | 5 + stores/pg/pg.go | 342 +++++++++++++++++++++++++++++++++++++++++++ stores/pg/pg_test.go | 118 +++++++++++++++ 4 files changed, 467 insertions(+), 1 deletion(-) create mode 100644 stores/pg/go.mod create mode 100644 stores/pg/pg.go create mode 100644 stores/pg/pg_test.go diff --git a/go.work b/go.work index c30543b..f9f0f47 100644 --- a/go.work +++ b/go.work @@ -1,4 +1,4 @@ -go 1.21 +go 1.21.6 use ( . @@ -7,4 +7,5 @@ use ( ./stores/memory ./stores/redis ./stores/securecookie + ./stores/pg ) diff --git a/stores/pg/go.mod b/stores/pg/go.mod new file mode 100644 index 0000000..6480a78 --- /dev/null +++ b/stores/pg/go.mod @@ -0,0 +1,5 @@ +module github.com/vividvilla/simplesessions/stores/pg + +go 1.21.6 + +require github.com/lib/pq v1.10.9 diff --git a/stores/pg/pg.go b/stores/pg/pg.go new file mode 100644 index 0000000..d0be4ec --- /dev/null +++ b/stores/pg/pg.go @@ -0,0 +1,342 @@ +package pg + +/* +CREATE TABLE sessions ( + id TEXT NOT NULL PRIMARY KEY, + data jsonb DEFAULT '{}'::jsonb NOT NULL, + created_at timestamp without time zone DEFAULT now() NOT NULL +); +CREATE INDEX idx_sessions ON sessions (id, created_at); +*/ + +import ( + "crypto/rand" + "database/sql" + "encoding/json" + "errors" + "fmt" + "sync" + "time" + "unicode" + + _ "github.com/lib/pq" +) + +var ( + // Error codes for store errors. This should match the codes + // defined in the /simplesessions package exactly. + ErrInvalidSession = &Err{code: 1, msg: "invalid session"} + ErrFieldNotFound = &Err{code: 2, msg: "field not found"} + ErrAssertType = &Err{code: 3, msg: "assertion failed"} + ErrNil = &Err{code: 4, msg: "nil returned"} +) + +type Err struct { + code int + msg string +} + +func (e *Err) Error() string { + return e.msg +} + +func (e *Err) Code() int { + return e.code +} + +// Store represents redis session store for simple sessions. +// Each session is stored as redis hashmap. +type Store struct { + db *sql.DB + opt Opt + + commitID string + tx *sql.Tx + stmt *sql.Stmt + sync.Mutex +} + +type Opt struct { + Table string `json:"table"` + TTL time.Duration `json:"ttl"` + + // Delete expired (TTL) rows from the table at this interval. + // This runs concurrently on a separate goroutine. + CleanInterval time.Duration `json:"clean_interval"` +} + +const ( + sessionIDLen = 32 +) + +// New creates a new Postgres store instance. +func New(opt Opt, db *sql.DB) *Store { + if opt.Table == "" { + opt.Table = "sessions" + } + if opt.TTL.Seconds() < 1 { + opt.TTL = time.Hour * 24 + } + if opt.CleanInterval.Seconds() < 1 { + opt.CleanInterval = time.Hour * 1 + } + + return &Store{ + db: db, + opt: opt, + } +} + +// Create creates a new session and returns the ID. +func (s *Store) Create() (string, error) { + id, err := generateID(sessionIDLen) + if err != nil { + return "", err + } + + if _, err := s.db.Exec(fmt.Sprintf("INSERT INTO %s (id, data) VALUES($1, '{}'::JSONB)", s.opt.Table), id); err != nil { + return "", err + } + return id, nil +} + +// Get returns a single session field's value. +func (s *Store) Get(id, key string) (interface{}, error) { + if !validateID(id) { + return nil, ErrInvalidSession + } + + // Scan the whole JSON map out so that it can be unmarshalled, + // preserving the types. + var b []byte + err := s.db.QueryRow(fmt.Sprintf("SELECT data as val FROM %s WHERE id=$1 AND created_at >= NOW() - INTERVAL '1 second' * $2", s.opt.Table), id, s.opt.TTL.Seconds()).Scan(&b) + if err != nil { + return nil, err + } + + var mp map[string]interface{} + if err := json.Unmarshal(b, &mp); err != nil { + return nil, err + } + + v, ok := mp[key] + if !ok { + return nil, ErrFieldNotFound + } + + return v, nil +} + +// GetMulti gets a map for values for multiple keys. If a key doesn't exist, it returns ErrFieldNotFound. +func (s *Store) GetMulti(id string, keys ...string) (map[string]interface{}, error) { + if !validateID(id) { + return nil, ErrInvalidSession + } + + vals, err := s.GetAll(id) + if err != nil { + return nil, err + } + + out := make(map[string]interface{}, len(keys)) + for _, k := range keys { + v, ok := vals[k] + if !ok { + return nil, ErrFieldNotFound + } + out[k] = v + } + + return out, err +} + +// GetAll returns the map of all keys in the session. +func (s *Store) GetAll(id string, keys ...string) (map[string]interface{}, error) { + if !validateID(id) { + return nil, ErrInvalidSession + } + + var b []byte + err := s.db.QueryRow(fmt.Sprintf("SELECT data FROM %s WHERE id=$1 AND created_at >= NOW() - INTERVAL '1 second' * $2", s.opt.Table), id, s.opt.TTL.Seconds()).Scan(&b) + if err != nil { + return nil, err + } + + out := make(map[string]interface{}) + if err := json.Unmarshal(b, &out); err != nil { + return nil, err + } + + return out, err +} + +// Set sets a value to given session but is stored only on commit. +func (s *Store) Set(id, key string, val interface{}) (err error) { + if !validateID(id) { + return ErrInvalidSession + } + + b, err := json.Marshal(map[string]interface{}{key: val}) + if err != nil { + return err + } + + s.Lock() + defer func() { + if err == nil { + s.Unlock() + return + } + + if s.tx != nil { + s.tx.Rollback() + s.tx = nil + } + s.stmt = nil + + s.Unlock() + }() + + // If a transaction isn't set, set it. + if s.tx == nil { + tx, err := s.db.Begin() + if err != nil { + return err + } + + // Prepare the statement for executing SQL commands + stmt, err := tx.Prepare(fmt.Sprintf("UPDATE %s SET data = data || $2::JSONB WHERE id = $1", s.opt.Table)) + if err != nil { + return err + } + + s.tx = tx + s.stmt = stmt + } + + // Execute the query in the batch to be committed later. + res, err := s.stmt.Exec(id, json.RawMessage(b)) + if err != nil { + return err + } + num, err := res.RowsAffected() + if err != nil { + return err + } + + // No row was updated. The session didn't exist. + if num == 0 { + return ErrInvalidSession + } + + s.commitID = id + return err +} + +// Commit sets all set values +func (s *Store) Commit(id string) error { + if !validateID(id) { + return ErrInvalidSession + } + + s.Lock() + if s.commitID != id { + s.Unlock() + return ErrInvalidSession + } + + defer func() { + if s.stmt != nil { + s.stmt.Close() + } + s.tx = nil + s.stmt = nil + s.Unlock() + }() + + if s.tx == nil { + return errors.New("nothing to commit") + } + if s.commitID != id { + return ErrInvalidSession + } + + return s.tx.Commit() +} + +// Delete deletes a key from redis session hashmap. +func (s *Store) Delete(id string, key string) error { + if !validateID(id) { + return ErrInvalidSession + } + + res, err := s.db.Exec(fmt.Sprintf("UPDATE %s SET data = data - '%s' WHERE id=$1", s.opt.Table, key), id) + if err != nil { + return err + } + + num, err := res.RowsAffected() + if err != nil { + return err + } + + // No row was updated. The session didn't exist. + if num == 0 { + return ErrInvalidSession + } + + return nil +} + +// Clear clears session in redis. +func (s *Store) Clear(id string) error { + if !validateID(id) { + return ErrInvalidSession + } + + res, err := s.db.Exec(fmt.Sprintf("UPDATE %s SET data = '{}'::JSONB WHERE id=$1", s.opt.Table), id) + if err != nil { + return err + } + + num, err := res.RowsAffected() + if err != nil { + return err + } + + // No row was updated. The session didn't exist. + if num == 0 { + return ErrInvalidSession + } + + return nil +} + +func validateID(id string) bool { + if len(id) != sessionIDLen { + return false + } + + for _, r := range id { + if !unicode.IsDigit(r) && !unicode.IsLetter(r) { + return false + } + } + + return true +} + +// generateID generates a random alpha-num session ID. +func generateID(n int) (string, error) { + const dict = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + bytes := make([]byte, n) + if _, err := rand.Read(bytes); err != nil { + return "", err + } + + for k, v := range bytes { + bytes[k] = dict[v%byte(len(dict))] + } + + return string(bytes), nil +} diff --git a/stores/pg/pg_test.go b/stores/pg/pg_test.go new file mode 100644 index 0000000..7fda7ed --- /dev/null +++ b/stores/pg/pg_test.go @@ -0,0 +1,118 @@ +package pg + +// For this test to run, set env vars: PG_HOST, PG_PORT, PG_USER, PG_PASSWORD, PG_DB. + +import ( + "database/sql" + "fmt" + "log" + "os" + "testing" + + _ "github.com/lib/pq" + "github.com/stretchr/testify/assert" +) + +var ( + st *Store + randID, _ = generateID(sessionIDLen) +) + +func init() { + if os.Getenv("PG_HOST") == "" { + fmt.Println("WARNING: Skiping DB test as database config isn't set in env vars.") + os.Exit(0) + } + + p := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=disable", + os.Getenv("PG_HOST"), os.Getenv("PG_PORT"), os.Getenv("PG_USER"), os.Getenv("PG_PASSWORD"), os.Getenv("PG_DB")) + db, err := sql.Open("postgres", p) + if err != nil { + log.Fatal(err) + } + + if err := db.Ping(); err != nil { + log.Fatal(err) + } + + st = New(Opt{}, db) +} + +func TestCreate(t *testing.T) { + for n := 0; n < 5; n++ { + id, err := st.Create() + assert.NoError(t, err) + assert.NotEmpty(t, id) + } +} + +func TestSet(t *testing.T) { + assert.NotEmpty(t, randID) + + id, err := st.Create() + assert.NoError(t, err) + assert.NotEmpty(t, id) + + assert.NoError(t, st.Set(id, "num", 123)) + assert.NoError(t, st.Set(id, "str", "hello 123")) + assert.NoError(t, st.Set(id, "bool", true)) + + // Commit invalid session. + assert.Error(t, st.Commit(randID), ErrInvalidSession) + + // Commit valid session. + assert.NoError(t, st.Commit(id)) + + // Commit without setting. + assert.Error(t, st.Commit(id)) + assert.Error(t, st.Commit(randID)) + + // Get different types. + v, err := st.Get(id, "num") + assert.NoError(t, err) + assert.Equal(t, v, float64(123)) + + v, err = st.Get(id, "str") + assert.NoError(t, err) + assert.Equal(t, v, "hello 123") + + v, err = st.Get(id, "bool") + assert.NoError(t, err) + assert.Equal(t, v, true) + + // Non-existent field. + _, err = st.Get(id, "xx") + assert.ErrorIs(t, err, ErrFieldNotFound) + + // Get multiple. + mp, err := st.GetMulti(id, "num", "str", "bool") + assert.NoError(t, err) + assert.Equal(t, mp, map[string]interface{}{ + "str": "hello 123", + "num": float64(123), + "bool": true, + }) + mp, err = st.GetMulti(id, "num", "str", "bool", "blah") + assert.ErrorIs(t, err, ErrFieldNotFound) + + // Add another key in a different commit. + assert.NoError(t, st.Set(id, "num2", 456)) + assert.NoError(t, st.Commit(id)) + + v, err = st.Get(id, "num2") + assert.NoError(t, err) + assert.Equal(t, v, float64(456)) + + // Delete. + assert.ErrorIs(t, st.Delete("blah", "num2"), ErrInvalidSession) + assert.NoError(t, st.Delete(id, "num2")) + v, err = st.Get(id, "num2") + v, err = st.Get(id, "num3") + assert.Error(t, ErrFieldNotFound) + + // Clear. + assert.ErrorIs(t, st.Clear(randID), ErrInvalidSession) + assert.NoError(t, st.Clear(id)) + v, err = st.Get(id, "str") + assert.Error(t, err, ErrFieldNotFound) +}