Skip to content

Commit

Permalink
Replace tokens with sessions
Browse files Browse the repository at this point in the history
  • Loading branch information
theandrew168 committed Sep 1, 2024
1 parent 3043306 commit 71f7c35
Show file tree
Hide file tree
Showing 11 changed files with 185 additions and 208 deletions.
102 changes: 102 additions & 0 deletions backend/model/session.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package model

import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"time"

"github.com/google/uuid"
"github.com/theandrew168/bloggulus/backend/timeutil"
)

type Session struct {
id uuid.UUID
accountID uuid.UUID
hash string
expiresAt time.Time

createdAt time.Time
updatedAt time.Time
}

// Generate a random, crypto-safe session ID.
func GenerateSessionID() (string, error) {
b := make([]byte, 32)

_, err := rand.Read(b)
if err != nil {
return "", err
}

return base64.RawURLEncoding.EncodeToString(b), nil
}

func NewSession(account *Account, ttl time.Duration) (*Session, string, error) {
now := timeutil.Now()

sessionID, err := GenerateSessionID()
if err != nil {
return nil, "", err
}

// Generate a SHA-256 hash of the plaintext session ID. This will be the value
// that we store in the `hash` field of our database table. Note that the
// sha256.Sum256() function returns an array of length 32, so to make it easier to
// work with we convert it to a slice using the [:] operator before storing it.
hashBytes := sha256.Sum256([]byte(sessionID))
hash := hex.EncodeToString(hashBytes[:])

session := Session{
id: uuid.New(),
accountID: account.ID(),
hash: hash,
expiresAt: now.Add(ttl),

createdAt: now,
updatedAt: now,
}
return &session, sessionID, nil
}

func LoadSession(id, accountID uuid.UUID, hash string, expiresAt, createdAt, updatedAt time.Time) *Session {
session := Session{
id: id,
accountID: accountID,
hash: hash,
expiresAt: expiresAt,

createdAt: createdAt,
updatedAt: updatedAt,
}
return &session
}

func (s *Session) ID() uuid.UUID {
return s.id
}

func (s *Session) AccountID() uuid.UUID {
return s.accountID
}

func (s *Session) Hash() string {
return s.hash
}

func (s *Session) ExpiresAt() time.Time {
return s.expiresAt
}

func (s *Session) CreatedAt() time.Time {
return s.createdAt
}

func (s *Session) UpdatedAt() time.Time {
return s.updatedAt
}

func (s *Session) CheckDelete() error {
return nil
}
109 changes: 0 additions & 109 deletions backend/model/token.go

This file was deleted.

10 changes: 5 additions & 5 deletions backend/storage/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ func (s *AccountStorage) ReadByUsername(username string) (*model.Account, error)
return row.unmarshal()
}

func (s *AccountStorage) ReadByToken(token string) (*model.Account, error) {
func (s *AccountStorage) ReadBySessionID(sessionID string) (*model.Account, error) {
stmt := `
SELECT
account.id,
Expand All @@ -154,14 +154,14 @@ func (s *AccountStorage) ReadByToken(token string) (*model.Account, error) {
account.created_at,
account.updated_at
FROM account
INNER JOIN token
ON token.account_id = account.id
WHERE token.hash = $1`
INNER JOIN session
ON session.account_id = account.id
WHERE session.hash = $1`

ctx, cancel := context.WithTimeout(context.Background(), postgres.Timeout)
defer cancel()

hashBytes := sha256.Sum256([]byte(token))
hashBytes := sha256.Sum256([]byte(sessionID))
hash := hex.EncodeToString(hashBytes[:])

rows, err := s.conn.Query(ctx, stmt, hash)
Expand Down
6 changes: 3 additions & 3 deletions backend/storage/account_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,16 @@ func TestAccountReadByUsername(t *testing.T) {
test.AssertEqual(t, got.ID(), account.ID())
}

func TestAccountReadByToken(t *testing.T) {
func TestAccountReadBySessionIDn(t *testing.T) {
t.Parallel()

store, closer := test.NewStorage(t)
defer closer()

account, _ := test.CreateAccount(t, store)
_, token := test.CreateToken(t, store, account)
_, sessionID := test.CreateSession(t, store, account)

got, err := store.Account().ReadByToken(token)
got, err := store.Account().ReadBySessionID(sessionID)
test.AssertNilError(t, err)

test.AssertEqual(t, got.ID(), account.ID())
Expand Down
Loading

0 comments on commit 71f7c35

Please sign in to comment.