From ca00f8b5d37010f3a1ff25e81bb4b70976165e3f Mon Sep 17 00:00:00 2001 From: Drewry Pope Date: Thu, 4 Jan 2024 20:56:09 -0600 Subject: [PATCH 1/4] Add DB Versioning (#1) * add versions * align --- cmd/serve_migrate.go | 35 +- .../db/sqlite/migration/0001_create_tables.go | 80 ++++ ...1_foreign_keys.go => 0002_foreign_keys.go} | 8 +- server/db/sqlite/migration/migration.go | 68 +++- server/db/sqlite/sql.go | 121 ++---- server/db/sqlite/storage.go | 377 ++++++++++++++---- server/http.go | 8 +- server/link.go | 4 +- server/ssh.go | 4 +- server/stats/prometheus/prometheus.go | 2 +- 10 files changed, 496 insertions(+), 211 deletions(-) create mode 100644 server/db/sqlite/migration/0001_create_tables.go rename server/db/sqlite/migration/{0001_foreign_keys.go => 0002_foreign_keys.go} (93%) diff --git a/cmd/serve_migrate.go b/cmd/serve_migrate.go index 59588352..98c5b24f 100644 --- a/cmd/serve_migrate.go +++ b/cmd/serve_migrate.go @@ -1,16 +1,12 @@ package cmd import ( - "database/sql" - "fmt" - "os" "path/filepath" - "github.com/charmbracelet/log" - "github.com/charmbracelet/charm/server" "github.com/charmbracelet/charm/server/db/sqlite" - "github.com/charmbracelet/charm/server/db/sqlite/migration" + "github.com/charmbracelet/charm/server/storage" + "github.com/charmbracelet/log" "github.com/spf13/cobra" _ "modernc.org/sqlite" // sqlite driver @@ -26,30 +22,11 @@ var ServeMigrationCmd = &cobra.Command{ Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { cfg := server.DefaultConfig() - dp := filepath.Join(cfg.DataDir, "db", sqlite.DbName) - _, err := os.Stat(dp) - if err != nil { - return fmt.Errorf("database does not exist: %s", err) - } - db := sqlite.NewDB(dp) - for _, m := range []migration.Migration{ - migration.Migration0001, - } { - log.Print("Running migration", "id", fmt.Sprintf("%04d", m.ID), "name", m.Name) - err = db.WrapTransaction(func(tx *sql.Tx) error { - _, err := tx.Exec(m.SQL) - if err != nil { - return err - } - return nil - }) - if err != nil { - break - } - } + dp := filepath.Join(cfg.DataDir, "db") + err := storage.EnsureDir(dp, 0o700) if err != nil { - return err + log.Fatal("could not init sqlite path", "err", err) } - return nil + return sqlite.NewDB(filepath.Join(dp, sqlite.DbName)).Migrate() }, } diff --git a/server/db/sqlite/migration/0001_create_tables.go b/server/db/sqlite/migration/0001_create_tables.go new file mode 100644 index 00000000..9edcb3ef --- /dev/null +++ b/server/db/sqlite/migration/0001_create_tables.go @@ -0,0 +1,80 @@ +package migration + +// Migration0001 is the initial migration. +var Migration0001 = Migration{ + Version: 1, + Name: "create tables", + SQL: ` +CREATE TABLE IF NOT EXISTS charm_user( + id INTEGER NOT NULL PRIMARY KEY, + charm_id uuid UNIQUE NOT NULL, + name varchar(50) UNIQUE, + email varchar(254), + bio varchar(1000), + created_at timestamp default current_timestamp +); + +CREATE TABLE IF NOT EXISTS public_key( + id INTEGER NOT NULL PRIMARY KEY, + user_id integer NOT NULL, + public_key varchar(2048) NOT NULL, + created_at timestamp default current_timestamp, + UNIQUE (user_id, public_key), + CONSTRAINT user_id_fk + FOREIGN KEY (user_id) + REFERENCES charm_user (id) + ON DELETE CASCADE + ON UPDATE CASCADE +); + +CREATE TABLE IF NOT EXISTS encrypt_key( + id INTEGER NOT NULL PRIMARY KEY, + public_key_id integer NOT NULL, + global_id uuid NOT NULL, + created_at timestamp default current_timestamp, + encrypted_key varchar(2048) NOT NULL, + CONSTRAINT public_key_id_fk + FOREIGN KEY (public_key_id) + REFERENCES public_key (id) + ON DELETE CASCADE + ON UPDATE CASCADE +); + +CREATE TABLE IF NOT EXISTS named_seq( + id INTEGER NOT NULL PRIMARY KEY, + user_id integer NOT NULL, + seq integer NOT NULL DEFAULT 0, + name varchar(1024) NOT NULL, + UNIQUE (user_id, name), + CONSTRAINT user_id_fk + FOREIGN KEY (user_id) + REFERENCES charm_user (id) + ON DELETE CASCADE + ON UPDATE CASCADE +); + +CREATE TABLE IF NOT EXISTS news( + id INTEGER NOT NULL PRIMARY KEY, + subject text, + body text, + created_at timestamp default current_timestamp +); + +CREATE TABLE IF NOT EXISTS news_tag( + id INTEGER NOT NULL PRIMARY KEY, + tag varchar(250), + news_id integer NOT NULL, + CONSTRAINT news_id_fk + FOREIGN KEY (news_id) + REFERENCES news (id) + ON DELETE CASCADE + ON UPDATE CASCADE +); + +CREATE TABLE IF NOT EXISTS token( + id INTEGER NOT NULL PRIMARY KEY, + pin text UNIQUE NOT NULL, + created_at timestamp default current_timestamp +); +`, +} diff --git a/server/db/sqlite/migration/0001_foreign_keys.go b/server/db/sqlite/migration/0002_foreign_keys.go similarity index 93% rename from server/db/sqlite/migration/0001_foreign_keys.go rename to server/db/sqlite/migration/0002_foreign_keys.go index fd4f88df..63982f72 100644 --- a/server/db/sqlite/migration/0001_foreign_keys.go +++ b/server/db/sqlite/migration/0002_foreign_keys.go @@ -1,9 +1,9 @@ package migration -// Migration0001 is the initial migration. -var Migration0001 = Migration{ - ID: 1, - Name: "foreign keys", +// Migration0002 is the initial inclusion of foreign keys. +var Migration0002 = Migration{ + Version: 2, + Name: "foreign keys", SQL: ` PRAGMA foreign_keys=off; diff --git a/server/db/sqlite/migration/migration.go b/server/db/sqlite/migration/migration.go index 5def2692..80fe60ae 100644 --- a/server/db/sqlite/migration/migration.go +++ b/server/db/sqlite/migration/migration.go @@ -1,8 +1,70 @@ package migration +import ( + "fmt" + "time" + + "github.com/charmbracelet/log" +) + // Migration is a db migration script. type Migration struct { - ID int - Name string - SQL string + Version int + Name string + SQL string +} + +type Version struct { + Version int + Name *string + CompletedAt *time.Time + ErrorAt *time.Time + Comment *string + CreatedAt *time.Time + UpdatedAt *time.Time +} + +func (v Version) String() string { + return fmt.Sprintf( + "Version: %d, Name: %s, CompletedAt: %s, ErrorAt: %s, Comment: %s, CreatedAt: %s, UpdatedAt: %s", + v.Version, + safeString(v.Name), + safeTime(v.CompletedAt), + safeTime(v.ErrorAt), + safeString(v.Comment), + safeTime(v.CreatedAt), + safeTime(v.UpdatedAt), + ) +} +func safeString(s *string) string { + if s != nil { + return *s + } + return "nil" +} +func safeTime(t *time.Time) string { + if t != nil { + return t.Format(time.RFC3339) + } + return "nil" +} + +var Migrations = []Migration{ + Migration0001, + Migration0002, +} + +// Validate validates the migration sequence. +// It returns an error if the sequence is not valid. +// Each migration must have a unique version number and +// the version numbers must be in sequence starting from 1. +func Validate() error { + log.Info("validating migrations") + for i, m := range Migrations { + if i+1 != m.Version { + log.Error("migration is not in sequence", "expected", i+1, "actual", m.Version, "migration", m) + return fmt.Errorf("migration %d is not in sequence, expected %d, name %s", m.Version, i+1, m.Name) + } + } + return nil } diff --git a/server/db/sqlite/sql.go b/server/db/sqlite/sql.go index 0f36606e..8d56374f 100644 --- a/server/db/sqlite/sql.go +++ b/server/db/sqlite/sql.go @@ -1,87 +1,35 @@ package sqlite const ( - sqlCreateUserTable = `CREATE TABLE IF NOT EXISTS charm_user( - id INTEGER NOT NULL PRIMARY KEY, - charm_id uuid UNIQUE NOT NULL, - name varchar(50) UNIQUE, - email varchar(254), - bio varchar(1000), - created_at timestamp default current_timestamp - )` - - sqlCreatePublicKeyTable = `CREATE TABLE IF NOT EXISTS public_key( - id INTEGER NOT NULL PRIMARY KEY, - user_id integer NOT NULL, - public_key varchar(2048) NOT NULL, - created_at timestamp default current_timestamp, - UNIQUE (user_id, public_key), - CONSTRAINT user_id_fk - FOREIGN KEY (user_id) - REFERENCES charm_user (id) - ON DELETE CASCADE - ON UPDATE CASCADE - )` - - sqlCreateEncryptKeyTable = `CREATE TABLE IF NOT EXISTS encrypt_key( - id INTEGER NOT NULL PRIMARY KEY, - public_key_id integer NOT NULL, - global_id uuid NOT NULL, - created_at timestamp default current_timestamp, - encrypted_key varchar(2048) NOT NULL, - CONSTRAINT public_key_id_fk - FOREIGN KEY (public_key_id) - REFERENCES public_key (id) - ON DELETE CASCADE - ON UPDATE CASCADE - )` - - sqlCreateNamedSeqTable = `CREATE TABLE IF NOT EXISTS named_seq( - id INTEGER NOT NULL PRIMARY KEY, - user_id integer NOT NULL, - seq integer NOT NULL DEFAULT 0, - name varchar(1024) NOT NULL, - UNIQUE (user_id, name), - CONSTRAINT user_id_fk - FOREIGN KEY (user_id) - REFERENCES charm_user (id) - ON DELETE CASCADE - ON UPDATE CASCADE - )` - - sqlCreateNewsTable = `CREATE TABLE IF NOT EXISTS news( - id INTEGER NOT NULL PRIMARY KEY, - subject text, - body text, - created_at timestamp default current_timestamp - )` - - sqlCreateNewsTagTable = `CREATE TABLE IF NOT EXISTS news_tag( - id INTEGER NOT NULL PRIMARY KEY, - tag varchar(250), - news_id integer NOT NULL, - CONSTRAINT news_id_fk - FOREIGN KEY (news_id) - REFERENCES news (id) - ON DELETE CASCADE - ON UPDATE CASCADE - )` - - sqlCreateTokenTable = `CREATE TABLE IF NOT EXISTS token( - id INTEGER NOT NULL PRIMARY KEY, - pin text UNIQUE NOT NULL, - created_at timestamp default current_timestamp - )` - - sqlSelectUserWithName = `SELECT id, charm_id, name, email, bio, created_at FROM charm_user WHERE name like ?` - sqlSelectUserWithCharmID = `SELECT id, charm_id, name, email, bio, created_at FROM charm_user WHERE charm_id = ?` - sqlSelectUserWithID = `SELECT id, charm_id, name, email, bio, created_at FROM charm_user WHERE id = ?` - sqlSelectUserPublicKeys = `SELECT id, public_key, created_at FROM public_key WHERE user_id = ?` - sqlSelectNumberUserPublicKeys = `SELECT count(*) FROM public_key WHERE user_id = ?` - sqlSelectPublicKey = `SELECT id, user_id, public_key FROM public_key WHERE public_key = ?` - sqlSelectEncryptKey = `SELECT global_id, encrypted_key, created_at FROM encrypt_key WHERE public_key_id = ? AND global_id = ?` - sqlSelectEncryptKeys = `SELECT global_id, encrypted_key, created_at FROM encrypt_key WHERE public_key_id = ? ORDER BY created_at ASC` - sqlSelectNamedSeq = `SELECT seq FROM named_seq WHERE user_id = ? AND name = ?` + sqlCreateVersionTable = ` + CREATE TABLE IF NOT EXISTS version ( + id INTEGER PRIMARY KEY, + version INTEGER NOT NULL, + name TEXT NOT NULL, + completed_at DATETIME, + error_at DATETIME, + comment TEXT, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + UNIQUE(version) + );` + sqlDropVersionTable = `DROP TABLE IF EXISTS version;` + sqlSelectVersionTableCount = `SELECT count(*) FROM sqlite_master WHERE type='table' AND name='version';` + sqlSelectVersionCount = `SELECT count(*) FROM version;` + sqlSelectLatestVersion = `SELECT version, name, completed_at, error_at, comment, created_at, updated_at FROM version ORDER BY version DESC LIMIT 1;` + sqlSelectIncompleteVersionCount = `SELECT count(*) FROM version WHERE completed_at IS NULL;` + sqlInsertVersion = `INSERT INTO version (version, name, comment) VALUES (?, ?, ?);` + sqlUpdateCompleteVersion = `UPDATE version SET completed_at = CURRENT_TIMESTAMP WHERE version = ?;` + sqlUpdateErrorVersion = `UPDATE version SET error_at = CURRENT_TIMESTAMP, comment = ? WHERE version = ?;` + sqlSelectUserWithName = `SELECT id, charm_id, name, email, bio, created_at FROM charm_user WHERE name like ?` + sqlSelectUserWithCharmID = `SELECT id, charm_id, name, email, bio, created_at FROM charm_user WHERE charm_id = ?` + sqlSelectUserWithID = `SELECT id, charm_id, name, email, bio, created_at FROM charm_user WHERE id = ?` + sqlSelectUserPublicKeys = `SELECT id, public_key, created_at FROM public_key WHERE user_id = ?` + sqlSelectNumberUserPublicKeys = `SELECT count(*) FROM public_key WHERE user_id = ?` + sqlSelectPublicKey = `SELECT id, user_id, public_key FROM public_key WHERE public_key = ?` + sqlSelectEncryptKey = `SELECT global_id, encrypted_key, created_at FROM encrypt_key WHERE public_key_id = ? AND global_id = ?` + sqlSelectEncryptKeys = `SELECT global_id, encrypted_key, created_at FROM encrypt_key WHERE public_key_id = ? ORDER BY created_at ASC` + sqlSelectNamedSeq = `SELECT seq FROM named_seq WHERE user_id = ? AND name = ?` sqlInsertUser = `INSERT INTO charm_user (charm_id) VALUES (?)` @@ -116,9 +64,10 @@ const ( sqlCountUserNames = `SELECT COUNT(*) FROM charm_user WHERE name <> ''` sqlSelectNews = `SELECT id, subject, body, created_at FROM news WHERE id = ?` - sqlSelectNewsList = `SELECT n.id, n.subject, n.created_at FROM news AS n - INNER JOIN news_tag AS t ON t.news_id = n.id - WHERE t.tag = ? - ORDER BY n.created_at desc - LIMIT 50 OFFSET ?` + sqlSelectNewsList = ` + SELECT n.id, n.subject, n.created_at FROM news AS n + INNER JOIN news_tag AS t ON t.news_id = n.id + WHERE t.tag = ? + ORDER BY n.created_at desc + LIMIT 50 OFFSET ?` ) diff --git a/server/db/sqlite/storage.go b/server/db/sqlite/storage.go index 2d53f723..7e5fad89 100644 --- a/server/db/sqlite/storage.go +++ b/server/db/sqlite/storage.go @@ -4,12 +4,14 @@ import ( "context" "database/sql" "fmt" + "net/url" "strconv" "time" "github.com/charmbracelet/log" charm "github.com/charmbracelet/charm/proto" + "github.com/charmbracelet/charm/server/db/sqlite/migration" "github.com/google/uuid" "modernc.org/sqlite" sqlitelib "modernc.org/sqlite/lib" @@ -20,6 +22,7 @@ const ( DbName = "charm_sqlite.db" // The DB default connection options. DbOptions = "?_pragma=busy_timeout(5000)&_pragma=foreign_keys(1)" + Redact = "___REDACTED___" ) // DB is the database struct. @@ -27,22 +30,305 @@ type DB struct { db *sql.DB } +type Tx struct { + tx *sql.Tx +} + +// sanitizePath redacts sensitive information from a database connection string. +func sanitizePath(path string) string { + u, err := url.Parse(path) + if err != nil { + // If the URL is not parseable, return it as is (or handle the error as needed) + return path + } + + // List of query parameters to redact + redactParams := []string{"authToken", "password", "secret"} + + // Replace the values of sensitive query parameters + q := u.Query() + for _, param := range redactParams { + if q.Has(param) { + q.Set(param, Redact) + } + } + + // Reassemble the URL + u.RawQuery = q.Encode() + return u.String() +} + // NewDB creates a new DB in the given path. func NewDB(path string) *DB { var err error log.Debug("Opening SQLite db", "path", path) db, err := sql.Open("sqlite", path+DbOptions) + if err != nil { panic(err) } d := &DB{db: db} - err = d.CreateDB() + + exists, err := d.VersionTableExists() if err != nil { panic(err) } + if !exists { + err = d.CreateDB() + if err != nil { + panic(err) + } + } else { + latest, err := d.LatestVersion() + if err != nil { + log.Error("Error getting latest version. Did the initial migration fail?", "err", err) + log.Error("Dropping version table if it exists and is empty.") + derr := d.DropVersionTableIfEmpty() + if derr != nil { + log.Error("Error dropping version table", "err", derr) + } + panic(err) + } + log.Info("Latest version", "version", latest.Version, "name", *latest.Name, "completed_at", latest.CompletedAt, "error_at", latest.ErrorAt, "comment", latest.Comment, "created_at", latest.CreatedAt, "updated_at", latest.UpdatedAt) + if latest.Version != migration.Migrations[len(migration.Migrations)-1].Version { + log.Info("The database may be out of date.", "latest_db_version", latest.Version, "latest_code_version", migration.Migrations[len(migration.Migrations)-1].Version, "latest_db", latest) + log.Info("Latest Code version", "latest_code", migration.Migrations[len(migration.Migrations)-1]) + } + incomplete, err := d.IncompleteVersionExists() + if err != nil { + panic(err) + } + if incomplete { + if !latest.ErrorAt.IsZero() { + log.Error("The latest version has an error. Please manually ensure all version migrations are complete, then try again.", "latest_db_version", latest.Version, "latest_code_version", migration.Migrations[len(migration.Migrations)-1].Version, "latest_db", latest, "latest_code", migration.Migrations[len(migration.Migrations)-1]) + panic("The database is in an incomplete state. The latest version has an error Please manually ensure all version migrations are complete, then try again.") + } else if latest.CompletedAt.IsZero() { + log.Error("The latest version is incomplete. Please wait & ensure all version migrations are complete, then try again.", "latest_db_version", latest.Version, "latest_code_version", migration.Migrations[len(migration.Migrations)-1].Version, "latest_db", latest, "latest_code", migration.Migrations[len(migration.Migrations)-1]) + panic("The database is in an incomplete state. The latest version is incomplete. Please wait & ensure all version migrations are complete, then try again.") + } else { + log.Error("The database is in an unknown state. The latest version is complete, but there are incomplete versions. Please manually ensure all version migrations are complete, then try again.", "latest_db_version", latest.Version, "latest_code_version", migration.Migrations[len(migration.Migrations)-1].Version, "latest_db", latest, "latest_code", migration.Migrations[len(migration.Migrations)-1]) + panic("The database is in an unknown state. The latest version is complete, but there are incomplete versions. Please manually ensure all version migrations are complete, then try again.") + } + } + } return d } +// VersionTableExists returns true if the version table exists. +func (me *DB) VersionTableExists() (bool, error) { + var c int + r := me.db.QueryRow(sqlSelectVersionTableCount) + if err := r.Scan(&c); err != nil { + return false, err + } + return c > 0, nil +} + +// IncompleteVersionExists returns true if there are incomplete versions. +func (me *DB) IncompleteVersionExists() (bool, error) { + var c int + r := me.db.QueryRow(sqlSelectIncompleteVersionCount) + if err := r.Scan(&c); err != nil { + return false, err + } + return c > 0, nil +} + +// LatestVersion returns the latest version. +func (me *DB) LatestVersion() (*migration.Version, error) { + log.Debug("Getting latest version") + v := &migration.Version{} + r := me.db.QueryRow(sqlSelectLatestVersion) + log.Debug("Scanning latest version row") + err := r.Scan(&v.Version, &v.Name, &v.CompletedAt, &v.ErrorAt, &v.Comment, &v.CreatedAt, &v.UpdatedAt) + if err != nil { + log.Error("Error getting latest version", "err", err) + return nil, err + } + log.Debug("Got latest version", "version", v.Version, "name", *v.Name, "completed_at", v.CompletedAt, "error_at", v.ErrorAt, "comment", v.Comment, "created_at", v.CreatedAt, "updated_at", v.UpdatedAt) + return v, nil +} + +// Migrate runs the migrations. +func (me *DB) Migrate() error { + log.Info("Running migrations") + err := migration.Validate() + if err != nil { + return err + } + latest, err := me.LatestVersion() + if err != nil && err != sql.ErrNoRows { + return err + } + if err == sql.ErrNoRows { + latest = &migration.Version{} + log.Info("No previous migrations found") + } + log.Info("Latest version", "version", latest.Version, "name", *latest.Name, "completed_at", latest.CompletedAt, "error_at", latest.ErrorAt, "comment", latest.Comment, "created_at", latest.CreatedAt, "updated_at", latest.UpdatedAt) + + executedMigrations := 0 + skippedMigrations := 0 + + for _, m := range migration.Migrations { + if m.Version <= latest.Version { + log.Info("Skipping migration", "version", m.Version, "name", m.Name) + skippedMigrations++ + continue + } + log.Info("Running migration", "version", m.Version, "name", m.Name) + err := me.InsertVersion(m.Version, m.Name, nil) + if err != nil { + log.Error("Error inserting version", "version", m.Version, "name", m.Name, "err", err) + return err + } + err = me.WrapTransaction(func(tx *sql.Tx) error { + transaction := Tx{tx: tx} + _, err := transaction.tx.Exec(m.SQL) + if err != nil { + log.Error("Error executing migration", "version", m.Version, "name", m.Name, "err", err) + return err + } + executedMigrations++ + err = transaction.UpdateCompleteVersion(m.Version) + if err != nil { + log.Error("Error updating version", "version", m.Version, "name", m.Name, "err", err) + return err + } + return nil + }) + if err != nil { + err2 := me.UpdateErrorVersion(m.Version, err.Error()) + if err2 != nil { + log.Error("Error updating version", "version", m.Version, "err", err2) + } + return err + } + log.Info("Migration complete", "version", m.Version, "name", m.Name, "committed", "true") + } + log.Info("Migrations complete", "version", migration.Migrations[len(migration.Migrations)-1].Version, "name", migration.Migrations[len(migration.Migrations)-1].Name, "executed", executedMigrations, "skipped", skippedMigrations, "total", len(migration.Migrations)) + return nil +} + +// UpdateCompleteVersion updates the version table with the given version. +func (me Tx) UpdateCompleteVersion(version int) error { + log.Info("Updating version to complete", "version", version) + _, err := me.tx.Exec(sqlUpdateCompleteVersion, version) + if err != nil { + log.Error("Error updating version to complete", "version", version, "err", err) + } else { + log.Info("Updated version to complete", "version", version) + } + return err +} + +// UpdateErrorVersion updates the version table with the given version. +func (me DB) UpdateErrorVersion(version int, comment string) error { + log.Info("Updating version with error", "version", version, "comment", comment) + _, err := me.db.Exec(sqlUpdateErrorVersion, comment, version) + if err != nil { + log.Error("Error updating version", "version", version, "comment", comment, "err", err) + } else { + log.Info("Updated version with error", "version", version, "comment", comment) + } + return err +} + +// InsertVersion inserts a version into the version table. +func (me DB) InsertVersion(version int, name string, comment *string) error { + log.Info("Inserting version", "version", version, "name", name, "comment", comment) + _, err := me.db.Exec(sqlInsertVersion, version, name, comment) + + if err != nil { + log.Error("Error inserting version", "version", version, "name", name, "comment", comment, "err", err) + } else { + log.Info("Inserted version", "version", version, "name", name, "comment", comment) + } + return err +} + +// CreateVersionTable creates the version table. +func (me *DB) CreateVersionTable() error { + log.Info("Creating version table") + _, err := me.db.Exec(sqlCreateVersionTable) + if err != nil { + return err + } + return nil +} + +// CreateDB creates the database. +func (me *DB) CreateDB() error { + log.Info("Creating SQLite db") + err := me.CreateVersionTable() + if err != nil { + log.Error("Error creating version table", "err", err) + return err + } + log.Info("Running migrations") + err = me.Migrate() + if err != nil { + log.Error("Error migrating database", "err", err) + versionCount, verr := me.VersionCount() + if verr != nil { + log.Error("Error getting version count", "err", verr) + return verr + } + log.Error("Error migrating database", "version_count", versionCount, "err", err) + if versionCount == 0 { + log.Error("No versions found, dropping version table") + err = me.DropVersionTableIfEmpty() + if err != nil { + log.Error("Error dropping version table", "err", err) + } + } + return err + } + return nil +} + +// DropVersionTableIfEmpty drops the version table. +func (me *DB) DropVersionTableIfEmpty() error { + log.Info("Dropping version table if empty") + exists, err := me.VersionTableExists() + if err != nil { + log.Error("Error checking if version table exists", "err", err) + return err + } + if !exists { + log.Info("Version table does not exist") + return nil + } + log.Info("Version table exists", "exists", exists) + versionCount, err := me.VersionCount() + if err != nil { + log.Error("Error getting version count", "err", err) + return err + } + log.Error("Version count", "count", versionCount) + if versionCount != 0 { + log.Error("Version table is not empty", "count", versionCount) + return fmt.Errorf("version table is not empty") + } + log.Info("Dropping version table because it is empty") + _, err = me.db.Exec(sqlDropVersionTable) + if err != nil { + log.Error("Error dropping version table", "err", err) + return err + } + log.Info("Dropped version table") + return nil +} + +// VersionCount returns the number of versions. +func (me *DB) VersionCount() (int, error) { + var c int + r := me.db.QueryRow(sqlSelectVersionCount) + if err := r.Scan(&c); err != nil { + return 0, err + } + return c, nil +} + // UserCount returns the number of users. func (me *DB) UserCount() (int, error) { var c int @@ -86,7 +372,7 @@ func (me *DB) GetUserWithName(name string) (*charm.User, error) { // SetUserName sets a user name for the given user id. func (me *DB) SetUserName(charmID string, name string) (*charm.User, error) { var u *charm.User - log.Debug("Setting name for user", "name", name, "id", charmID) + log.Info("Setting name for user", "name", name, "id", charmID) err := me.WrapTransaction(func(tx *sql.Tx) error { // nolint: godox // TODO: this should be handled with unique constraints in the database instead. @@ -143,7 +429,7 @@ func (me *DB) UserForKey(key string, create bool) (*charm.User, error) { return charm.ErrMissingUser } if err == sql.ErrNoRows { - log.Debug("Creating user for key", "key", charm.PublicKeySha(key)) + log.Info("Creating user for key", "key", charm.PublicKeySha(key)) err = me.createUser(tx, key) if err != nil { return err @@ -174,7 +460,7 @@ func (me *DB) UserForKey(key string, create bool) (*charm.User, error) { // AddEncryptKeyForPublicKey adds an ecrypted key to the user. func (me *DB) AddEncryptKeyForPublicKey(u *charm.User, pk string, gid string, ek string, ca *time.Time) error { - log.Debug("Adding encrypted key for user", "key", gid, "time", ca, "id", u.CharmID) + log.Info("Adding encrypted key for user", "key", gid, "time", ca, "id", u.CharmID) return me.WrapTransaction(func(tx *sql.Tx) error { u2, err := me.UserForKey(pk, false) if err != nil { @@ -193,7 +479,7 @@ func (me *DB) AddEncryptKeyForPublicKey(u *charm.User, pk string, gid string, ek if err == sql.ErrNoRows { return me.insertEncryptKey(tx, ek, gid, u2.PublicKey.ID, ca) } - log.Debug("Encrypt key already exists for public key, skipping", "key", gid) + log.Info("Encrypt key already exists for public key, skipping", "key", gid) return nil }) } @@ -229,7 +515,7 @@ func (me *DB) EncryptKeysForPublicKey(pk *charm.PublicKey) ([]*charm.EncryptKey, // LinkUserKey links a user to a key. func (me *DB) LinkUserKey(user *charm.User, key string) error { ks := charm.PublicKeySha(key) - log.Debug("Linking user and key", "id", user.CharmID, "key", ks) + log.Info("Linking user and key", "id", user.CharmID, "key", ks) return me.WrapTransaction(func(tx *sql.Tx) error { return me.insertPublicKey(tx, user.ID, key) }) @@ -238,7 +524,7 @@ func (me *DB) LinkUserKey(user *charm.User, key string) error { // UnlinkUserKey unlinks the key from the user. func (me *DB) UnlinkUserKey(user *charm.User, key string) error { ks := charm.PublicKeySha(key) - log.Debug("Unlinking user key", "id", user.CharmID, "key", ks) + log.Info("Unlinking user key", "id", user.CharmID, "key", ks) return me.WrapTransaction(func(tx *sql.Tx) error { err := me.deleteUserPublicKey(tx, user.ID, key) if err != nil { @@ -251,7 +537,7 @@ func (me *DB) UnlinkUserKey(user *charm.User, key string) error { return err } if count == 0 { - log.Debug("Removing last key for account, deleting", "id", user.CharmID) + log.Info("Removing last key for account, deleting", "id", user.CharmID) // nolint: godox // TODO: Where to put glow stuff // err := me.deleteUserStashMarkdown(tx, user.ID) @@ -267,7 +553,7 @@ func (me *DB) UnlinkUserKey(user *charm.User, key string) error { // KeysForUser returns all user's public keys. func (me *DB) KeysForUser(user *charm.User) ([]*charm.PublicKey, error) { var keys []*charm.PublicKey - log.Debug("Getting keys for user", "id", user.CharmID) + log.Info("Getting keys for user", "id", user.CharmID) err := me.WrapTransaction(func(tx *sql.Tx) error { rs, err := me.selectUserPublicKeys(tx, user.ID) if err != nil { @@ -412,44 +698,9 @@ func (me *DB) DeleteToken(token charm.Token) error { }) } -// CreateDB creates the database. -func (me *DB) CreateDB() error { - return me.WrapTransaction(func(tx *sql.Tx) error { - err := me.createUserTable(tx) - if err != nil { - return err - } - err = me.createPublicKeyTable(tx) - if err != nil { - return err - } - err = me.createNamedSeqTable(tx) - if err != nil { - return err - } - err = me.createEncryptKeyTable(tx) - if err != nil { - return err - } - err = me.createNewsTable(tx) - if err != nil { - return err - } - err = me.createNewsTagTable(tx) - if err != nil { - return err - } - err = me.createTokenTable(tx) - if err != nil { - return err - } - return nil - }) -} - // Close the db. func (me *DB) Close() error { - log.Debug("Closing db") + log.Info("Closing db") return me.db.Close() } @@ -592,41 +843,6 @@ func (me *DB) updateMergePublicKeys(tx *sql.Tx, userID1 int, userID2 int) error return err } -func (me *DB) createUserTable(tx *sql.Tx) error { - _, err := tx.Exec(sqlCreateUserTable) - return err -} - -func (me *DB) createPublicKeyTable(tx *sql.Tx) error { - _, err := tx.Exec(sqlCreatePublicKeyTable) - return err -} - -func (me *DB) createEncryptKeyTable(tx *sql.Tx) error { - _, err := tx.Exec(sqlCreateEncryptKeyTable) - return err -} - -func (me *DB) createNamedSeqTable(tx *sql.Tx) error { - _, err := tx.Exec(sqlCreateNamedSeqTable) - return err -} - -func (me *DB) createNewsTable(tx *sql.Tx) error { - _, err := tx.Exec(sqlCreateNewsTable) - return err -} - -func (me *DB) createNewsTagTable(tx *sql.Tx) error { - _, err := tx.Exec(sqlCreateNewsTagTable) - return err -} - -func (me *DB) createTokenTable(tx *sql.Tx) error { - _, err := tx.Exec(sqlCreateTokenTable) - return err -} - func (me *DB) scanUser(r *sql.Row) (*charm.User, error) { u := &charm.User{} var un, ue, ub sql.NullString @@ -652,6 +868,7 @@ func (me *DB) scanUser(r *sql.Row) (*charm.User, error) { // WrapTransaction runs the given function within a transaction. func (me *DB) WrapTransaction(f func(tx *sql.Tx) error) error { + me.db.Driver() ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() tx, err := me.db.BeginTx(ctx, nil) diff --git a/server/http.go b/server/http.go index e308737f..ea7b8fd8 100644 --- a/server/http.go +++ b/server/http.go @@ -116,7 +116,7 @@ func (s *HTTPServer) Start() error { scheme := strings.ToUpper(s.httpScheme) errg, _ := errgroup.WithContext(context.Background()) errg.Go(func() error { - log.Print("Starting health server", "scheme", scheme, "addr", s.health.Addr) + log.Info("Starting health server", "scheme", scheme, "addr", s.health.Addr) if s.cfg.UseTLS { err := s.health.ListenAndServeTLS(s.cfg.TLSCertFile, s.cfg.TLSKeyFile) if err != http.ErrServerClosed { @@ -131,7 +131,7 @@ func (s *HTTPServer) Start() error { return nil }) errg.Go(func() error { - log.Print("Starting server", "scheme", scheme, "addr", s.server.Addr) + log.Info("Starting server", "scheme", scheme, "addr", s.server.Addr) if s.cfg.UseTLS { err := s.server.ListenAndServeTLS(s.cfg.TLSCertFile, s.cfg.TLSKeyFile) if err != http.ErrServerClosed { @@ -151,8 +151,8 @@ func (s *HTTPServer) Start() error { // Shutdown gracefully shut down the HTTP and health servers. func (s *HTTPServer) Shutdown(ctx context.Context) error { scheme := strings.ToUpper(s.httpScheme) - log.Print("Stopping server", "scheme", scheme, "addr", s.server.Addr) - log.Print("Stopping health server", "scheme", scheme, "addr", s.health.Addr) + log.Info("Stopping server", "scheme", scheme, "addr", s.server.Addr) + log.Info("Stopping health server", "scheme", scheme, "addr", s.health.Addr) if err := s.health.Shutdown(ctx); err != nil { return err } diff --git a/server/link.go b/server/link.go index eff5862c..b055fa1a 100644 --- a/server/link.go +++ b/server/link.go @@ -301,7 +301,7 @@ func (me *SSHServer) handleLinkRequestAPI(s ssh.Session) { _ = me.sendAPIMessage(s, fmt.Sprintf("Missing public key %s", err)) return } - log.Print("API link request") + log.Info("API link request") linker := &SSHLinker{ session: s, server: me, @@ -328,7 +328,7 @@ func (me *SSHServer) handleAPILink(s ssh.Session) { func (me *SSHServer) handleAPIUnlink(s ssh.Session) { key, err := keyText(s) if err != nil { - log.Print(err) + log.Info(err) _ = me.sendAPIMessage(s, "Missing key") return } diff --git a/server/ssh.go b/server/ssh.go index e2137607..8926fc35 100644 --- a/server/ssh.go +++ b/server/ssh.go @@ -86,7 +86,7 @@ func NewSSHServer(cfg *Config) (*SSHServer, error) { // Start serves the SSH protocol on the configured port. func (me *SSHServer) Start() error { - log.Print("Starting SSH server", "addr", me.server.Addr) + log.Info("Starting SSH server", "addr", me.server.Addr) if err := me.server.ListenAndServe(); err != ssh.ErrServerClosed { return err } @@ -95,7 +95,7 @@ func (me *SSHServer) Start() error { // Shutdown gracefully shuts down the SSH server. func (me *SSHServer) Shutdown(ctx context.Context) error { - log.Print("Stopping SSH server", "addr", me.server.Addr) + log.Info("Stopping SSH server", "addr", me.server.Addr) return me.server.Shutdown(ctx) } diff --git a/server/stats/prometheus/prometheus.go b/server/stats/prometheus/prometheus.go index 83fb7e07..8b0408ef 100644 --- a/server/stats/prometheus/prometheus.go +++ b/server/stats/prometheus/prometheus.go @@ -59,7 +59,7 @@ func (ps *Stats) Start() error { time.Sleep(time.Minute) } }() - log.Print("Starting Stats HTTP server", "addr", ps.server.Addr) + log.Info("Starting Stats HTTP server", "addr", ps.server.Addr) err := ps.server.ListenAndServe() if err != http.ErrServerClosed { return err From 6866137cfeacb713378634b35158c72af83aee46 Mon Sep 17 00:00:00 2001 From: Drewry Pope Date: Fri, 5 Jan 2024 21:04:50 -0600 Subject: [PATCH 2/4] Update migration.go --- server/db/sqlite/migration/migration.go | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/server/db/sqlite/migration/migration.go b/server/db/sqlite/migration/migration.go index 80fe60ae..39b5ce6c 100644 --- a/server/db/sqlite/migration/migration.go +++ b/server/db/sqlite/migration/migration.go @@ -7,6 +7,13 @@ import ( "github.com/charmbracelet/log" ) +// Migrations is a list of all migrations. +// The migrations must be in sequence starting from 1. +var Migrations = []Migration{ + Migration0001, + Migration0002, +} + // Migration is a db migration script. type Migration struct { Version int @@ -49,17 +56,15 @@ func safeTime(t *time.Time) string { return "nil" } -var Migrations = []Migration{ - Migration0001, - Migration0002, -} - // Validate validates the migration sequence. // It returns an error if the sequence is not valid. // Each migration must have a unique version number and // the version numbers must be in sequence starting from 1. func Validate() error { log.Info("validating migrations") + // later, this could be changed to ensure all versions are sequential starting from the first item in the array + // this would remove the requirement to have all versions starting from 1. + // this would allow to 'prune' or 'compact' previous versions in some way while continuing the general version scheme. for i, m := range Migrations { if i+1 != m.Version { log.Error("migration is not in sequence", "expected", i+1, "actual", m.Version, "migration", m) From 114d26d5150bc1f15d24403c5c716b860088beee Mon Sep 17 00:00:00 2001 From: Drewry Pope Date: Fri, 5 Jan 2024 21:25:33 -0600 Subject: [PATCH 3/4] Update storage.go --- server/db/sqlite/storage.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server/db/sqlite/storage.go b/server/db/sqlite/storage.go index 7e5fad89..458b1e71 100644 --- a/server/db/sqlite/storage.go +++ b/server/db/sqlite/storage.go @@ -91,8 +91,8 @@ func NewDB(path string) *DB { } log.Info("Latest version", "version", latest.Version, "name", *latest.Name, "completed_at", latest.CompletedAt, "error_at", latest.ErrorAt, "comment", latest.Comment, "created_at", latest.CreatedAt, "updated_at", latest.UpdatedAt) if latest.Version != migration.Migrations[len(migration.Migrations)-1].Version { - log.Info("The database may be out of date.", "latest_db_version", latest.Version, "latest_code_version", migration.Migrations[len(migration.Migrations)-1].Version, "latest_db", latest) - log.Info("Latest Code version", "latest_code", migration.Migrations[len(migration.Migrations)-1]) + log.Info("The database may be out of date.", "latest_code_version", migration.Migrations[len(migration.Migrations)-1].Version, "latest_code_version_name", migration.Migrations[len(migration.Migrations)-1].Name, "latest_db_version", latest.Version, "latest_db_version_name", *latest.Name, "latest_db_version_completed_at", latest.CompletedAt, "latest_db_version_error_at", latest.ErrorAt, "latest_db_version_comment", latest.Comment) + log.Debug("Latest Code version", "latest_code", migration.Migrations[len(migration.Migrations)-1]) } incomplete, err := d.IncompleteVersionExists() if err != nil { From 803bb0b21db1cd1ead86dd65824146f28cd56779 Mon Sep 17 00:00:00 2001 From: Drewry Pope Date: Fri, 5 Jan 2024 21:40:00 -0600 Subject: [PATCH 4/4] handle first-load --- server/db/sqlite/storage.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/server/db/sqlite/storage.go b/server/db/sqlite/storage.go index 458b1e71..6ddbde60 100644 --- a/server/db/sqlite/storage.go +++ b/server/db/sqlite/storage.go @@ -142,7 +142,7 @@ func (me *DB) LatestVersion() (*migration.Version, error) { log.Debug("Scanning latest version row") err := r.Scan(&v.Version, &v.Name, &v.CompletedAt, &v.ErrorAt, &v.Comment, &v.CreatedAt, &v.UpdatedAt) if err != nil { - log.Error("Error getting latest version", "err", err) + log.Debug("Error getting latest version", "err", err) return nil, err } log.Debug("Got latest version", "version", v.Version, "name", *v.Name, "completed_at", v.CompletedAt, "error_at", v.ErrorAt, "comment", v.Comment, "created_at", v.CreatedAt, "updated_at", v.UpdatedAt) @@ -164,7 +164,9 @@ func (me *DB) Migrate() error { latest = &migration.Version{} log.Info("No previous migrations found") } - log.Info("Latest version", "version", latest.Version, "name", *latest.Name, "completed_at", latest.CompletedAt, "error_at", latest.ErrorAt, "comment", latest.Comment, "created_at", latest.CreatedAt, "updated_at", latest.UpdatedAt) + if err == nil { + log.Info("Latest version", "version", latest.Version, "name", *latest.Name, "completed_at", latest.CompletedAt, "error_at", latest.ErrorAt, "comment", latest.Comment, "created_at", latest.CreatedAt, "updated_at", latest.UpdatedAt) + } executedMigrations := 0 skippedMigrations := 0