From 19b7dd5e10a9d4649b384e081c7843d36a72def4 Mon Sep 17 00:00:00 2001 From: Raymond Tukpe Date: Tue, 12 Nov 2024 14:53:36 +0100 Subject: [PATCH 1/7] feat: moved postgres migration files to its own folder --- sql/{ => postgres}/1677078479.sql | 0 sql/{ => postgres}/1677770163.sql | 0 sql/{ => postgres}/1679836136.sql | 0 sql/{ => postgres}/1681821969.sql | 0 sql/{ => postgres}/1684504109.sql | 0 sql/{ => postgres}/1684884904.sql | 0 sql/{ => postgres}/1684918027.sql | 0 sql/{ => postgres}/1684929840.sql | 0 sql/{ => postgres}/1685202737.sql | 0 sql/{ => postgres}/1686048402.sql | 0 sql/{ => postgres}/1686656160.sql | 0 sql/{ => postgres}/1692024707.sql | 0 sql/{ => postgres}/1692105853.sql | 0 sql/{ => postgres}/1692699318.sql | 0 sql/{ => postgres}/1693908172.sql | 0 sql/{ => postgres}/1698074481.sql | 0 sql/{ => postgres}/1698683940.sql | 0 sql/{ => postgres}/1704372039.sql | 0 sql/{ => postgres}/1705562731.sql | 0 sql/{ => postgres}/1705575999.sql | 0 sql/{ => postgres}/1708434555.sql | 0 sql/{ => postgres}/1709568783.sql | 0 sql/{ => postgres}/1709729972.sql | 0 sql/{ => postgres}/1710685343.sql | 0 sql/{ => postgres}/1710763531.sql | 0 sql/{ => postgres}/1715118159.sql | 0 sql/{ => postgres}/1715849591.sql | 0 sql/{ => postgres}/1716983708.sql | 0 sql/{ => postgres}/1717504961.sql | 0 sql/{ => postgres}/1718819653.sql | 0 sql/{ => postgres}/1719521272.sql | 0 sql/{ => postgres}/1719521273.sql | 0 sql/{ => postgres}/1721127636.sql | 0 sql/{ => postgres}/1724236900.sql | 0 sql/{ => postgres}/1724932863.sql | 0 sql/{ => postgres}/1729585620.sql | 0 sql/{ => postgres}/1729709223.sql | 0 sql/{ => postgres}/1729941918.sql | 0 sql/{ => postgres}/1730027360.sql | 0 sql/{ => postgres}/1730719365.sql | 0 40 files changed, 0 insertions(+), 0 deletions(-) rename sql/{ => postgres}/1677078479.sql (100%) rename sql/{ => postgres}/1677770163.sql (100%) rename sql/{ => postgres}/1679836136.sql (100%) rename sql/{ => postgres}/1681821969.sql (100%) rename sql/{ => postgres}/1684504109.sql (100%) rename sql/{ => postgres}/1684884904.sql (100%) rename sql/{ => postgres}/1684918027.sql (100%) rename sql/{ => postgres}/1684929840.sql (100%) rename sql/{ => postgres}/1685202737.sql (100%) rename sql/{ => postgres}/1686048402.sql (100%) rename sql/{ => postgres}/1686656160.sql (100%) rename sql/{ => postgres}/1692024707.sql (100%) rename sql/{ => postgres}/1692105853.sql (100%) rename sql/{ => postgres}/1692699318.sql (100%) rename sql/{ => postgres}/1693908172.sql (100%) rename sql/{ => postgres}/1698074481.sql (100%) rename sql/{ => postgres}/1698683940.sql (100%) rename sql/{ => postgres}/1704372039.sql (100%) rename sql/{ => postgres}/1705562731.sql (100%) rename sql/{ => postgres}/1705575999.sql (100%) rename sql/{ => postgres}/1708434555.sql (100%) rename sql/{ => postgres}/1709568783.sql (100%) rename sql/{ => postgres}/1709729972.sql (100%) rename sql/{ => postgres}/1710685343.sql (100%) rename sql/{ => postgres}/1710763531.sql (100%) rename sql/{ => postgres}/1715118159.sql (100%) rename sql/{ => postgres}/1715849591.sql (100%) rename sql/{ => postgres}/1716983708.sql (100%) rename sql/{ => postgres}/1717504961.sql (100%) rename sql/{ => postgres}/1718819653.sql (100%) rename sql/{ => postgres}/1719521272.sql (100%) rename sql/{ => postgres}/1719521273.sql (100%) rename sql/{ => postgres}/1721127636.sql (100%) rename sql/{ => postgres}/1724236900.sql (100%) rename sql/{ => postgres}/1724932863.sql (100%) rename sql/{ => postgres}/1729585620.sql (100%) rename sql/{ => postgres}/1729709223.sql (100%) rename sql/{ => postgres}/1729941918.sql (100%) rename sql/{ => postgres}/1730027360.sql (100%) rename sql/{ => postgres}/1730719365.sql (100%) diff --git a/sql/1677078479.sql b/sql/postgres/1677078479.sql similarity index 100% rename from sql/1677078479.sql rename to sql/postgres/1677078479.sql diff --git a/sql/1677770163.sql b/sql/postgres/1677770163.sql similarity index 100% rename from sql/1677770163.sql rename to sql/postgres/1677770163.sql diff --git a/sql/1679836136.sql b/sql/postgres/1679836136.sql similarity index 100% rename from sql/1679836136.sql rename to sql/postgres/1679836136.sql diff --git a/sql/1681821969.sql b/sql/postgres/1681821969.sql similarity index 100% rename from sql/1681821969.sql rename to sql/postgres/1681821969.sql diff --git a/sql/1684504109.sql b/sql/postgres/1684504109.sql similarity index 100% rename from sql/1684504109.sql rename to sql/postgres/1684504109.sql diff --git a/sql/1684884904.sql b/sql/postgres/1684884904.sql similarity index 100% rename from sql/1684884904.sql rename to sql/postgres/1684884904.sql diff --git a/sql/1684918027.sql b/sql/postgres/1684918027.sql similarity index 100% rename from sql/1684918027.sql rename to sql/postgres/1684918027.sql diff --git a/sql/1684929840.sql b/sql/postgres/1684929840.sql similarity index 100% rename from sql/1684929840.sql rename to sql/postgres/1684929840.sql diff --git a/sql/1685202737.sql b/sql/postgres/1685202737.sql similarity index 100% rename from sql/1685202737.sql rename to sql/postgres/1685202737.sql diff --git a/sql/1686048402.sql b/sql/postgres/1686048402.sql similarity index 100% rename from sql/1686048402.sql rename to sql/postgres/1686048402.sql diff --git a/sql/1686656160.sql b/sql/postgres/1686656160.sql similarity index 100% rename from sql/1686656160.sql rename to sql/postgres/1686656160.sql diff --git a/sql/1692024707.sql b/sql/postgres/1692024707.sql similarity index 100% rename from sql/1692024707.sql rename to sql/postgres/1692024707.sql diff --git a/sql/1692105853.sql b/sql/postgres/1692105853.sql similarity index 100% rename from sql/1692105853.sql rename to sql/postgres/1692105853.sql diff --git a/sql/1692699318.sql b/sql/postgres/1692699318.sql similarity index 100% rename from sql/1692699318.sql rename to sql/postgres/1692699318.sql diff --git a/sql/1693908172.sql b/sql/postgres/1693908172.sql similarity index 100% rename from sql/1693908172.sql rename to sql/postgres/1693908172.sql diff --git a/sql/1698074481.sql b/sql/postgres/1698074481.sql similarity index 100% rename from sql/1698074481.sql rename to sql/postgres/1698074481.sql diff --git a/sql/1698683940.sql b/sql/postgres/1698683940.sql similarity index 100% rename from sql/1698683940.sql rename to sql/postgres/1698683940.sql diff --git a/sql/1704372039.sql b/sql/postgres/1704372039.sql similarity index 100% rename from sql/1704372039.sql rename to sql/postgres/1704372039.sql diff --git a/sql/1705562731.sql b/sql/postgres/1705562731.sql similarity index 100% rename from sql/1705562731.sql rename to sql/postgres/1705562731.sql diff --git a/sql/1705575999.sql b/sql/postgres/1705575999.sql similarity index 100% rename from sql/1705575999.sql rename to sql/postgres/1705575999.sql diff --git a/sql/1708434555.sql b/sql/postgres/1708434555.sql similarity index 100% rename from sql/1708434555.sql rename to sql/postgres/1708434555.sql diff --git a/sql/1709568783.sql b/sql/postgres/1709568783.sql similarity index 100% rename from sql/1709568783.sql rename to sql/postgres/1709568783.sql diff --git a/sql/1709729972.sql b/sql/postgres/1709729972.sql similarity index 100% rename from sql/1709729972.sql rename to sql/postgres/1709729972.sql diff --git a/sql/1710685343.sql b/sql/postgres/1710685343.sql similarity index 100% rename from sql/1710685343.sql rename to sql/postgres/1710685343.sql diff --git a/sql/1710763531.sql b/sql/postgres/1710763531.sql similarity index 100% rename from sql/1710763531.sql rename to sql/postgres/1710763531.sql diff --git a/sql/1715118159.sql b/sql/postgres/1715118159.sql similarity index 100% rename from sql/1715118159.sql rename to sql/postgres/1715118159.sql diff --git a/sql/1715849591.sql b/sql/postgres/1715849591.sql similarity index 100% rename from sql/1715849591.sql rename to sql/postgres/1715849591.sql diff --git a/sql/1716983708.sql b/sql/postgres/1716983708.sql similarity index 100% rename from sql/1716983708.sql rename to sql/postgres/1716983708.sql diff --git a/sql/1717504961.sql b/sql/postgres/1717504961.sql similarity index 100% rename from sql/1717504961.sql rename to sql/postgres/1717504961.sql diff --git a/sql/1718819653.sql b/sql/postgres/1718819653.sql similarity index 100% rename from sql/1718819653.sql rename to sql/postgres/1718819653.sql diff --git a/sql/1719521272.sql b/sql/postgres/1719521272.sql similarity index 100% rename from sql/1719521272.sql rename to sql/postgres/1719521272.sql diff --git a/sql/1719521273.sql b/sql/postgres/1719521273.sql similarity index 100% rename from sql/1719521273.sql rename to sql/postgres/1719521273.sql diff --git a/sql/1721127636.sql b/sql/postgres/1721127636.sql similarity index 100% rename from sql/1721127636.sql rename to sql/postgres/1721127636.sql diff --git a/sql/1724236900.sql b/sql/postgres/1724236900.sql similarity index 100% rename from sql/1724236900.sql rename to sql/postgres/1724236900.sql diff --git a/sql/1724932863.sql b/sql/postgres/1724932863.sql similarity index 100% rename from sql/1724932863.sql rename to sql/postgres/1724932863.sql diff --git a/sql/1729585620.sql b/sql/postgres/1729585620.sql similarity index 100% rename from sql/1729585620.sql rename to sql/postgres/1729585620.sql diff --git a/sql/1729709223.sql b/sql/postgres/1729709223.sql similarity index 100% rename from sql/1729709223.sql rename to sql/postgres/1729709223.sql diff --git a/sql/1729941918.sql b/sql/postgres/1729941918.sql similarity index 100% rename from sql/1729941918.sql rename to sql/postgres/1729941918.sql diff --git a/sql/1730027360.sql b/sql/postgres/1730027360.sql similarity index 100% rename from sql/1730027360.sql rename to sql/postgres/1730027360.sql diff --git a/sql/1730719365.sql b/sql/postgres/1730719365.sql similarity index 100% rename from sql/1730719365.sql rename to sql/postgres/1730719365.sql From 83a1c99dab4765bf1bb8834d0e9b88acd3488a71 Mon Sep 17 00:00:00 2001 From: Raymond Tukpe Date: Wed, 13 Nov 2024 10:52:15 +0100 Subject: [PATCH 2/7] feat: added base schema --- cmd/hooks/hooks.go | 2 +- cmd/migrate/migrate.go | 87 +++- config/config.go | 2 + database/sqlite3/sqlite3.go | 66 ++- internal/pkg/migrator/migrator.go | 26 +- sql/sqlite3/1731413863.sql | 712 ++++++++++++++++++++++++++++++ type.go | 10 +- 7 files changed, 869 insertions(+), 36 deletions(-) create mode 100644 sql/sqlite3/1731413863.sql diff --git a/cmd/hooks/hooks.go b/cmd/hooks/hooks.go index 9dee06fc94..ed791dd1bc 100644 --- a/cmd/hooks/hooks.go +++ b/cmd/hooks/hooks.go @@ -667,7 +667,7 @@ func checkPendingMigrations(db database.Database) error { } counter := map[string]ID{} - files, err := convoy.MigrationFiles.ReadDir("sql") + files, err := convoy.PostgresMigrationFiles.ReadDir("sql") if err != nil { return err } diff --git a/cmd/migrate/migrate.go b/cmd/migrate/migrate.go index 97b3525150..efb84984bd 100644 --- a/cmd/migrate/migrate.go +++ b/cmd/migrate/migrate.go @@ -2,6 +2,7 @@ package migrate import ( "fmt" + "github.com/frain-dev/convoy/database/sqlite3" "os" "time" @@ -13,6 +14,11 @@ import ( "github.com/spf13/cobra" ) +var mapping = map[string]string{ + "agent": "postgres", + "server": "sqlite", +} + func AddMigrateCommand(a *cli.App) *cobra.Command { cmd := &cobra.Command{ Use: "migrate", @@ -27,6 +33,7 @@ func AddMigrateCommand(a *cli.App) *cobra.Command { } func addUpCommand() *cobra.Command { + var component string cmd := &cobra.Command{ Use: "up", Aliases: []string{"migrate-up"}, @@ -36,32 +43,65 @@ func addUpCommand() *cobra.Command { "ShouldBootstrap": "false", }, Run: func(cmd *cobra.Command, args []string) { - cfg, err := config.Get() + t, err := cmd.Flags().GetString("component") if err != nil { - log.WithError(err).Fatal("Error fetching the config.") + log.Fatal(err) } - db, err := postgres.NewDB(cfg) - if err != nil { - log.Fatal(err) + if t != "server" && t != "agent" { + log.Fatalf("Invalid component %s. Must be one of: server or agent", t) } - defer db.Close() + switch t { + case "server": + cfg, err := config.Get() + if err != nil { + log.WithError(err).Fatal("[sqlite3] error fetching the config.") + } - m := migrator.New(db) - err = m.Up() - if err != nil { - log.Fatalf("migration up failed with error: %+v", err) + db, err := sqlite3.NewDB(cfg.Database.SqliteDB, log.NewLogger(os.Stdout)) + if err != nil { + log.Fatal(err) + } + + defer db.Close() + + m := migrator.New(db, "sqlite3") + err = m.Up() + if err != nil { + log.Fatalf("[sqlite3] migration up failed with error: %+v", err) + } + case "agent": + cfg, err := config.Get() + if err != nil { + log.WithError(err).Fatal("[postgres] error fetching the config.") + } + + db, err := postgres.NewDB(cfg) + if err != nil { + log.Fatal(err) + } + + defer db.Close() + + m := migrator.New(db, "postgres") + err = m.Up() + if err != nil { + log.Fatalf("[postgres] migration up failed with error: %+v", err) + } } + log.Info("migration up succeeded") }, } + cmd.Flags().StringVarP(&component, "component", "c", "server", "The component to create for: (server|agent)") + return cmd } func addDownCommand() *cobra.Command { - var max int + var maxDown int cmd := &cobra.Command{ Use: "down", @@ -84,29 +124,40 @@ func addDownCommand() *cobra.Command { defer db.Close() - m := migrator.New(db) - err = m.Down(max) + m := migrator.New(db, "postgres") + err = m.Down(maxDown) if err != nil { log.Fatalf("migration down failed with error: %+v", err) } }, } - cmd.Flags().IntVar(&max, "max", 1, "The maximum number of migrations to rollback") + cmd.Flags().IntVar(&maxDown, "max", 1, "The maximum number of migrations to rollback") return cmd } func addCreateCommand() *cobra.Command { + var component string cmd := &cobra.Command{ - Use: "create", - Short: "creates a new migration file", + Use: "create", + Aliases: []string{"migrate-create"}, + Short: "creates a new migration file", Annotations: map[string]string{ "CheckMigration": "false", "ShouldBootstrap": "false", }, Run: func(cmd *cobra.Command, args []string) { - fileName := fmt.Sprintf("sql/%v.sql", time.Now().Unix()) + t, err := cmd.Flags().GetString("component") + if err != nil { + log.Fatal(err) + } + + if t != "server" && t != "agent" { + log.Fatalf("Invalid component %s. Must be one of: server or agent", t) + } + + fileName := fmt.Sprintf("sql/%s/%v.sql", mapping[component], time.Now().Unix()) f, err := os.Create(fileName) if err != nil { log.Fatal(err) @@ -124,5 +175,7 @@ func addCreateCommand() *cobra.Command { }, } + cmd.Flags().StringVarP(&component, "component", "c", "server", "The component to create for: (server|agent)") + return cmd } diff --git a/config/config.go b/config/config.go index 4eb8a4efb4..bb1a36bc9c 100644 --- a/config/config.go +++ b/config/config.go @@ -120,6 +120,8 @@ var DefaultConfiguration = Configuration{ type DatabaseConfiguration struct { Type DatabaseProvider `json:"type" envconfig:"CONVOY_DB_TYPE"` + SqliteDB string `json:"sqlite_db" envconfig:"CONVOY_SQLITE_DB"` + Scheme string `json:"scheme" envconfig:"CONVOY_DB_SCHEME"` Host string `json:"host" envconfig:"CONVOY_DB_HOST"` Username string `json:"username" envconfig:"CONVOY_DB_USERNAME"` diff --git a/database/sqlite3/sqlite3.go b/database/sqlite3/sqlite3.go index 5c0850ed83..85e141926a 100644 --- a/database/sqlite3/sqlite3.go +++ b/database/sqlite3/sqlite3.go @@ -1,7 +1,13 @@ package sqlite3 import ( + "context" + "database/sql" + "errors" "fmt" + "github.com/frain-dev/convoy/database/hooks" + "github.com/frain-dev/convoy/pkg/log" + "io" "github.com/jmoiron/sqlx" _ "github.com/mattn/go-sqlite3" @@ -10,17 +16,69 @@ import ( const pkgName = "sqlite3" type Sqlite struct { - dbx *sqlx.DB + dbx *sqlx.DB + hook *hooks.Hook + logger *log.Logger } -func NewDB() (*Sqlite, error) { - db, err := sqlx.Connect("sqlite3", "convoy.db") +func (s *Sqlite) BeginTx(ctx context.Context) (*sqlx.Tx, error) { + return s.dbx.BeginTxx(ctx, nil) +} + +func (s *Sqlite) GetHook() *hooks.Hook { + if s.hook != nil { + return s.hook + } + + hook, err := hooks.Get() + if err != nil { + log.Fatal(err) + } + + s.hook = hook + return s.hook +} + +func (s *Sqlite) Rollback(tx *sqlx.Tx, err error) { + if err != nil { + rbErr := tx.Rollback() + log.WithError(rbErr).Error("failed to roll back transaction in ProcessBroadcastEventCreation") + } + + cmErr := tx.Commit() + if cmErr != nil && !errors.Is(cmErr, sql.ErrTxDone) { + log.WithError(cmErr).Error("failed to commit tx in ProcessBroadcastEventCreation, rolling back transaction") + rbErr := tx.Rollback() + log.WithError(rbErr).Error("failed to roll back transaction in ProcessBroadcastEventCreation") + } +} + +func (s *Sqlite) Close() error { + return s.dbx.Close() +} + +func NewDB(dbName string, logger *log.Logger) (*Sqlite, error) { + db, err := sqlx.Connect("sqlite3", dbName) if err != nil { return nil, fmt.Errorf("[%s]: failed to open database - %v", pkgName, err) } - return &Sqlite{dbx: db}, nil + return &Sqlite{dbx: db, logger: logger}, nil } func (s *Sqlite) GetDB() *sqlx.DB { return s.dbx } + +func rollbackTx(tx *sqlx.Tx) { + err := tx.Rollback() + if err != nil && !errors.Is(err, sql.ErrTxDone) { + log.WithError(err).Error("failed to rollback tx") + } +} + +func closeWithError(closer io.Closer) { + err := closer.Close() + if err != nil { + fmt.Printf("%v, an error occurred while closing the client", err) + } +} diff --git a/internal/pkg/migrator/migrator.go b/internal/pkg/migrator/migrator.go index a8942a02cd..22b794dfdc 100644 --- a/internal/pkg/migrator/migrator.go +++ b/internal/pkg/migrator/migrator.go @@ -12,30 +12,36 @@ var ( ) type Migrator struct { - dbx *sqlx.DB - src migrate.MigrationSource + dbx *sqlx.DB + dialect string + src migrate.MigrationSource } -func New(d database.Database) *Migrator { +func New(d database.Database, dialect string) *Migrator { migrations := &migrate.EmbedFileSystemMigrationSource{ - FileSystem: convoy.MigrationFiles, - Root: "sql", + FileSystem: convoy.SQLiteMigrationFiles, + Root: "sql/sqlite3", } - migrate.SetSchema(tableSchema) - return &Migrator{dbx: d.GetDB(), src: migrations} + if dialect == "postgres" { + migrations.FileSystem = convoy.PostgresMigrationFiles + migrations.Root = "sql/postgres" + migrate.SetSchema(tableSchema) + } + + return &Migrator{dbx: d.GetDB(), src: migrations, dialect: dialect} } func (m *Migrator) Up() error { - _, err := migrate.Exec(m.dbx.DB, "postgres", m.src, migrate.Up) + _, err := migrate.Exec(m.dbx.DB, m.dialect, m.src, migrate.Up) if err != nil { return err } return nil } -func (m *Migrator) Down(max int) error { - _, err := migrate.ExecMax(m.dbx.DB, "postgres", m.src, migrate.Down, max) +func (m *Migrator) Down(maxDown int) error { + _, err := migrate.ExecMax(m.dbx.DB, m.dialect, m.src, migrate.Down, maxDown) if err != nil { return err } diff --git a/sql/sqlite3/1731413863.sql b/sql/sqlite3/1731413863.sql new file mode 100644 index 0000000000..7d6ca8eb77 --- /dev/null +++ b/sql/sqlite3/1731413863.sql @@ -0,0 +1,712 @@ +-- +migrate Up +create table if not exists configurations +( + id TEXT not null, + is_analytics_enabled TEXT not null, + is_signup_enabled BOOLEAN not null, + storage_policy_type TEXT not null, + on_prem_path TEXT, + s3_bucket TEXT, + s3_access_key TEXT, + s3_secret_key TEXT, + s3_region TEXT, + s3_session_token TEXT, + s3_endpoint TEXT, + created_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), + updated_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), + deleted_at TEXT, + s3_prefix TEXT, + retention_policy_enabled BOOLEAN default false not null, + retention_policy_policy TEXT default '720h' not null, + cb_minimum_request_count INTEGER default 10 not null, + cb_sample_rate INTEGER default 30 not null, + cb_error_timeout INTEGER default 30 not null, + cb_failure_threshold INTEGER default 70 not null, + cb_success_threshold INTEGER default 1 not null, + cb_observability_window INTEGER default 30 not null, + cb_consecutive_failure_threshold INTEGER default 10 not null +); + +create table if not exists events_endpoints +( + event_id TEXT not null, + endpoint_id TEXT not null, + FOREIGN KEY(endpoint_id) REFERENCES endpoints(id), + FOREIGN KEY(event_id) REFERENCES events(id) +); + +create index if not exists events_endpoints_temp_endpoint_id_idx + on events_endpoints (endpoint_id); + +create unique index if not exists events_endpoints_temp_event_id_endpoint_id_idx1 + on events_endpoints (event_id, endpoint_id); + +create index if not exists events_endpoints_temp_event_id_idx + on events_endpoints (event_id); + +create unique index if not exists idx_uq_constraint_events_endpoints_event_id_endpoint_id + on events_endpoints (event_id, endpoint_id); + +create table if not exists gorp_migrations +( + id TEXT not null primary key, + applied_at TEXT +); + +create table if not exists project_configurations +( + id TEXT not null primary key, + max_payload_read_size INTEGER not null, + replay_attacks_prevention_enabled BOOLEAN not null, + ratelimit_count INTEGER not null, + ratelimit_duration INTEGER not null, + strategy_type TEXT not null, + strategy_duration INTEGER not null, + strategy_retry_count INTEGER not null, + signature_header TEXT not null, + signature_versions TEXT not null, + disable_endpoint BOOLEAN default false not null, + meta_events_enabled BOOLEAN default false not null, + meta_events_type TEXT, + meta_events_event_type TEXT, + meta_events_url TEXT, + meta_events_secret TEXT, + meta_events_pub_sub TEXT, + search_policy TEXT default '720h', + multiple_endpoint_subscriptions BOOLEAN default false not null, + ssl_enforce_secure_endpoints BOOLEAN default true, + created_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), + updated_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), + deleted_at TEXT +); + +create table if not exists source_verifiers +( + id TEXT not null primary key, + type TEXT not null, + basic_username TEXT, + basic_password TEXT, + api_key_header_name TEXT, + api_key_header_value TEXT, + hmac_hash TEXT, + hmac_header TEXT, + hmac_secret TEXT, + hmac_encoding TEXT, + twitter_crc_verified_at TEXT, + created_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), + updated_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), + deleted_at TEXT +); + +create table if not exists token_bucket +( + key TEXT not null primary key, + rate INTEGER not null, + tokens INTEGER default 1, + created_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), + updated_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), + expires_at TEXT not null +); + +create table if not exists users +( + id TEXT not null primary key, + first_name TEXT not null, + last_name TEXT not null, + email TEXT not null, + password TEXT not null, + email_verified BOOLEAN not null, + reset_password_token TEXT, + email_verification_token TEXT, + reset_password_expires_at TEXT, + email_verification_expires_at TEXT, + auth_type TEXT default 'local' not null, + created_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), + updated_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), + deleted_at TEXT, + constraint users_email_key + unique (email, deleted_at) +); + +create table if not exists organisations +( + id TEXT not null primary key, + name TEXT not null, + owner_id TEXT not null, + custom_domain TEXT, + assigned_domain TEXT, + created_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), + updated_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), + deleted_at TEXT, + FOREIGN KEY(owner_id) REFERENCES users(id) +); + +create unique index if not exists organisations_custom_domain + on organisations (custom_domain, assigned_domain) + where (deleted_at IS NULL); + +create table if not exists projects +( + id TEXT not null primary key, + name TEXT not null, + type TEXT not null, + logo_url TEXT, + retained_events INTEGER default 0, + organisation_id TEXT not null, + project_configuration_id TEXT not null, + created_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), + updated_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), + deleted_at TEXT, + constraint name_org_id_key unique (name, organisation_id, deleted_at), + FOREIGN KEY(organisation_id) REFERENCES organisations(id), + FOREIGN KEY(project_configuration_id) REFERENCES project_configurations(id) +); + +-- todo(raymond): deprecate me +create table if not exists applications +( + id TEXT not null primary key, + project_id TEXT not null, + title TEXT not null, + support_email TEXT, + slack_webhook_url TEXT, + created_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), + updated_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), + deleted_at TEXT, + FOREIGN KEY(project_id) REFERENCES projects(id) +); + +-- todo(raymond): deprecate me +create table if not exists devices +( + id TEXT not null primary key, + project_id TEXT not null, + host_name TEXT not null, + status TEXT not null, + created_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), + updated_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), + last_seen_at TEXT not null, + deleted_at TEXT, + FOREIGN KEY(project_id) REFERENCES projects(id) +); + +create table if not exists endpoints +( + id TEXT not null primary key, + status TEXT not null, + owner_id TEXT, + description TEXT, + rate_limit INTEGER not null, + advanced_signatures BOOLEAN not null, + slack_webhook_url TEXT, + support_email TEXT, + app_id TEXT, + project_id TEXT not null, + authentication_type TEXT, + authentication_type_api_key_header_name TEXT, + authentication_type_api_key_header_value TEXT, + secrets TEXT not null, + created_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), + updated_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), + deleted_at TEXT, + http_timeout INTEGER not null, + rate_limit_duration INTEGER not null, + name TEXT not null, + url TEXT not null, + FOREIGN KEY(project_id) REFERENCES projects(id) +); + +create table if not exists api_keys +( + id TEXT not null primary key, + name TEXT not null, + key_type TEXT not null, + mask_id TEXT not null, + role_type TEXT, + role_project TEXT, + role_endpoint TEXT, + hash TEXT not null, + salt TEXT not null, + user_id TEXT, + created_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), + updated_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), + expires_at TEXT, + deleted_at TEXT, + constraint api_keys_mask_id_key unique (mask_id, deleted_at), + FOREIGN KEY(role_project) REFERENCES projects(id), + FOREIGN KEY(role_endpoint) REFERENCES endpoints(id), + FOREIGN KEY(user_id) REFERENCES users(id) +); + +create index if not exists idx_api_keys_mask_id + on api_keys (mask_id); + +create index if not exists idx_endpoints_app_id_key + on endpoints (app_id); + +create index if not exists idx_endpoints_owner_id_key + on endpoints (owner_id); + +create index if not exists idx_endpoints_project_id_key + on endpoints (project_id); + +create table if not exists event_types +( + id TEXT not null primary key, + name TEXT not null, + description TEXT, + project_id TEXT not null, + category TEXT, + created_at TEXT default (strftime('%Y-%m-%dT%H:%M:%fZ')) not null, + updated_at TEXT default (strftime('%Y-%m-%dT%H:%M:%fZ')) not null, + deprecated_at TEXT, + FOREIGN KEY(project_id) REFERENCES projects(id) +); + +create index if not exists idx_event_types_category + on event_types (category); + +create index if not exists idx_event_types_category_deprecated + on event_types (category) + where (deprecated_at IS NOT NULL); + +create index if not exists idx_event_types_category_not_deprecated + on event_types (category) + where (deprecated_at IS NULL); + +create index if not exists idx_event_types_name + on event_types (name); + +create index if not exists idx_event_types_name_deprecated + on event_types (name) + where (deprecated_at IS NOT NULL); + +create index if not exists idx_event_types_name_not_deprecated + on event_types (name) + where (deprecated_at IS NULL); + +create table if not exists jobs +( + id TEXT not null primary key, + type TEXT not null, + status TEXT not null, + project_id TEXT not null, + started_at TEXT, + completed_at TEXT, + failed_at TEXT, + created_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), + updated_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), + deleted_at TEXT, + FOREIGN KEY(project_id) REFERENCES projects(id) +); + +create table if not exists meta_events +( + id TEXT not null primary key, + event_type TEXT not null, + project_id TEXT not null, + metadata TEXT not null, + attempt TEXT, + status TEXT not null, + created_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), + updated_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), + deleted_at TEXT, + FOREIGN KEY(project_id) REFERENCES projects(id) +); + +create table if not exists organisation_invites +( + id TEXT not null primary key, + organisation_id TEXT not null, + invitee_email TEXT not null, + token TEXT not null, + role_type TEXT not null, + role_project TEXT, + role_endpoint TEXT, + status TEXT not null, + created_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), + updated_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), + expires_at TEXT not null, + deleted_at TEXT, + constraint organisation_invites_token_key unique (token, deleted_at), + FOREIGN KEY(role_project) REFERENCES projects(id), + FOREIGN KEY(organisation_id) REFERENCES organisations(id), + FOREIGN KEY(role_endpoint) REFERENCES endpoints(id) +); + +create index if not exists idx_organisation_invites_token_key + on organisation_invites (token); + +create unique index if not exists organisation_invites_invitee_email + on organisation_invites (organisation_id, invitee_email, deleted_at); + +create table if not exists organisation_members +( + id TEXT not null primary key, + role_type TEXT not null, + role_project TEXT, + role_endpoint TEXT, + user_id TEXT not null, + organisation_id TEXT not null, + created_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), + updated_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), + deleted_at TEXT, + constraint organisation_members_user_id_org_id_key + unique (organisation_id, user_id, deleted_at), + FOREIGN KEY(role_project) REFERENCES projects(id), + FOREIGN KEY(role_endpoint) REFERENCES endpoints(id), + FOREIGN KEY(organisation_id) REFERENCES organisations(id), + FOREIGN KEY(user_id) REFERENCES users(id) +); + +create index if not exists idx_organisation_members_deleted_at_key + on organisation_members (deleted_at); + +create index if not exists idx_organisation_members_organisation_id_key + on organisation_members (organisation_id); + +create index if not exists idx_organisation_members_user_id_key + on organisation_members (user_id); + +create table if not exists portal_links +( + id TEXT not null primary key, + project_id TEXT not null, + name TEXT not null, + token TEXT not null, + endpoints TEXT, + created_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), + updated_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), + deleted_at TEXT, + owner_id TEXT, + can_manage_endpoint BOOLEAN default false, + constraint portal_links_token + unique (token, deleted_at), + FOREIGN KEY(project_id) REFERENCES projects(id) +); + +create index if not exists idx_portal_links_owner_id_key + on portal_links (owner_id); + +create index if not exists idx_portal_links_project_id + on portal_links (project_id); + +create index if not exists idx_portal_links_token + on portal_links (token); + +create table if not exists portal_links_endpoints +( + portal_link_id TEXT not null, + endpoint_id TEXT not null, + FOREIGN KEY(portal_link_id) REFERENCES portal_links(id), + FOREIGN KEY(endpoint_id) REFERENCES endpoints(id) +); + +create index if not exists idx_portal_links_endpoints_enpdoint_id + on portal_links_endpoints (endpoint_id); + +create index if not exists idx_portal_links_endpoints_portal_link_id + on portal_links_endpoints (portal_link_id); + +create table if not exists sources +( + id TEXT not null primary key, + name TEXT not null, + type TEXT not null, + mask_id TEXT not null, + provider TEXT not null, + is_disabled BOOLEAN default false, + forward_headers TEXT[], + project_id TEXT not null, + source_verifier_id TEXT, + pub_sub TEXT, + deleted_at TEXT, + custom_response_body TEXT, + custom_response_content_type TEXT, + idempotency_keys TEXT[], + body_function TEXT, + header_function TEXT, + created_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), + updated_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), + constraint sources_mask_id unique (mask_id, deleted_at), + FOREIGN KEY(project_id) REFERENCES projects(id), + FOREIGN KEY(source_verifier_id) REFERENCES source_verifiers(id) +); + +create table if not exists events +( + id TEXT not null primary key, + event_type TEXT not null, + data TEXT not null, + project_id TEXT not null, + raw TEXT not null, + endpoints TEXT, + source_id TEXT, + headers TEXT, + deleted_at TEXT, + url_query_params TEXT, + idempotency_key TEXT, + acknowledged_at TEXT, + status TEXT, + metadata TEXT, + is_duplicate_event BOOLEAN default false, + created_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), + updated_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), + FOREIGN KEY(source_id) REFERENCES sources(id), + FOREIGN KEY(project_id) REFERENCES projects(name) +); + +create index if not exists idx_events_created_at_key + on events (created_at); + +create index if not exists idx_events_deleted_at_key + on events (deleted_at); + +create index if not exists idx_events_project_id_deleted_at_key + on events (project_id, deleted_at); + +create index if not exists idx_events_project_id_key + on events (project_id); + +create index if not exists idx_events_project_id_source_id + on events (project_id, source_id); + +create index if not exists idx_events_source_id + on events (source_id); + +create index if not exists idx_events_source_id_key + on events (source_id); + +create index if not exists idx_idempotency_key_key + on events (idempotency_key); + +create index if not exists idx_project_id_on_not_deleted + on events (project_id) + where (deleted_at IS NULL); + +create table if not exists events_search +( + id TEXT not null primary key, + event_type TEXT not null, + endpoints TEXT, + project_id TEXT not null, + source_id TEXT, + headers TEXT, + raw TEXT not null, + data TEXT not null, + url_query_params TEXT, + idempotency_key TEXT, + is_duplicate_event BOOLEAN default false, + search_token TEXT, + created_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), + updated_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), + deleted_at TEXT, + FOREIGN KEY(source_id) REFERENCES sources(id), + FOREIGN KEY(project_id) REFERENCES projects(id) +); + +create index if not exists idx_events_search_created_at_key + on events_search (created_at); + +create index if not exists idx_events_search_deleted_at_key + on events_search (deleted_at); + +create index if not exists idx_events_search_project_id_deleted_at_key + on events_search (project_id, deleted_at); + +create index if not exists idx_events_search_project_id_key + on events_search (project_id); + +create index if not exists idx_events_search_source_id_key + on events_search (source_id); + +create index if not exists idx_events_search_token_key + on events_search (search_token); + +create index if not exists idx_sources_mask_id + on sources (mask_id); + +create index if not exists idx_sources_project_id + on sources (project_id); + +create index if not exists idx_sources_source_verifier_id + on sources (source_verifier_id); + +create table if not exists subscriptions +( + id TEXT not null primary key, + name TEXT not null, + type TEXT not null, + project_id TEXT not null, + endpoint_id TEXT, + device_id TEXT, + source_id TEXT, + alert_config_count INTEGER not null, + alert_config_threshold TEXT not null, + retry_config_type TEXT not null, + retry_config_duration INTEGER not null, + retry_config_retry_count INTEGER not null, + filter_config_event_types TEXT[] not null, + filter_config_filter_headers TEXT not null, + filter_config_filter_body TEXT not null, + rate_limit_config_count INTEGER not null, + rate_limit_config_duration INTEGER not null, + created_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), + updated_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), + deleted_at TEXT, + function TEXT, + filter_config_filter_is_flattened BOOLEAN default false, + FOREIGN KEY(source_id) REFERENCES sources(id), + FOREIGN KEY(device_id) REFERENCES devices(id), + FOREIGN KEY(endpoint_id) REFERENCES endpoints(id), + FOREIGN KEY(project_id) REFERENCES projects(id) +); + +create table if not exists event_deliveries +( + id TEXT not null primary key, + status TEXT not null, + description TEXT not null, + project_id TEXT not null, + endpoint_id TEXT, + event_id TEXT not null, + device_id TEXT, + subscription_id TEXT not null, + metadata TEXT not null, + headers TEXT, + attempts TEXT, + cli_metadata TEXT, + url_query_params TEXT, + idempotency_key TEXT, + latency TEXT, + event_type TEXT, + acknowledged_at TEXT, + latency_seconds NUMERIC, + created_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), + updated_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), + FOREIGN KEY(subscription_id) REFERENCES subscriptions(id), + FOREIGN KEY(device_id) REFERENCES devices(id), + FOREIGN KEY(event_id) REFERENCES events(id), + FOREIGN KEY(endpoint_id) REFERENCES endpoints(id), + FOREIGN KEY(project_id) REFERENCES projects(id) +); + +create table if not exists delivery_attempts +( + id TEXT not null primary key, + url TEXT not null, + method TEXT not null, + api_version TEXT not null, + project_id TEXT not null, + endpoint_id TEXT not null, + event_delivery_id TEXT not null, + ip_address TEXT, + request_http_header TEXT, + response_http_header TEXT, + http_status TEXT, + response_data TEXT, + error TEXT, + status BOOLEAN, + created_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), + updated_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), + deleted_at TEXT, + FOREIGN KEY(event_delivery_id) REFERENCES event_deliveries(id), + FOREIGN KEY(endpoint_id) REFERENCES endpoints(id), + FOREIGN KEY(project_id) REFERENCES projects(id) +); + +create index if not exists idx_delivery_attempts_created_at + on delivery_attempts (created_at); + +create index if not exists idx_delivery_attempts_created_at_id_event_delivery_id + on delivery_attempts (created_at, id, project_id, event_delivery_id) + where (deleted_at IS NULL); + +create index if not exists idx_delivery_attempts_event_delivery_id + on delivery_attempts (event_delivery_id); + +create index if not exists idx_delivery_attempts_event_delivery_id_created_at + on delivery_attempts (event_delivery_id, created_at); + +create index if not exists idx_delivery_attempts_event_delivery_id_created_at_desc + on delivery_attempts (event_delivery_id asc, created_at desc); + +create index if not exists event_deliveries_event_type_1 + on event_deliveries (event_type); + +create index if not exists idx_event_deliveries_created_at_key + on event_deliveries (created_at); + +create index if not exists idx_event_deliveries_device_id_key + on event_deliveries (device_id); + +create index if not exists idx_event_deliveries_endpoint_id_key + on event_deliveries (endpoint_id); + +create index if not exists idx_event_deliveries_event_id_key + on event_deliveries (event_id); + +create index if not exists idx_event_deliveries_project_id_endpoint_id + on event_deliveries (project_id, endpoint_id); + +create index if not exists idx_event_deliveries_project_id_endpoint_id_status + on event_deliveries (project_id, endpoint_id, status); + +create index if not exists idx_event_deliveries_project_id_event_id + on event_deliveries (project_id, event_id); + +create index if not exists idx_event_deliveries_project_id_key + on event_deliveries (project_id); + +create index if not exists idx_event_deliveries_status + on event_deliveries (status); + +create index if not exists idx_event_deliveries_status_key + on event_deliveries (status); + +create index if not exists idx_subscriptions_filter_config_event_types_key + on subscriptions (filter_config_event_types); + +create index if not exists idx_subscriptions_id_deleted_at + on subscriptions (id, deleted_at) + where (deleted_at IS NOT NULL); + +create index if not exists idx_subscriptions_name_key + on subscriptions (name) + where (deleted_at IS NULL); + +create index if not exists idx_subscriptions_updated_at + on subscriptions (updated_at) + where (deleted_at IS NULL); + +create index if not exists idx_subscriptions_updated_at_id_project_id + on subscriptions (updated_at, id, project_id) + where (deleted_at IS NULL); + +-- +migrate Down +drop table if exists configurations; +drop table if exists events_endpoints; +drop table if exists gorp_migrations; +drop table if exists project_configurations; +drop table if exists source_verifiers; +drop table if exists token_bucket; +drop table if exists users; +drop table if exists organisations; +drop table if exists projects; +drop table if exists applications; +drop table if exists devices; +drop table if exists endpoints; +drop table if exists api_keys; +drop table if exists event_types; +drop table if exists jobs; +drop table if exists meta_events; +drop table if exists organisation_invites; +drop table if exists organisation_members; +drop table if exists portal_links; +drop table if exists portal_links_endpoints; +drop table if exists sources; +drop table if exists events; +drop table if exists events_search; +drop table if exists subscriptions; +drop table if exists event_deliveries; +drop table if exists delivery_attempts; \ No newline at end of file diff --git a/type.go b/type.go index 65b0260284..234ed6fa19 100644 --- a/type.go +++ b/type.go @@ -16,8 +16,11 @@ type CacheKey string //go:embed VERSION var F embed.FS -//go:embed sql/*.sql -var MigrationFiles embed.FS +//go:embed sql/postgres/*.sql +var PostgresMigrationFiles embed.FS + +//go:embed sql/sqlite3/*.sql +var SQLiteMigrationFiles embed.FS func (t TaskName) SetPrefix(prefix string) TaskName { var name strings.Builder @@ -54,8 +57,7 @@ func readVersion(fs embed.FS) ([]byte, error) { return data, nil } -// TODO(subomi): This needs to be refactored for everywhere we depend -// on this code. +// GetVersion todo(subomi): This needs to be refactored for everywhere we depend on this code. func GetVersion() string { v := "0.1.0" From e4504ab509b520f8346f9ed914b27284f54b8d54 Mon Sep 17 00:00:00 2001 From: Raymond Tukpe Date: Wed, 13 Nov 2024 16:05:46 +0100 Subject: [PATCH 3/7] feat: added repos --- database/sqlite3/api_key.go | 391 +++++++ database/sqlite3/api_key_test.go | 271 +++++ database/sqlite3/configuration.go | 227 ++++ database/sqlite3/configuration_test.go | 124 +++ database/sqlite3/delivery_attempts.go | 195 ++++ database/sqlite3/delivery_attempts_test.go | 122 +++ database/sqlite3/device.go | 314 ++++++ database/sqlite3/device_test.go | 212 ++++ database/sqlite3/endpoint.go | 497 +++++++++ database/sqlite3/endpoint_test.go | 605 ++++++++++ database/sqlite3/event.go | 673 ++++++++++++ database/sqlite3/event_delivery.go | 1031 ++++++++++++++++++ database/sqlite3/event_delivery_test.go | 518 +++++++++ database/sqlite3/event_test.go | 470 ++++++++ database/sqlite3/event_types.go | 184 ++++ database/sqlite3/export.go | 143 +++ database/sqlite3/job.go | 368 +++++++ database/sqlite3/job_test.go | 345 ++++++ database/sqlite3/meta_event.go | 225 ++++ database/sqlite3/meta_event_test.go | 207 ++++ database/sqlite3/organisation.go | 285 +++++ database/sqlite3/organisation_invite.go | 342 ++++++ database/sqlite3/organisation_invite_test.go | 227 ++++ database/sqlite3/organisation_member.go | 499 +++++++++ database/sqlite3/organisation_member_test.go | 323 ++++++ database/sqlite3/organisation_test.go | 249 +++++ database/sqlite3/portal_link.go | 493 +++++++++ database/sqlite3/portal_link_test.go | 240 ++++ database/sqlite3/project.go | 509 +++++++++ database/sqlite3/project_test.go | 500 +++++++++ database/sqlite3/source.go | 559 ++++++++++ database/sqlite3/source_test.go | 288 +++++ database/sqlite3/sqlite3.go | 21 + database/sqlite3/sqlite_test.go | 78 ++ database/sqlite3/subscription.go | 844 ++++++++++++++ database/sqlite3/subscription_test.go | 526 +++++++++ database/sqlite3/users.go | 177 +++ database/sqlite3/users_test.go | 292 +++++ 38 files changed, 13574 insertions(+) create mode 100644 database/sqlite3/api_key.go create mode 100644 database/sqlite3/api_key_test.go create mode 100644 database/sqlite3/configuration.go create mode 100644 database/sqlite3/configuration_test.go create mode 100644 database/sqlite3/delivery_attempts.go create mode 100644 database/sqlite3/delivery_attempts_test.go create mode 100644 database/sqlite3/device.go create mode 100644 database/sqlite3/device_test.go create mode 100644 database/sqlite3/endpoint.go create mode 100644 database/sqlite3/endpoint_test.go create mode 100644 database/sqlite3/event.go create mode 100644 database/sqlite3/event_delivery.go create mode 100644 database/sqlite3/event_delivery_test.go create mode 100644 database/sqlite3/event_test.go create mode 100644 database/sqlite3/event_types.go create mode 100644 database/sqlite3/export.go create mode 100644 database/sqlite3/job.go create mode 100644 database/sqlite3/job_test.go create mode 100644 database/sqlite3/meta_event.go create mode 100644 database/sqlite3/meta_event_test.go create mode 100644 database/sqlite3/organisation.go create mode 100644 database/sqlite3/organisation_invite.go create mode 100644 database/sqlite3/organisation_invite_test.go create mode 100644 database/sqlite3/organisation_member.go create mode 100644 database/sqlite3/organisation_member_test.go create mode 100644 database/sqlite3/organisation_test.go create mode 100644 database/sqlite3/portal_link.go create mode 100644 database/sqlite3/portal_link_test.go create mode 100644 database/sqlite3/project.go create mode 100644 database/sqlite3/project_test.go create mode 100644 database/sqlite3/source.go create mode 100644 database/sqlite3/source_test.go create mode 100644 database/sqlite3/sqlite_test.go create mode 100644 database/sqlite3/subscription.go create mode 100644 database/sqlite3/subscription_test.go create mode 100644 database/sqlite3/users.go create mode 100644 database/sqlite3/users_test.go diff --git a/database/sqlite3/api_key.go b/database/sqlite3/api_key.go new file mode 100644 index 0000000000..80691aa5e1 --- /dev/null +++ b/database/sqlite3/api_key.go @@ -0,0 +1,391 @@ +package sqlite3 + +import ( + "context" + "database/sql" + "errors" + "fmt" + "github.com/frain-dev/convoy/auth" + + "github.com/frain-dev/convoy/database" + "github.com/frain-dev/convoy/datastore" + "github.com/frain-dev/convoy/util" + "github.com/jmoiron/sqlx" +) + +const ( + createAPIKey = ` + INSERT INTO api_keys (id,name,key_type,mask_id,role_type,role_project,role_endpoint,hash,salt,user_id,expires_at) + VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11); + ` + + updateAPIKeyById = ` + UPDATE api_keys SET + name = $2, + role_type= $3, + role_project=$4, + role_endpoint=$5, + updated_at = NOW() + WHERE id = $1 AND deleted_at IS NULL ; + ` + + fetchAPIKey = ` + SELECT + id, + name, + key_type, + mask_id, + COALESCE(role_type,'') AS "role.type", + COALESCE(role_project,'') AS "role.project", + COALESCE(role_endpoint,'') AS "role.endpoint", + hash, + salt, + COALESCE(user_id, '') AS user_id, + created_at, + updated_at, + expires_at + FROM api_keys + WHERE deleted_at IS NULL + ` + + deleteAPIKeys = ` + UPDATE api_keys SET + deleted_at = NOW() + WHERE id IN (?); + ` + + fetchAPIKeysPaged = ` + SELECT + id, + name, + key_type, + mask_id, + COALESCE(role_type,'') AS "role.type", + COALESCE(role_project,'') AS "role.project", + COALESCE(role_endpoint,'') AS "role.endpoint", + hash, + salt, + COALESCE(user_id, '') AS user_id, + created_at, + updated_at, + expires_at + FROM api_keys + WHERE deleted_at IS NULL` + + baseApiKeysFilter = ` + AND (role_project = :project_id OR :project_id = '') + AND (role_endpoint = :endpoint_id OR :endpoint_id = '') + AND (user_id = :user_id OR :user_id = '') + AND (key_type = :key_type OR :key_type = '')` + + baseFetchAPIKeysPagedForward = ` + %s + %s + AND id <= :cursor + GROUP BY id + ORDER BY id DESC + LIMIT :limit + ` + + baseFetchAPIKeysPagedBackward = ` + WITH api_keys AS ( + %s + %s + AND id >= :cursor + GROUP BY id + ORDER BY id ASC + LIMIT :limit + ) + + SELECT * FROM api_keys ORDER BY id DESC + ` + + countPrevAPIKeys = ` + SELECT COUNT(DISTINCT(id)) AS count + FROM api_keys s + WHERE s.deleted_at IS NULL + %s + AND id > :cursor + GROUP BY id + ORDER BY id + DESC LIMIT 1` +) + +var ( + ErrAPIKeyNotCreated = errors.New("api key could not be created") + ErrAPIKeyNotUpdated = errors.New("api key could not be updated") + ErrAPIKeyNotRevoked = errors.New("api key could not be revoked") +) + +type apiKeyRepo struct { + db *sqlx.DB +} + +func NewAPIKeyRepo(db database.Database) datastore.APIKeyRepository { + return &apiKeyRepo{db: db.GetDB()} +} + +func (a *apiKeyRepo) CreateAPIKey(ctx context.Context, key *datastore.APIKey) error { + var ( + userID *string + endpointID *string + projectID *string + roleType *auth.RoleType + ) + + if !util.IsStringEmpty(key.UserID) { + userID = &key.UserID + } + + if !util.IsStringEmpty(key.Role.Endpoint) { + endpointID = &key.Role.Endpoint + } + + if !util.IsStringEmpty(key.Role.Project) { + projectID = &key.Role.Project + } + + if !util.IsStringEmpty(string(key.Role.Type)) { + roleType = &key.Role.Type + } + + result, err := a.db.ExecContext( + ctx, createAPIKey, key.UID, key.Name, key.Type, key.MaskID, + roleType, projectID, endpointID, key.Hash, + key.Salt, userID, key.ExpiresAt, + ) + if err != nil { + return err + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + + if rowsAffected < 1 { + return ErrAPIKeyNotCreated + } + + return nil +} + +func (a *apiKeyRepo) UpdateAPIKey(ctx context.Context, key *datastore.APIKey) error { + var endpointID *string + var projectID *string + var roleType *auth.RoleType + + if !util.IsStringEmpty(key.Role.Endpoint) { + endpointID = &key.Role.Endpoint + } + + if !util.IsStringEmpty(key.Role.Project) { + projectID = &key.Role.Project + } + + if !util.IsStringEmpty(string(key.Role.Type)) { + roleType = &key.Role.Type + } + + result, err := a.db.ExecContext( + ctx, updateAPIKeyById, key.UID, key.Name, roleType, projectID, endpointID, + ) + if err != nil { + return err + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + + if rowsAffected < 1 { + return ErrAPIKeyNotUpdated + } + + return nil +} + +func (a *apiKeyRepo) FindAPIKeyByID(ctx context.Context, id string) (*datastore.APIKey, error) { + apiKey := &datastore.APIKey{} + err := a.db.QueryRowxContext(ctx, fmt.Sprintf("%s AND id = $1;", fetchAPIKey), id).StructScan(apiKey) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, datastore.ErrAPIKeyNotFound + } + return nil, err + } + + return apiKey, nil +} + +func (a *apiKeyRepo) FindAPIKeyByMaskID(ctx context.Context, maskID string) (*datastore.APIKey, error) { + apiKey := &datastore.APIKey{} + err := a.db.QueryRowxContext(ctx, fmt.Sprintf("%s AND mask_id = $1;", fetchAPIKey), maskID).StructScan(apiKey) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, datastore.ErrAPIKeyNotFound + } + return nil, err + } + + return apiKey, nil +} + +func (a *apiKeyRepo) FindAPIKeyByHash(ctx context.Context, hash string) (*datastore.APIKey, error) { + apiKey := &datastore.APIKey{} + err := a.db.QueryRowxContext(ctx, fmt.Sprintf("%s AND hash = $1;", fetchAPIKey), hash).StructScan(apiKey) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, datastore.ErrAPIKeyNotFound + } + return nil, err + } + + return apiKey, nil +} + +func (a *apiKeyRepo) RevokeAPIKeys(ctx context.Context, ids []string) error { + query, args, err := sqlx.In(deleteAPIKeys, ids) + if err != nil { + return err + } + + result, err := a.db.ExecContext(ctx, a.db.Rebind(query), args...) + if err != nil { + return err + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + + if rowsAffected < 1 { + return ErrAPIKeyNotRevoked + } + + return nil +} + +func (a *apiKeyRepo) LoadAPIKeysPaged(ctx context.Context, filter *datastore.ApiKeyFilter, pageable *datastore.Pageable) ([]datastore.APIKey, datastore.PaginationData, error) { + var query, filterQuery string + var err error + var args []interface{} + + arg := map[string]interface{}{ + "endpoint_ids": filter.EndpointIDs, + "project_id": filter.ProjectID, + "endpoint_id": filter.EndpointID, + "user_id": filter.UserID, + "key_type": filter.KeyType, + "limit": pageable.Limit(), + "cursor": pageable.Cursor(), + } + + if pageable.Direction == datastore.Next { + query = baseFetchAPIKeysPagedForward + } else { + query = baseFetchAPIKeysPagedBackward + } + + filterQuery = baseApiKeysFilter + if len(filter.EndpointIDs) > 0 { + filterQuery += ` AND role_endpoint IN (:endpoint_ids)` + } + + query = fmt.Sprintf(query, fetchAPIKeysPaged, filterQuery) + + query, args, err = sqlx.Named(query, arg) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + query, args, err = sqlx.In(query, args...) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + query = a.db.Rebind(query) + + rows, err := a.db.QueryxContext(ctx, query, args...) + if err != nil { + return nil, datastore.PaginationData{}, err + } + defer closeWithError(rows) + + var apiKeys []datastore.APIKey + + for rows.Next() { + ak := ApiKeyPaginated{} + err = rows.StructScan(&ak) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + apiKeys = append(apiKeys, ak.APIKey) + } + + var count datastore.PrevRowCount + if len(apiKeys) > 0 { + var countQuery string + var qargs []interface{} + first := apiKeys[0] + qarg := arg + qarg["cursor"] = first.UID + + cq := fmt.Sprintf(countPrevAPIKeys, filterQuery) + countQuery, qargs, err = sqlx.Named(cq, qarg) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + countQuery = a.db.Rebind(countQuery) + + // count the row number before the first row + rows, err := a.db.QueryxContext(ctx, countQuery, qargs...) + if err != nil { + return nil, datastore.PaginationData{}, err + } + defer closeWithError(rows) + + if rows.Next() { + err = rows.StructScan(&count) + if err != nil { + return nil, datastore.PaginationData{}, err + } + } + } + + ids := make([]string, len(apiKeys)) + for i := range apiKeys { + ids[i] = apiKeys[i].UID + } + + if len(apiKeys) > pageable.PerPage { + apiKeys = apiKeys[:len(apiKeys)-1] + } + + pagination := &datastore.PaginationData{PrevRowCount: count} + pagination = pagination.Build(*pageable, ids) + + return apiKeys, *pagination, nil +} + +func (a *apiKeyRepo) FindAPIKeyByProjectID(ctx context.Context, projectID string) (*datastore.APIKey, error) { + apiKey := &datastore.APIKey{} + err := a.db.QueryRowxContext(ctx, fmt.Sprintf("%s AND role_project = $1;", fetchAPIKey), projectID).StructScan(apiKey) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, datastore.ErrAPIKeyNotFound + } + return nil, err + } + + return apiKey, nil +} + +type ApiKeyPaginated struct { + Count int `db:"count"` + datastore.APIKey +} diff --git a/database/sqlite3/api_key_test.go b/database/sqlite3/api_key_test.go new file mode 100644 index 0000000000..5587469f9b --- /dev/null +++ b/database/sqlite3/api_key_test.go @@ -0,0 +1,271 @@ +//go:build integration +// +build integration + +package sqlite3 + +import ( + "context" + "errors" + "testing" + "time" + + "gopkg.in/guregu/null.v4" + + "github.com/frain-dev/convoy/auth" + "github.com/frain-dev/convoy/datastore" + "github.com/oklog/ulid/v2" + "github.com/stretchr/testify/require" +) + +func Test_CreateAPIKey(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + project := seedProject(t, db) + endpoint := seedEndpoint(t, db) + + apiKeyRepo := NewAPIKeyRepo(db) + apiKey := generateApiKey(project, endpoint) + + require.NoError(t, apiKeyRepo.CreateAPIKey(context.Background(), apiKey)) + + newApiKey, err := apiKeyRepo.FindAPIKeyByID(context.Background(), apiKey.UID) + require.NoError(t, err) + + apiKey.ExpiresAt = null.Time{} + newApiKey.CreatedAt = time.Time{} + newApiKey.UpdatedAt = time.Time{} + newApiKey.ExpiresAt = null.Time{} + + require.Equal(t, apiKey, newApiKey) +} + +func Test_FindAPIKeyByID(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + project := seedProject(t, db) + endpoint := seedEndpoint(t, db) + + apiKeyRepo := NewAPIKeyRepo(db) + apiKey := generateApiKey(project, endpoint) + + _, err := apiKeyRepo.FindAPIKeyByID(context.Background(), apiKey.UID) + require.Error(t, err) + require.True(t, errors.Is(err, datastore.ErrAPIKeyNotFound)) + + require.NoError(t, apiKeyRepo.CreateAPIKey(context.Background(), apiKey)) + + newApiKey, err := apiKeyRepo.FindAPIKeyByID(context.Background(), apiKey.UID) + require.NoError(t, err) + + apiKey.ExpiresAt = null.Time{} + newApiKey.CreatedAt = time.Time{} + newApiKey.UpdatedAt = time.Time{} + newApiKey.ExpiresAt = null.Time{} + + require.Equal(t, apiKey, newApiKey) +} + +func Test_FindAPIKeyByMaskID(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + project := seedProject(t, db) + endpoint := seedEndpoint(t, db) + + apiKeyRepo := NewAPIKeyRepo(db) + apiKey := generateApiKey(project, endpoint) + + _, err := apiKeyRepo.FindAPIKeyByMaskID(context.Background(), apiKey.MaskID) + require.Error(t, err) + require.True(t, errors.Is(err, datastore.ErrAPIKeyNotFound)) + + require.NoError(t, apiKeyRepo.CreateAPIKey(context.Background(), apiKey)) + + newApiKey, err := apiKeyRepo.FindAPIKeyByMaskID(context.Background(), apiKey.MaskID) + require.NoError(t, err) + + apiKey.ExpiresAt = null.Time{} + newApiKey.CreatedAt = time.Time{} + newApiKey.UpdatedAt = time.Time{} + newApiKey.ExpiresAt = null.Time{} + + require.Equal(t, apiKey, newApiKey) +} + +func Test_FindAPIKeyByHash(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + project := seedProject(t, db) + endpoint := seedEndpoint(t, db) + + apiKeyRepo := NewAPIKeyRepo(db) + apiKey := generateApiKey(project, endpoint) + + _, err := apiKeyRepo.FindAPIKeyByHash(context.Background(), apiKey.Hash) + require.Error(t, err) + require.True(t, errors.Is(err, datastore.ErrAPIKeyNotFound)) + + require.NoError(t, apiKeyRepo.CreateAPIKey(context.Background(), apiKey)) + + newApiKey, err := apiKeyRepo.FindAPIKeyByHash(context.Background(), apiKey.Hash) + require.NoError(t, err) + + apiKey.ExpiresAt = null.Time{} + newApiKey.CreatedAt = time.Time{} + newApiKey.UpdatedAt = time.Time{} + newApiKey.ExpiresAt = null.Time{} + + require.Equal(t, apiKey, newApiKey) +} + +func Test_UpdateAPIKey(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + project := seedProject(t, db) + endpoint := seedEndpoint(t, db) + + apiKeyRepo := NewAPIKeyRepo(db) + apiKey := generateApiKey(project, endpoint) + + require.NoError(t, apiKeyRepo.CreateAPIKey(context.Background(), apiKey)) + + apiKey.Name = "Updated-Test-Api-Key" + apiKey.Role = auth.Role{ + Type: auth.RoleSuperUser, + Project: project.UID, + } + + require.NoError(t, apiKeyRepo.UpdateAPIKey(context.Background(), apiKey)) + + newApiKey, err := apiKeyRepo.FindAPIKeyByID(context.Background(), apiKey.UID) + require.NoError(t, err) + + apiKey.ExpiresAt = null.Time{} + newApiKey.CreatedAt = time.Time{} + newApiKey.UpdatedAt = time.Time{} + newApiKey.ExpiresAt = null.Time{} + + require.Equal(t, apiKey, newApiKey) +} + +func Test_RevokeAPIKey(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + project := seedProject(t, db) + endpoint := seedEndpoint(t, db) + + apiKeyRepo := NewAPIKeyRepo(db) + apiKey := generateApiKey(project, endpoint) + + require.NoError(t, apiKeyRepo.CreateAPIKey(context.Background(), apiKey)) + + _, err := apiKeyRepo.FindAPIKeyByID(context.Background(), apiKey.UID) + require.NoError(t, err) + + require.NoError(t, apiKeyRepo.RevokeAPIKeys(context.Background(), []string{apiKey.UID})) + + _, err = apiKeyRepo.FindAPIKeyByID(context.Background(), apiKey.UID) + require.Error(t, err) + require.True(t, errors.Is(err, datastore.ErrAPIKeyNotFound)) +} + +func Test_LoadAPIKeysPaged(t *testing.T) { + type Expected struct { + paginationData datastore.PaginationData + } + + tests := []struct { + name string + pageData datastore.Pageable + count int + expected Expected + }{ + { + name: "Load API Keys Paged - 10 records", + pageData: datastore.Pageable{PerPage: 3}, + count: 10, + expected: Expected{ + paginationData: datastore.PaginationData{ + PerPage: 3, + }, + }, + }, + + { + name: "Load API Keys Paged - 12 records", + pageData: datastore.Pageable{PerPage: 4}, + count: 12, + expected: Expected{ + paginationData: datastore.PaginationData{ + PerPage: 4, + }, + }, + }, + + { + name: "Load API Keys Paged - 5 records", + pageData: datastore.Pageable{PerPage: 3}, + count: 5, + expected: Expected{ + paginationData: datastore.PaginationData{ + PerPage: 3, + }, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + project := seedProject(t, db) + + apiKeyRepo := NewAPIKeyRepo(db) + for i := 0; i < tc.count; i++ { + apiKey := &datastore.APIKey{ + UID: ulid.Make().String(), + MaskID: ulid.Make().String(), + Name: "Test Api Key", + Type: datastore.ProjectKey, + Role: auth.Role{ + Type: auth.RoleAdmin, + Project: project.UID, + }, + Hash: ulid.Make().String(), + Salt: ulid.Make().String(), + ExpiresAt: null.NewTime(time.Now().Add(5*time.Minute), true), + } + require.NoError(t, apiKeyRepo.CreateAPIKey(context.Background(), apiKey)) + } + + _, pageable, err := apiKeyRepo.LoadAPIKeysPaged(context.Background(), &datastore.ApiKeyFilter{ProjectID: project.UID}, &tc.pageData) + + require.NoError(t, err) + + require.Equal(t, tc.expected.paginationData.PerPage, pageable.PerPage) + }) + } +} + +func generateApiKey(project *datastore.Project, endpoint *datastore.Endpoint) *datastore.APIKey { + return &datastore.APIKey{ + UID: ulid.Make().String(), + MaskID: ulid.Make().String(), + Name: "Test Api Key", + Type: datastore.ProjectKey, + Role: auth.Role{ + Type: auth.RoleAdmin, + Project: project.UID, + Endpoint: endpoint.UID, + }, + Hash: ulid.Make().String(), + Salt: ulid.Make().String(), + ExpiresAt: null.NewTime(time.Now().Add(5*time.Minute), true), + } +} diff --git a/database/sqlite3/configuration.go b/database/sqlite3/configuration.go new file mode 100644 index 0000000000..17bb050750 --- /dev/null +++ b/database/sqlite3/configuration.go @@ -0,0 +1,227 @@ +package sqlite3 + +import ( + "context" + "database/sql" + "errors" + "github.com/frain-dev/convoy/util" + + "github.com/frain-dev/convoy/database" + "github.com/frain-dev/convoy/datastore" + "github.com/jmoiron/sqlx" + "gopkg.in/guregu/null.v4" +) + +const ( + createConfiguration = ` + INSERT INTO configurations( + id, is_analytics_enabled, is_signup_enabled, + storage_policy_type, on_prem_path, s3_prefix, + s3_bucket, s3_access_key, s3_secret_key, + s3_region, s3_session_token, s3_endpoint, + retention_policy_policy, retention_policy_enabled, + cb_sample_rate,cb_error_timeout, + cb_failure_threshold, cb_success_threshold, + cb_observability_window, + cb_consecutive_failure_threshold, cb_minimum_request_count + ) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21); + ` + + fetchConfiguration = ` + SELECT + id, + is_analytics_enabled, + is_signup_enabled, + retention_policy_enabled AS "retention_policy.enabled", + retention_policy_policy AS "retention_policy.policy", + storage_policy_type AS "storage_policy.type", + on_prem_path AS "storage_policy.on_prem.path", + s3_bucket AS "storage_policy.s3.bucket", + s3_access_key AS "storage_policy.s3.access_key", + s3_secret_key AS "storage_policy.s3.secret_key", + s3_region AS "storage_policy.s3.region", + s3_session_token AS "storage_policy.s3.session_token", + s3_endpoint AS "storage_policy.s3.endpoint", + s3_prefix AS "storage_policy.s3.prefix", + cb_sample_rate AS "circuit_breaker.sample_rate", + cb_error_timeout AS "circuit_breaker.error_timeout", + cb_failure_threshold AS "circuit_breaker.failure_threshold", + cb_success_threshold AS "circuit_breaker.success_threshold", + cb_observability_window AS "circuit_breaker.observability_window", + cb_minimum_request_count as "circuit_breaker.minimum_request_count", + cb_consecutive_failure_threshold AS "circuit_breaker.consecutive_failure_threshold", + created_at, + updated_at, + deleted_at + FROM configurations + WHERE deleted_at IS NULL LIMIT 1; + ` + + updateConfiguration = ` + UPDATE + configurations + SET + is_analytics_enabled = $2, + is_signup_enabled = $3, + storage_policy_type = $4, + on_prem_path = $5, + s3_bucket = $6, + s3_access_key = $7, + s3_secret_key = $8, + s3_region = $9, + s3_session_token = $10, + s3_endpoint = $11, + s3_prefix = $12, + retention_policy_policy = $13, + retention_policy_enabled = $14, + cb_sample_rate = $15, + cb_error_timeout = $16, + cb_failure_threshold = $17, + cb_success_threshold = $18, + cb_observability_window = $19, + cb_consecutive_failure_threshold = $20, + cb_minimum_request_count = $21, + updated_at = NOW() + WHERE id = $1 AND deleted_at IS NULL; + ` +) + +type configRepo struct { + db *sqlx.DB +} + +func NewConfigRepo(db database.Database) datastore.ConfigurationRepository { + return &configRepo{db: db.GetDB()} +} + +func (c *configRepo) CreateConfiguration(ctx context.Context, config *datastore.Configuration) error { + if config.StoragePolicy.Type == datastore.OnPrem { + config.StoragePolicy.S3 = &datastore.S3Storage{ + Prefix: null.NewString("", false), + Bucket: null.NewString("", false), + AccessKey: null.NewString("", false), + SecretKey: null.NewString("", false), + Region: null.NewString("", false), + SessionToken: null.NewString("", false), + Endpoint: null.NewString("", false), + } + } else { + config.StoragePolicy.OnPrem = &datastore.OnPremStorage{ + Path: null.NewString("", false), + } + } + + rc := config.GetRetentionPolicyConfig() + cb := config.GetCircuitBreakerConfig() + + r, err := c.db.ExecContext(ctx, createConfiguration, + config.UID, + util.BoolToText(config.IsAnalyticsEnabled), + config.IsSignupEnabled, + config.StoragePolicy.Type, + config.StoragePolicy.OnPrem.Path, + config.StoragePolicy.S3.Prefix, + config.StoragePolicy.S3.Bucket, + config.StoragePolicy.S3.AccessKey, + config.StoragePolicy.S3.SecretKey, + config.StoragePolicy.S3.Region, + config.StoragePolicy.S3.SessionToken, + config.StoragePolicy.S3.Endpoint, + rc.Policy, + rc.IsRetentionPolicyEnabled, + cb.SampleRate, + cb.ErrorTimeout, + cb.FailureThreshold, + cb.SuccessThreshold, + cb.ObservabilityWindow, + cb.ConsecutiveFailureThreshold, + cb.MinimumRequestCount, + ) + if err != nil { + return err + } + + nRows, err := r.RowsAffected() + if err != nil { + return err + } + + if nRows < 1 { + return errors.New("configuration not created") + } + + return nil +} + +func (c *configRepo) LoadConfiguration(ctx context.Context) (*datastore.Configuration, error) { + config := &datastore.Configuration{} + err := c.db.QueryRowxContext(ctx, fetchConfiguration).StructScan(config) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, datastore.ErrConfigNotFound + } + return nil, err + } + + return config, nil +} + +func (c *configRepo) UpdateConfiguration(ctx context.Context, cfg *datastore.Configuration) error { + if cfg.StoragePolicy.Type == datastore.OnPrem { + cfg.StoragePolicy.S3 = &datastore.S3Storage{ + Prefix: null.NewString("", false), + Bucket: null.NewString("", false), + AccessKey: null.NewString("", false), + SecretKey: null.NewString("", false), + Region: null.NewString("", false), + SessionToken: null.NewString("", false), + Endpoint: null.NewString("", false), + } + } else { + cfg.StoragePolicy.OnPrem = &datastore.OnPremStorage{ + Path: null.NewString("", false), + } + } + + rc := cfg.GetRetentionPolicyConfig() + cb := cfg.GetCircuitBreakerConfig() + + result, err := c.db.ExecContext(ctx, updateConfiguration, + cfg.UID, + util.BoolToText(cfg.IsAnalyticsEnabled), + cfg.IsSignupEnabled, + cfg.StoragePolicy.Type, + cfg.StoragePolicy.OnPrem.Path, + cfg.StoragePolicy.S3.Bucket, + cfg.StoragePolicy.S3.AccessKey, + cfg.StoragePolicy.S3.SecretKey, + cfg.StoragePolicy.S3.Region, + cfg.StoragePolicy.S3.SessionToken, + cfg.StoragePolicy.S3.Endpoint, + cfg.StoragePolicy.S3.Prefix, + rc.Policy, + rc.IsRetentionPolicyEnabled, + cb.SampleRate, + cb.ErrorTimeout, + cb.FailureThreshold, + cb.SuccessThreshold, + cb.ObservabilityWindow, + cb.ConsecutiveFailureThreshold, + cb.MinimumRequestCount, + ) + if err != nil { + return err + } + + nRows, err := result.RowsAffected() + if err != nil { + return err + } + + if nRows < 1 { + return errors.New("configuration not updated") + } + + return nil +} diff --git a/database/sqlite3/configuration_test.go b/database/sqlite3/configuration_test.go new file mode 100644 index 0000000000..6e8e93f4c0 --- /dev/null +++ b/database/sqlite3/configuration_test.go @@ -0,0 +1,124 @@ +//go:build integration +// +build integration + +package sqlite3 + +import ( + "context" + "errors" + "testing" + "time" + + "gopkg.in/guregu/null.v4" + + "github.com/oklog/ulid/v2" + + "github.com/frain-dev/convoy/datastore" + "github.com/stretchr/testify/require" +) + +func Test_CreateConfiguration(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + configRepo := NewConfigRepo(db) + config := generateConfig() + + require.NoError(t, configRepo.CreateConfiguration(context.Background(), config)) + + newConfig, err := configRepo.LoadConfiguration(context.Background()) + require.NoError(t, err) + + newConfig.CreatedAt = time.Time{} + newConfig.UpdatedAt = time.Time{} + + config.CreatedAt = time.Time{} + config.UpdatedAt = time.Time{} + + require.Equal(t, config, newConfig) +} + +func Test_LoadConfiguration(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + configRepo := NewConfigRepo(db) + config := generateConfig() + + _, err := configRepo.LoadConfiguration(context.Background()) + + require.Error(t, err) + require.True(t, errors.Is(err, datastore.ErrConfigNotFound)) + + require.NoError(t, configRepo.CreateConfiguration(context.Background(), config)) + + newConfig, err := configRepo.LoadConfiguration(context.Background()) + require.NoError(t, err) + + newConfig.CreatedAt = time.Time{} + newConfig.UpdatedAt = time.Time{} + + config.CreatedAt = time.Time{} + config.UpdatedAt = time.Time{} + + require.Equal(t, config, newConfig) +} + +func Test_UpdateConfiguration(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + configRepo := NewConfigRepo(db) + config := generateConfig() + + require.NoError(t, configRepo.CreateConfiguration(context.Background(), config)) + + config.IsAnalyticsEnabled = false + require.NoError(t, configRepo.UpdateConfiguration(context.Background(), config)) + + newConfig, err := configRepo.LoadConfiguration(context.Background()) + require.NoError(t, err) + + newConfig.CreatedAt = time.Time{} + newConfig.UpdatedAt = time.Time{} + + config.CreatedAt = time.Time{} + config.UpdatedAt = time.Time{} + + require.Equal(t, config, newConfig) +} + +func generateConfig() *datastore.Configuration { + return &datastore.Configuration{ + UID: ulid.Make().String(), + IsAnalyticsEnabled: true, + IsSignupEnabled: false, + StoragePolicy: &datastore.StoragePolicyConfiguration{ + Type: datastore.OnPrem, + S3: &datastore.S3Storage{ + Prefix: null.NewString("random7", true), + Bucket: null.NewString("random1", true), + AccessKey: null.NewString("random2", true), + SecretKey: null.NewString("random3", true), + Region: null.NewString("random4", true), + SessionToken: null.NewString("random5", true), + Endpoint: null.NewString("random6", true), + }, + OnPrem: &datastore.OnPremStorage{ + Path: null.NewString("path", true), + }, + }, + RetentionPolicy: &datastore.RetentionPolicyConfiguration{ + Policy: "720h", + IsRetentionPolicyEnabled: true, + }, + CircuitBreakerConfig: &datastore.CircuitBreakerConfig{ + SampleRate: 30, + ErrorTimeout: 30, + FailureThreshold: 10, + SuccessThreshold: 5, + ObservabilityWindow: 5, + ConsecutiveFailureThreshold: 10, + }, + } +} diff --git a/database/sqlite3/delivery_attempts.go b/database/sqlite3/delivery_attempts.go new file mode 100644 index 0000000000..78b9c233fb --- /dev/null +++ b/database/sqlite3/delivery_attempts.go @@ -0,0 +1,195 @@ +package sqlite3 + +import ( + "context" + "database/sql" + "errors" + "fmt" + "github.com/frain-dev/convoy/database" + "github.com/frain-dev/convoy/datastore" + "github.com/frain-dev/convoy/pkg/circuit_breaker" + "github.com/jmoiron/sqlx" + "io" + "time" +) + +type deliveryAttemptRepo struct { + db *sqlx.DB +} + +func NewDeliveryAttemptRepo(db database.Database) datastore.DeliveryAttemptsRepository { + return &deliveryAttemptRepo{db: db.GetDB()} +} + +var ( + _ datastore.DeliveryAttemptsRepository = (*deliveryAttemptRepo)(nil) +) + +const ( + creatDeliveryAttempt = ` + INSERT INTO delivery_attempts (id, url, method, api_version, endpoint_id, event_delivery_id, project_id, ip_address, request_http_header, response_http_header, http_status, response_data, error, status) + VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14); + ` + + softDeleteProjectDeliveryAttempts = ` + UPDATE delivery_attempts SET deleted_at = NOW() WHERE project_id = $1 AND created_at >= $2 AND created_at <= $3 AND deleted_at IS NULL; + ` + + hardDeleteProjectDeliveryAttempts = ` + DELETE FROM delivery_attempts WHERE project_id = $1 AND created_at >= $2 AND created_at <= $3; + ` + + findDeliveryAttempts = `with att as (SELECT * FROM delivery_attempts WHERE event_delivery_id = $1 order by created_at desc limit 10) select * from att order by created_at;` + + findOneDeliveryAttempt = `SELECT * FROM delivery_attempts WHERE id = $1 and event_delivery_id = $2;` +) + +func (d *deliveryAttemptRepo) CreateDeliveryAttempt(ctx context.Context, attempt *datastore.DeliveryAttempt) error { + result, err := d.db.ExecContext( + ctx, creatDeliveryAttempt, attempt.UID, attempt.URL, attempt.Method, attempt.APIVersion, attempt.EndpointID, + attempt.EventDeliveryId, attempt.ProjectId, attempt.IPAddress, attempt.RequestHeader, attempt.ResponseHeader, attempt.HttpResponseCode, + attempt.ResponseData, attempt.Error, attempt.Status, + ) + if err != nil { + return err + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + + if rowsAffected < 1 { + return ErrEventDeliveryNotCreated + } + + return nil +} + +func (d *deliveryAttemptRepo) FindDeliveryAttemptById(ctx context.Context, eventDeliveryId string, id string) (*datastore.DeliveryAttempt, error) { + attempt := &datastore.DeliveryAttempt{} + err := d.db.QueryRowxContext(ctx, findOneDeliveryAttempt, id, eventDeliveryId).StructScan(attempt) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, datastore.ErrDeliveryAttemptNotFound + } + return nil, err + } + + return attempt, nil +} + +func (d *deliveryAttemptRepo) FindDeliveryAttempts(ctx context.Context, eventDeliveryId string) ([]datastore.DeliveryAttempt, error) { + var attempts []datastore.DeliveryAttempt + rows, err := d.db.QueryxContext(ctx, findDeliveryAttempts, eventDeliveryId) + if err != nil { + return nil, err + } + defer closeWithError(rows) + + for rows.Next() { + var attempt datastore.DeliveryAttempt + + err = rows.StructScan(&attempt) + if err != nil { + return nil, err + } + + attempts = append(attempts, attempt) + } + + return attempts, nil +} + +func (d *deliveryAttemptRepo) DeleteProjectDeliveriesAttempts(ctx context.Context, projectID string, filter *datastore.DeliveryAttemptsFilter, hardDelete bool) error { + var result sql.Result + var err error + + start := time.Unix(filter.CreatedAtStart, 0) + end := time.Unix(filter.CreatedAtEnd, 0) + + if hardDelete { + result, err = d.db.ExecContext(ctx, hardDeleteProjectDeliveryAttempts, projectID, start, end) + } else { + result, err = d.db.ExecContext(ctx, softDeleteProjectDeliveryAttempts, projectID, start, end) + } + + if err != nil { + return err + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + + if rowsAffected < 1 { + return datastore.ErrDeliveryAttemptsNotDeleted + } + + return nil +} + +func (d *deliveryAttemptRepo) GetFailureAndSuccessCounts(ctx context.Context, lookBackDuration uint64, resetTimes map[string]time.Time) (map[string]circuit_breaker.PollResult, error) { + resultsMap := map[string]circuit_breaker.PollResult{} + + query := ` + SELECT + endpoint_id AS key, + project_id AS tenant_id, + COUNT(CASE WHEN status = false THEN 1 END) AS failures, + COUNT(CASE WHEN status = true THEN 1 END) AS successes + FROM delivery_attempts + WHERE created_at >= NOW() - MAKE_INTERVAL(mins := $1) + group by endpoint_id, project_id; + ` + + rows, err := d.db.QueryxContext(ctx, query, lookBackDuration) + if err != nil { + return nil, err + } + defer rows.Close() + + for rows.Next() { + var rowValue circuit_breaker.PollResult + if rowScanErr := rows.StructScan(&rowValue); rowScanErr != nil { + return nil, rowScanErr + } + resultsMap[rowValue.Key] = rowValue + } + + // this is an n+1 query? yikes + query2 := ` + SELECT + endpoint_id AS key, + project_id AS tenant_id, + COUNT(CASE WHEN status = false THEN 1 END) AS failures, + COUNT(CASE WHEN status = true THEN 1 END) AS successes + FROM delivery_attempts + WHERE endpoint_id = '%s' AND created_at >= TIMESTAMP '%s' AT TIME ZONE 'UTC' + group by endpoint_id, project_id; + ` + + customFormat := "2006-01-02 15:04:05" + for k, t := range resetTimes { + // remove the old key so it doesn't pollute the results + delete(resultsMap, k) + qq := fmt.Sprintf(query2, k, t.Format(customFormat)) + + var rowValue circuit_breaker.PollResult + err = d.db.QueryRowxContext(ctx, qq).StructScan(&rowValue) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + continue + } + } + + resultsMap[k] = rowValue + } + + return resultsMap, nil +} + +func (d *deliveryAttemptRepo) ExportRecords(ctx context.Context, projectID string, createdAt time.Time, w io.Writer) (int64, error) { + return exportRecords(ctx, d.db, "delivery_attempts", projectID, createdAt, w) +} diff --git a/database/sqlite3/delivery_attempts_test.go b/database/sqlite3/delivery_attempts_test.go new file mode 100644 index 0000000000..792d36eedd --- /dev/null +++ b/database/sqlite3/delivery_attempts_test.go @@ -0,0 +1,122 @@ +//go:build integration +// +build integration + +package sqlite3 + +import ( + "context" + + "github.com/frain-dev/convoy/datastore" + "github.com/oklog/ulid/v2" + "github.com/stretchr/testify/require" + "testing" +) + +func TestCreateDeliveryAttempt(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + attemptsRepo := NewDeliveryAttemptRepo(db) + ctx := context.Background() + + source := seedSource(t, db) + project := seedProject(t, db) + device := seedDevice(t, db) + endpoint := seedEndpoint(t, db) + event := seedEvent(t, db, project) + sub := seedSubscription(t, db, project, source, endpoint, device) + ed := generateEventDelivery(project, endpoint, event, device, sub) + + uid := ulid.Make().String() + + edRepo := NewEventDeliveryRepo(db) + err := edRepo.CreateEventDelivery(ctx, ed) + + attempt := &datastore.DeliveryAttempt{ + UID: uid, + EventDeliveryId: ed.UID, + URL: "https://example.com", + Method: "POST", + ProjectId: project.UID, + EndpointID: endpoint.UID, + APIVersion: "2024-01-01", + IPAddress: "192.0.0.1", + RequestHeader: map[string]string{"Content-Type": "application/json"}, + ResponseHeader: map[string]string{"Content-Type": "application/json"}, + HttpResponseCode: "200", + ResponseData: "{\"status\":\"ok\"}", + Status: true, + } + + err = attemptsRepo.CreateDeliveryAttempt(ctx, attempt) + require.NoError(t, err) + + att, err := attemptsRepo.FindDeliveryAttemptById(ctx, ed.UID, uid) + require.NoError(t, err) + + require.Equal(t, att.ResponseData, attempt.ResponseData) +} + +func TestFindDeliveryAttempts(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + attemptsRepo := NewDeliveryAttemptRepo(db) + ctx := context.Background() + + source := seedSource(t, db) + project := seedProject(t, db) + device := seedDevice(t, db) + endpoint := seedEndpoint(t, db) + event := seedEvent(t, db, project) + sub := seedSubscription(t, db, project, source, endpoint, device) + ed := generateEventDelivery(project, endpoint, event, device, sub) + + edRepo := NewEventDeliveryRepo(db) + err := edRepo.CreateEventDelivery(ctx, ed) + + attempts := []datastore.DeliveryAttempt{ + { + UID: ulid.Make().String(), + EventDeliveryId: ed.UID, + URL: "https://example.com", + Method: "POST", + EndpointID: endpoint.UID, + ProjectId: project.UID, + APIVersion: "2024-01-01", + IPAddress: "192.168.0.1", + RequestHeader: map[string]string{"Content-Type": "application/json"}, + ResponseHeader: map[string]string{"Content-Type": "application/json"}, + HttpResponseCode: "200", + ResponseData: "{\"status\":\"ok\"}", + Status: true, + }, + { + UID: ulid.Make().String(), + EventDeliveryId: ed.UID, + URL: "https://main.com", + Method: "POST", + EndpointID: endpoint.UID, + ProjectId: project.UID, + APIVersion: "2024-04-04", + IPAddress: "127.0.0.1", + RequestHeader: map[string]string{"Content-Type": "application/json"}, + ResponseHeader: map[string]string{"Content-Type": "application/json"}, + HttpResponseCode: "400", + ResponseData: "{\"status\":\"Not Found\"}", + Error: "", + Status: false, + }, + } + + for _, a := range attempts { + err = attemptsRepo.CreateDeliveryAttempt(ctx, &a) + require.NoError(t, err) + } + + atts, err := attemptsRepo.FindDeliveryAttempts(ctx, ed.UID) + require.NoError(t, err) + + require.Equal(t, atts[0].ResponseData, attempts[0].ResponseData) + require.Equal(t, atts[1].HttpResponseCode, attempts[1].HttpResponseCode) +} diff --git a/database/sqlite3/device.go b/database/sqlite3/device.go new file mode 100644 index 0000000000..e717336ed4 --- /dev/null +++ b/database/sqlite3/device.go @@ -0,0 +1,314 @@ +package sqlite3 + +import ( + "context" + "database/sql" + "errors" + "fmt" + "github.com/frain-dev/convoy/database" + "github.com/frain-dev/convoy/datastore" + "github.com/jmoiron/sqlx" +) + +var ( + ErrDeviceNotCreated = errors.New("device could not be created") + ErrDeviceNotFound = errors.New("device not found") + ErrDeviceNotUpdated = errors.New("device could not be updated") + ErrDeviceNotDeleted = errors.New("device could not be deleted") +) + +const ( + createDevice = ` + INSERT INTO devices (id, project_id, host_name, status, last_seen_at) + VALUES ($1, $2, $3, $4, $5) + ` + + updateDevice = ` + UPDATE devices SET + host_name = $3, + status = $4, + updated_at = NOW() + WHERE id = $1 AND project_id = $2 AND deleted_at IS NULL; + ` + updateDeviceLastSeen = ` + UPDATE devices SET + status = $3, + last_seen_at = NOW(), + updated_at = NOW() + WHERE id = $1 AND project_id = $2 AND deleted_at IS NULL; + ` + + deleteDevice = ` + UPDATE devices SET + deleted_at = NOW() + WHERE id = $1 AND project_id = $2 AND deleted_at IS NULL; + ` + + fetchDeviceById = ` + SELECT * FROM devices + WHERE id = $1 AND project_id = $2 AND deleted_at IS NULL; + ` + + fetchDeviceByHostName = ` + SELECT * FROM devices + WHERE host_name = $1 AND project_id = $2 AND deleted_at IS NULL; + ` + + fetchDevicesPaginated = ` + SELECT * FROM devices WHERE deleted_at IS NULL` + + baseDevicesFilter = ` + AND project_id = :project_id` + + baseFetchDevicesPagedForward = ` + %s + %s + AND id <= :cursor + GROUP BY id + ORDER BY id DESC + LIMIT :limit + ` + + baseFetchDevicesPagedBackward = ` + WITH devices AS ( + %s + %s + AND id >= :cursor + GROUP BY id + ORDER BY id ASC + LIMIT :limit + ) + + SELECT * FROM devices ORDER BY id DESC + ` + + countPrevDevices = ` + SELECT COUNT(DISTINCT(id)) AS count + FROM devices + WHERE deleted_at IS NULL + %s + AND id > :cursor GROUP BY id ORDER BY id DESC LIMIT 1` +) + +type deviceRepo struct { + db *sqlx.DB +} + +func NewDeviceRepo(db database.Database) datastore.DeviceRepository { + return &deviceRepo{db: db.GetDB()} +} + +func (d *deviceRepo) CreateDevice(ctx context.Context, device *datastore.Device) error { + r, err := d.db.ExecContext(ctx, createDevice, + device.UID, + device.ProjectID, + device.HostName, + device.Status, + device.LastSeenAt, + ) + if err != nil { + return err + } + + rowsAffected, err := r.RowsAffected() + if err != nil { + return err + } + + if rowsAffected < 1 { + return ErrDeviceNotCreated + } + + return nil +} + +func (d *deviceRepo) UpdateDevice(ctx context.Context, device *datastore.Device, endpointID, projectID string) error { + r, err := d.db.ExecContext(ctx, updateDevice, + device.UID, + projectID, + device.HostName, + device.Status, + ) + if err != nil { + return err + } + + rowsAffected, err := r.RowsAffected() + if err != nil { + return err + } + + if rowsAffected < 1 { + return ErrDeviceNotUpdated + } + + return nil +} + +func (d *deviceRepo) UpdateDeviceLastSeen(ctx context.Context, device *datastore.Device, endpointID, projectID string, status datastore.DeviceStatus) error { + r, err := d.db.ExecContext(ctx, updateDeviceLastSeen, + device.UID, + projectID, + status, + ) + if err != nil { + return err + } + + rowsAffected, err := r.RowsAffected() + if err != nil { + return err + } + + if rowsAffected < 1 { + return ErrDeviceNotUpdated + } + + return nil +} + +func (d *deviceRepo) DeleteDevice(ctx context.Context, uid string, endpointID, projectID string) error { + r, err := d.db.ExecContext(ctx, deleteDevice, uid, projectID) + if err != nil { + return err + } + + rowsAffected, err := r.RowsAffected() + if err != nil { + return err + } + + if rowsAffected < 1 { + return ErrDeviceNotDeleted + } + + return nil +} + +func (d *deviceRepo) FetchDeviceByID(ctx context.Context, uid string, endpointID, projectID string) (*datastore.Device, error) { + device := &datastore.Device{} + err := d.db.QueryRowxContext(ctx, fetchDeviceById, uid, projectID).StructScan(device) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrDeviceNotFound + } + return nil, err + } + + return device, nil +} + +func (d *deviceRepo) FetchDeviceByHostName(ctx context.Context, hostName string, endpointID, projectID string) (*datastore.Device, error) { + device := &datastore.Device{} + err := d.db.QueryRowxContext(ctx, fetchDeviceByHostName, hostName, projectID).StructScan(device) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrDeviceNotFound + } + return nil, err + } + + return device, nil +} + +func (d *deviceRepo) LoadDevicesPaged(ctx context.Context, projectID string, filter *datastore.ApiKeyFilter, pageable datastore.Pageable) ([]datastore.Device, datastore.PaginationData, error) { + var query, filterQuery string + var args []interface{} + var err error + + arg := map[string]interface{}{ + "project_id": projectID, + "limit": pageable.Limit(), + "cursor": pageable.Cursor(), + } + + if pageable.Direction == datastore.Next { + query = baseFetchDevicesPagedForward + } else { + query = baseFetchDevicesPagedBackward + } + + filterQuery = baseDevicesFilter + + query = fmt.Sprintf(query, fetchDevicesPaginated, filterQuery) + + query, args, err = sqlx.Named(query, arg) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + query, args, err = sqlx.In(query, args...) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + query = d.db.Rebind(query) + + rows, err := d.db.QueryxContext(ctx, query, args...) + if err != nil { + return nil, datastore.PaginationData{}, err + } + defer closeWithError(rows) + + var devices []datastore.Device + for rows.Next() { + var data DevicePaginated + + err = rows.StructScan(&data) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + devices = append(devices, data.Device) + } + + var count datastore.PrevRowCount + if len(devices) > 0 { + var countQuery string + var qargs []interface{} + first := devices[0] + qarg := arg + qarg["cursor"] = first.UID + + cq := fmt.Sprintf(countPrevDevices, filterQuery) + countQuery, qargs, err = sqlx.Named(cq, qarg) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + countQuery = d.db.Rebind(countQuery) + + // count the row number before the first row + rows, err := d.db.QueryxContext(ctx, countQuery, qargs...) + if err != nil { + return nil, datastore.PaginationData{}, err + } + defer closeWithError(rows) + + if rows.Next() { + err = rows.StructScan(&count) + if err != nil { + return nil, datastore.PaginationData{}, err + } + } + } + + ids := make([]string, len(devices)) + for i := range devices { + ids[i] = devices[i].UID + } + + if len(devices) > pageable.PerPage { + devices = devices[:len(devices)-1] + } + + pagination := &datastore.PaginationData{PrevRowCount: count} + pagination = pagination.Build(pageable, ids) + + return devices, *pagination, nil +} + +type DevicePaginated struct { + Count int + datastore.Device +} diff --git a/database/sqlite3/device_test.go b/database/sqlite3/device_test.go new file mode 100644 index 0000000000..f2cb4e9d3b --- /dev/null +++ b/database/sqlite3/device_test.go @@ -0,0 +1,212 @@ +//go:build integration +// +build integration + +package sqlite3 + +import ( + "context" + "testing" + "time" + + "github.com/oklog/ulid/v2" + + "github.com/frain-dev/convoy/database" + "github.com/frain-dev/convoy/datastore" + "github.com/stretchr/testify/require" +) + +func Test_CreateDevice(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + repo := NewDeviceRepo(db) + device := generateDevice(t, db) + + require.NoError(t, repo.CreateDevice(context.Background(), device)) + + newDevice, err := repo.FetchDeviceByID(context.Background(), device.UID, device.EndpointID, device.ProjectID) + require.NoError(t, err) + + require.InDelta(t, device.LastSeenAt.Unix(), newDevice.LastSeenAt.Unix(), float64(time.Hour)) + newDevice.CreatedAt, newDevice.UpdatedAt = time.Time{}, time.Time{} + device.LastSeenAt, newDevice.LastSeenAt = time.Time{}, time.Time{} + + require.Equal(t, device, newDevice) +} + +func Test_UpdateDevice(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + repo := NewDeviceRepo(db) + device := generateDevice(t, db) + + require.NoError(t, repo.CreateDevice(context.Background(), device)) + + device.Status = datastore.DeviceStatusOffline + err := repo.UpdateDevice(context.Background(), device, device.EndpointID, device.ProjectID) + require.NoError(t, err) + + newDevice, err := repo.FetchDeviceByID(context.Background(), device.UID, device.EndpointID, device.ProjectID) + require.NoError(t, err) + + require.InDelta(t, device.LastSeenAt.Unix(), newDevice.LastSeenAt.Unix(), float64(time.Hour)) + newDevice.CreatedAt, newDevice.UpdatedAt = time.Time{}, time.Time{} + device.LastSeenAt, newDevice.LastSeenAt = time.Time{}, time.Time{} + + require.Equal(t, device, newDevice) +} + +func Test_UpdateDeviceLastSeen(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + deviceRepo := NewDeviceRepo(db) + device := generateDevice(t, db) + + require.NoError(t, deviceRepo.CreateDevice(context.Background(), device)) + + device.Status = datastore.DeviceStatusOffline + err := deviceRepo.UpdateDeviceLastSeen(context.Background(), device, device.EndpointID, device.ProjectID, datastore.DeviceStatusOffline) + require.NoError(t, err) + + newDevice, err := deviceRepo.FetchDeviceByID(context.Background(), device.UID, device.EndpointID, device.ProjectID) + require.NoError(t, err) + + require.InDelta(t, device.LastSeenAt.Unix(), newDevice.LastSeenAt.Unix(), float64(time.Hour)) + newDevice.CreatedAt, newDevice.UpdatedAt = time.Time{}, time.Time{} + device.LastSeenAt, newDevice.LastSeenAt = time.Time{}, time.Time{} + + require.Equal(t, device, newDevice) +} + +func Test_DeleteDevice(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + repo := NewDeviceRepo(db) + device := generateDevice(t, db) + + require.NoError(t, repo.CreateDevice(context.Background(), device)) + + err := repo.DeleteDevice(context.Background(), device.UID, device.EndpointID, device.ProjectID) + require.NoError(t, err) + + _, err = repo.FetchDeviceByID(context.Background(), device.UID, device.EndpointID, device.ProjectID) + require.Equal(t, datastore.ErrDeviceNotFound, err) +} + +func Test_FetchDeviceByID(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + repo := NewDeviceRepo(db) + device := generateDevice(t, db) + + require.NoError(t, repo.CreateDevice(context.Background(), device)) + + newDevice, err := repo.FetchDeviceByID(context.Background(), device.UID, device.EndpointID, device.ProjectID) + require.NoError(t, err) + + require.InDelta(t, device.LastSeenAt.Unix(), newDevice.LastSeenAt.Unix(), float64(time.Hour)) + newDevice.CreatedAt, newDevice.UpdatedAt = time.Time{}, time.Time{} + device.LastSeenAt, newDevice.LastSeenAt = time.Time{}, time.Time{} + + require.Equal(t, device, newDevice) +} + +func Test_LoadDevicesPaged(t *testing.T) { + type Expected struct { + paginationData datastore.PaginationData + } + + tests := []struct { + name string + pageData datastore.Pageable + count int + expected Expected + }{ + { + name: "Load Devices Paged - 10 records", + pageData: datastore.Pageable{PerPage: 3}, + count: 10, + expected: Expected{ + paginationData: datastore.PaginationData{ + PerPage: 3, + }, + }, + }, + + { + name: "Load Devices Paged - 12 records", + pageData: datastore.Pageable{PerPage: 4}, + count: 12, + expected: Expected{ + paginationData: datastore.PaginationData{ + PerPage: 4, + }, + }, + }, + + { + name: "Load Devices Paged - 5 records", + pageData: datastore.Pageable{PerPage: 3}, + count: 5, + expected: Expected{ + paginationData: datastore.PaginationData{ + PerPage: 3, + }, + }, + }, + + { + name: "Load Devices Paged - 1 record", + pageData: datastore.Pageable{PerPage: 3}, + count: 1, + expected: Expected{ + paginationData: datastore.PaginationData{ + PerPage: 3, + }, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + repo := NewDeviceRepo(db) + project := seedProject(t, db) + + for i := 0; i < tc.count; i++ { + device := &datastore.Device{ + UID: ulid.Make().String(), + ProjectID: project.UID, + HostName: "", + Status: datastore.DeviceStatusOnline, + LastSeenAt: time.Now(), + } + + require.NoError(t, repo.CreateDevice(context.Background(), device)) + } + + _, pageable, err := repo.LoadDevicesPaged(context.Background(), project.UID, &datastore.ApiKeyFilter{}, tc.pageData) + require.NoError(t, err) + + require.Equal(t, tc.expected.paginationData.PerPage, pageable.PerPage) + }) + } +} + +func generateDevice(t *testing.T, db database.Database) *datastore.Device { + project := seedProject(t, db) + + return &datastore.Device{ + UID: ulid.Make().String(), + ProjectID: project.UID, + HostName: "", + Status: datastore.DeviceStatusOnline, + LastSeenAt: time.Now(), + } +} diff --git a/database/sqlite3/endpoint.go b/database/sqlite3/endpoint.go new file mode 100644 index 0000000000..c0a230cd41 --- /dev/null +++ b/database/sqlite3/endpoint.go @@ -0,0 +1,497 @@ +package sqlite3 + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strings" + "time" + + "github.com/frain-dev/convoy/database" + "github.com/frain-dev/convoy/database/hooks" + "github.com/frain-dev/convoy/datastore" + "github.com/frain-dev/convoy/util" + "github.com/jmoiron/sqlx" + "gopkg.in/guregu/null.v4" +) + +var ( + ErrEndpointNotCreated = errors.New("endpoint could not be created") + ErrEndpointNotUpdated = errors.New("endpoint could not be updated") + ErrEndpointExists = errors.New("an endpoint with that name already exists") +) + +const ( + createEndpoint = ` + INSERT INTO endpoints ( + id, name, status, secrets, owner_id, url, description, http_timeout, + rate_limit, rate_limit_duration, advanced_signatures, slack_webhook_url, + support_email, app_id, project_id, authentication_type, authentication_type_api_key_header_name, + authentication_type_api_key_header_value + ) + VALUES + ( + $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, + $14, $15, $16, $17, $18 + ); + ` + + baseEndpointFetch = ` + SELECT + e.id, e.name, e.status, e.owner_id, + e.url, e.description, e.http_timeout, + e.rate_limit, e.rate_limit_duration, e.advanced_signatures, + e.slack_webhook_url, e.support_email, e.app_id, + e.project_id, e.secrets, e.created_at, e.updated_at, + e.authentication_type AS "authentication.type", + e.authentication_type_api_key_header_name AS "authentication.api_key.header_name", + e.authentication_type_api_key_header_value AS "authentication.api_key.header_value" + FROM endpoints AS e + WHERE e.deleted_at IS NULL + ` + + fetchEndpointById = baseEndpointFetch + ` AND e.id = $1 AND e.project_id = $2;` + + fetchEndpointsById = baseEndpointFetch + ` AND e.id IN (?) AND e.project_id = ? GROUP BY e.id ORDER BY e.id;` + + fetchEndpointsByAppId = baseEndpointFetch + ` AND e.app_id = $1 AND e.project_id = $2 GROUP BY e.id ORDER BY e.id;` + + fetchEndpointsByOwnerId = baseEndpointFetch + ` AND e.project_id = $1 AND e.owner_id = $2 GROUP BY e.id ORDER BY e.id;` + + fetchEndpointByTargetURL = ` + SELECT e.id, e.name, e.status, e.owner_id, e.url, + e.description, e.http_timeout, e.rate_limit, e.rate_limit_duration, + e.advanced_signatures, e.slack_webhook_url, e.support_email, + e.app_id, e.project_id, e.secrets, e.created_at, e.updated_at, + e.authentication_type AS "authentication.type", + e.authentication_type_api_key_header_name AS "authentication.api_key.header_name", + e.authentication_type_api_key_header_value AS "authentication.api_key.header_value" + FROM endpoints AS e WHERE e.deleted_at IS NULL AND e.url = $1 AND e.project_id = $2; + ` + + updateEndpoint = ` + UPDATE endpoints SET + name = $3, status = $4, owner_id = $5, + url = $6, description = $7, http_timeout = $8, + rate_limit = $9, rate_limit_duration = $10, advanced_signatures = $11, + slack_webhook_url = $12, support_email = $13, + authentication_type = $14, authentication_type_api_key_header_name = $15, + authentication_type_api_key_header_value = $16, secrets = $17, + updated_at = NOW() + WHERE id = $1 AND project_id = $2 AND deleted_at IS NULL; + ` + + updateEndpointStatus = ` + UPDATE endpoints SET status = $3 + WHERE id = $1 AND project_id = $2 AND deleted_at IS NULL RETURNING + id, name, status, owner_id, url, + description, http_timeout, rate_limit, rate_limit_duration, + advanced_signatures, slack_webhook_url, support_email, + app_id, project_id, secrets, created_at, updated_at, + authentication_type AS "authentication.type", + authentication_type_api_key_header_name AS "authentication.api_key.header_name", + authentication_type_api_key_header_value AS "authentication.api_key.header_value"; + ` + + updateEndpointSecrets = ` + UPDATE endpoints SET + secrets = $3, updated_at = NOW() + WHERE id = $1 AND project_id = $2 AND deleted_at IS NULL RETURNING + id, name, status, owner_id, url, + description, http_timeout, rate_limit, rate_limit_duration, + advanced_signatures, slack_webhook_url, support_email, + app_id, project_id, secrets, created_at, updated_at, + authentication_type AS "authentication.type", + authentication_type_api_key_header_name AS "authentication.api_key.header_name", + authentication_type_api_key_header_value AS "authentication.api_key.header_value"; + ` + + deleteEndpoint = ` + UPDATE endpoints SET deleted_at = NOW() + WHERE id = $1 AND project_id = $2 AND deleted_at IS NULL; + ` + + deleteEndpointSubscriptions = ` + UPDATE subscriptions SET deleted_at = NOW() + WHERE endpoint_id = $1 AND project_id = $2 AND deleted_at IS NULL; + ` + + countProjectEndpoints = ` + SELECT COUNT(*) AS count FROM endpoints + WHERE project_id = $1 AND deleted_at IS NULL; + ` + + baseFetchEndpointsPaged = ` + SELECT + e.id, e.name, e.status, e.owner_id, + e.url, e.description, e.http_timeout, + e.rate_limit, e.rate_limit_duration, e.advanced_signatures, + e.slack_webhook_url, e.support_email, e.app_id, + e.project_id, e.secrets, e.created_at, e.updated_at, + e.authentication_type AS "authentication.type", + e.authentication_type_api_key_header_name AS "authentication.api_key.header_name", + e.authentication_type_api_key_header_value AS "authentication.api_key.header_value" + FROM endpoints AS e + WHERE e.deleted_at IS NULL + AND e.project_id = :project_id + AND (e.owner_id = :owner_id OR :owner_id = '') + AND (e.name ILIKE :name OR :name = '')` + + fetchEndpointsPagedForward = ` + %s + %s + AND e.id <= :cursor + GROUP BY e.id + ORDER BY e.id DESC + LIMIT :limit + ` + + fetchEndpointsPagedBackward = ` + WITH endpoints AS ( + %s + %s + AND e.id >= :cursor + GROUP BY e.id + ORDER BY e.id ASC + LIMIT :limit + ) + + SELECT * FROM endpoints ORDER BY id DESC + ` + + countPrevEndpoints = ` + SELECT COUNT(DISTINCT(s.id)) AS count + FROM endpoints s + WHERE s.deleted_at IS NULL + AND s.project_id = :project_id + AND (s.name ILIKE :name OR :name = '') + AND s.id > :cursor + GROUP BY s.id + ORDER BY s.id DESC + LIMIT 1` +) + +type endpointRepo struct { + db *sqlx.DB + hook *hooks.Hook +} + +func NewEndpointRepo(db database.Database) datastore.EndpointRepository { + return &endpointRepo{db: db.GetDB(), hook: db.GetHook()} +} + +func (e *endpointRepo) CreateEndpoint(ctx context.Context, endpoint *datastore.Endpoint, projectID string) error { + ac := endpoint.GetAuthConfig() + + args := []interface{}{ + endpoint.UID, endpoint.Name, endpoint.Status, endpoint.Secrets, endpoint.OwnerID, endpoint.Url, + endpoint.Description, endpoint.HttpTimeout, endpoint.RateLimit, endpoint.RateLimitDuration, + endpoint.AdvancedSignatures, endpoint.SlackWebhookURL, endpoint.SupportEmail, endpoint.AppID, + projectID, ac.Type, ac.ApiKey.HeaderName, ac.ApiKey.HeaderValue, + } + + result, err := e.db.ExecContext(ctx, createEndpoint, args...) + if err != nil { + if strings.Contains(err.Error(), "duplicate key value violates unique constraint") { + return ErrEndpointExists + } + return err + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + + if rowsAffected < 1 { + return ErrEndpointNotCreated + } + + go e.hook.Fire(datastore.EndpointCreated, endpoint, nil) + + return nil +} + +func (e *endpointRepo) FindEndpointByID(ctx context.Context, id, projectID string) (*datastore.Endpoint, error) { + endpoint := &datastore.Endpoint{} + err := e.db.QueryRowxContext(ctx, fetchEndpointById, id, projectID).StructScan(endpoint) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, datastore.ErrEndpointNotFound + } + return nil, err + } + + return endpoint, nil +} + +func (e *endpointRepo) FindEndpointsByID(ctx context.Context, ids []string, projectID string) ([]datastore.Endpoint, error) { + query, args, err := sqlx.In(fetchEndpointsById, ids, projectID) + if err != nil { + return nil, err + } + + query = e.db.Rebind(query) + rows, err := e.db.QueryxContext(ctx, query, args...) + if err != nil { + return nil, err + } + + return e.scanEndpoints(rows) +} + +func (e *endpointRepo) FindEndpointsByAppID(ctx context.Context, appID, projectID string) ([]datastore.Endpoint, error) { + rows, err := e.db.QueryxContext(ctx, fetchEndpointsByAppId, appID, projectID) + if err != nil { + return nil, err + } + + return e.scanEndpoints(rows) +} + +func (e *endpointRepo) FindEndpointsByOwnerID(ctx context.Context, projectID string, ownerID string) ([]datastore.Endpoint, error) { + rows, err := e.db.QueryxContext(ctx, fetchEndpointsByOwnerId, projectID, ownerID) + if err != nil { + return nil, err + } + + return e.scanEndpoints(rows) +} + +func (e *endpointRepo) UpdateEndpoint(ctx context.Context, endpoint *datastore.Endpoint, projectID string) error { + ac := endpoint.GetAuthConfig() + + r, err := e.db.ExecContext(ctx, updateEndpoint, endpoint.UID, projectID, endpoint.Name, endpoint.Status, endpoint.OwnerID, endpoint.Url, + endpoint.Description, endpoint.HttpTimeout, endpoint.RateLimit, endpoint.RateLimitDuration, + endpoint.AdvancedSignatures, endpoint.SlackWebhookURL, endpoint.SupportEmail, + ac.Type, ac.ApiKey.HeaderName, ac.ApiKey.HeaderValue, endpoint.Secrets, + ) + if err != nil { + return err + } + + rowsAffected, err := r.RowsAffected() + if err != nil { + return err + } + + if rowsAffected < 1 { + return ErrEndpointNotUpdated + } + + go e.hook.Fire(datastore.EndpointUpdated, endpoint, nil) + return nil +} + +func (e *endpointRepo) UpdateEndpointStatus(ctx context.Context, projectID string, endpointID string, status datastore.EndpointStatus) error { + endpoint := datastore.Endpoint{} + err := e.db.QueryRowxContext(ctx, updateEndpointStatus, endpointID, projectID, status).StructScan(&endpoint) + if err != nil { + return err + } + + return nil +} + +func (e *endpointRepo) DeleteEndpoint(ctx context.Context, endpoint *datastore.Endpoint, projectID string) error { + tx, err := e.db.BeginTxx(ctx, &sql.TxOptions{}) + if err != nil { + return err + } + defer rollbackTx(tx) + + _, err = tx.ExecContext(ctx, deleteEndpoint, endpoint.UID, projectID) + if err != nil { + return err + } + + _, err = tx.ExecContext(ctx, deleteEndpointSubscriptions, endpoint.UID, projectID) + if err != nil { + return err + } + + _, err = tx.ExecContext(ctx, deletePortalLinkEndpoints, nil, endpoint.UID) + if err != nil { + return err + } + + err = tx.Commit() + if err != nil { + return err + } + + go e.hook.Fire(datastore.EndpointDeleted, endpoint, nil) + return nil +} + +func (e *endpointRepo) CountProjectEndpoints(ctx context.Context, projectID string) (int64, error) { + var count int64 + + err := e.db.QueryRowxContext(ctx, countProjectEndpoints, projectID).Scan(&count) + if err != nil { + return count, err + } + + return count, nil +} + +func (e *endpointRepo) FindEndpointByTargetURL(ctx context.Context, projectID string, targetURL string) (*datastore.Endpoint, error) { + endpoint := &datastore.Endpoint{} + err := e.db.QueryRowxContext(ctx, fetchEndpointByTargetURL, targetURL, projectID).StructScan(endpoint) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, datastore.ErrEndpointNotFound + } + return nil, err + } + + return endpoint, nil +} + +func (e *endpointRepo) LoadEndpointsPaged(ctx context.Context, projectId string, filter *datastore.Filter, pageable datastore.Pageable) ([]datastore.Endpoint, datastore.PaginationData, error) { + q := filter.Query + if !util.IsStringEmpty(q) { + q = fmt.Sprintf("%%%s%%", q) + } + + arg := map[string]interface{}{ + "project_id": projectId, + "owner_id": filter.OwnerID, + "limit": pageable.Limit(), + "cursor": pageable.Cursor(), + "endpoint_ids": filter.EndpointIDs, + "name": q, + } + + var query, filterQuery string + if pageable.Direction == datastore.Next { + query = fetchEndpointsPagedForward + } else { + query = fetchEndpointsPagedBackward + } + + if len(filter.EndpointIDs) > 0 { + filterQuery = ` AND e.id IN (:endpoint_ids)` + } + + query = fmt.Sprintf(query, baseFetchEndpointsPaged, filterQuery) + query, args, err := sqlx.Named(query, arg) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + query, args, err = sqlx.In(query, args...) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + query = e.db.Rebind(query) + + rows, err := e.db.QueryxContext(ctx, query, args...) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + endpoints, err := e.scanEndpoints(rows) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + ids := make([]string, len(endpoints)) + for i := range endpoints { + ids[i] = endpoints[i].UID + } + + if len(endpoints) > pageable.PerPage { + endpoints = endpoints[:len(endpoints)-1] + } + + var count datastore.PrevRowCount + if len(endpoints) > 0 { + var countQuery string + var qargs []interface{} + first := endpoints[0] + qarg := arg + qarg["cursor"] = first.UID + + countQuery, qargs, err = sqlx.Named(countPrevEndpoints, qarg) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + countQuery = e.db.Rebind(countQuery) + + // count the row number before the first row + rows, err := e.db.QueryxContext(ctx, countQuery, qargs...) + if err != nil { + return nil, datastore.PaginationData{}, err + } + defer closeWithError(rows) + + if rows.Next() { + err = rows.StructScan(&count) + if err != nil { + return nil, datastore.PaginationData{}, err + } + } + } + + pagination := &datastore.PaginationData{PrevRowCount: count} + pagination = pagination.Build(pageable, ids) + + return endpoints, *pagination, nil +} + +func (e *endpointRepo) UpdateSecrets(ctx context.Context, endpointID string, projectID string, secrets datastore.Secrets) error { + endpoint := datastore.Endpoint{} + err := e.db.QueryRowxContext(ctx, updateEndpointSecrets, endpointID, projectID, secrets).StructScan(&endpoint) + if err != nil { + return err + } + + return nil +} + +func (e *endpointRepo) DeleteSecret(ctx context.Context, endpoint *datastore.Endpoint, secretID, projectID string) error { + sc := endpoint.FindSecret(secretID) + if sc == nil { + return datastore.ErrSecretNotFound + } + + sc.DeletedAt = null.NewTime(time.Now(), true) + + updatedEndpoint := datastore.Endpoint{} + err := e.db.QueryRowxContext(ctx, updateEndpointSecrets, endpoint.UID, projectID, endpoint.Secrets).StructScan(&updatedEndpoint) + if err != nil { + return err + } + + return nil +} + +func (e *endpointRepo) scanEndpoints(rows *sqlx.Rows) ([]datastore.Endpoint, error) { + endpoints := make([]datastore.Endpoint, 0) + defer closeWithError(rows) + + for rows.Next() { + var endpoint datastore.Endpoint + err := rows.StructScan(&endpoint) + if err != nil { + return nil, err + } + + endpoints = append(endpoints, endpoint) + } + + return endpoints, nil +} + +type EndpointPaginated struct { + EndpointSecret +} + +type EndpointSecret struct { + Endpoint datastore.Endpoint `json:"endpoint"` + Secret datastore.Secret `db:"secret"` +} diff --git a/database/sqlite3/endpoint_test.go b/database/sqlite3/endpoint_test.go new file mode 100644 index 0000000000..f96fc1e343 --- /dev/null +++ b/database/sqlite3/endpoint_test.go @@ -0,0 +1,605 @@ +//go:build integration +// +build integration + +package sqlite3 + +import ( + "context" + "fmt" + "testing" + "time" + + "gopkg.in/guregu/null.v4" + + "github.com/jaswdr/faker" + "github.com/oklog/ulid/v2" + + "github.com/frain-dev/convoy/database" + "github.com/frain-dev/convoy/datastore" + "github.com/stretchr/testify/require" +) + +func Test_UpdateEndpoint(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + endpointRepo := NewEndpointRepo(db) + + project := seedProject(t, db) + endpoint := generateEndpoint(project) + + err := endpointRepo.CreateEndpoint(context.Background(), endpoint, project.UID) + require.NoError(t, err) + + updatedEndpoint := &datastore.Endpoint{ + UID: endpoint.UID, + ProjectID: endpoint.ProjectID, + OwnerID: "4304jj39h43h", + Url: "https//uere.ccm", + Name: "testing_endpoint_repo", + Secrets: endpoint.Secrets, + AdvancedSignatures: true, + AppID: endpoint.AppID, + Description: "9897fdkhkhd", + SlackWebhookURL: "https:/899gfnnn", + SupportEmail: "ex@convoybbb.com", + HttpTimeout: 88, + RateLimit: 8898, + Status: datastore.ActiveEndpointStatus, + RateLimitDuration: 10, + Authentication: &datastore.EndpointAuthentication{ + Type: datastore.APIKeyAuthentication, + ApiKey: &datastore.ApiKey{ + HeaderValue: "97if7dgfg", + HeaderName: "x-header-p", + }, + }, + } + + require.NoError(t, endpointRepo.UpdateEndpoint(context.Background(), updatedEndpoint, updatedEndpoint.ProjectID)) + + dbEndpoint, err := endpointRepo.FindEndpointByID(context.Background(), endpoint.UID, project.UID) + require.NoError(t, err) + + require.NotEmpty(t, dbEndpoint.CreatedAt) + require.NotEmpty(t, dbEndpoint.UpdatedAt) + + dbEndpoint.CreatedAt = time.Time{} + dbEndpoint.UpdatedAt = time.Time{} + + for i := range dbEndpoint.Secrets { + secret := &dbEndpoint.Secrets[i] + + require.Equal(t, updatedEndpoint.Secrets[i].Value, secret.Value) + require.NotEmpty(t, secret.CreatedAt) + require.NotEmpty(t, secret.UpdatedAt) + + secret.CreatedAt, secret.UpdatedAt = time.Time{}, time.Time{} + updatedEndpoint.Secrets[i].CreatedAt, updatedEndpoint.Secrets[i].UpdatedAt = time.Time{}, time.Time{} + } + + require.Equal(t, updatedEndpoint, dbEndpoint) +} + +func Test_UpdateEndpointStatus(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + endpointRepo := NewEndpointRepo(db) + + project := seedProject(t, db) + + endpoint := generateEndpoint(project) + + err := endpointRepo.CreateEndpoint(context.Background(), endpoint, project.UID) + require.NoError(t, err) + + status := datastore.InactiveEndpointStatus + + endpoint.Status = status + + require.NoError(t, endpointRepo.UpdateEndpointStatus(context.Background(), project.UID, endpoint.UID, status)) + + dbEndpoint, err := endpointRepo.FindEndpointByID(context.Background(), endpoint.UID, project.UID) + require.NoError(t, err) + + require.Equal(t, status, dbEndpoint.Status) +} + +func Test_DeleteEndpoint(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + endpointRepo := NewEndpointRepo(db) + + project := seedProject(t, db) + + endpoint := generateEndpoint(project) + + err := endpointRepo.CreateEndpoint(context.Background(), endpoint, project.UID) + require.NoError(t, err) + + sub := &datastore.Subscription{ + UID: ulid.Make().String(), + Name: "test_sub", + Type: datastore.SubscriptionTypeAPI, + ProjectID: project.UID, + EndpointID: endpoint.UID, + AlertConfig: &datastore.DefaultAlertConfig, + RetryConfig: &datastore.DefaultRetryConfig, + FilterConfig: &datastore.FilterConfiguration{ + EventTypes: []string{"*"}, + Filter: datastore.FilterSchema{ + Headers: datastore.M{}, + Body: datastore.M{}, + }, + }, + RateLimitConfig: &datastore.DefaultRateLimitConfig, + } + + subRepo := NewSubscriptionRepo(db) + err = subRepo.CreateSubscription(context.Background(), project.UID, sub) + require.NoError(t, err) + + err = endpointRepo.DeleteEndpoint(context.Background(), endpoint, project.UID) + require.NoError(t, err) + + _, err = endpointRepo.FindEndpointByID(context.Background(), endpoint.UID, project.UID) + require.Equal(t, datastore.ErrEndpointNotFound, err) + + _, err = subRepo.FindSubscriptionByID(context.Background(), project.UID, sub.UID) + require.Equal(t, datastore.ErrSubscriptionNotFound, err) +} + +func Test_CreateEndpoint(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + projectRepo := NewProjectRepo(db) + endpointRepo := NewEndpointRepo(db) + + project := &datastore.Project{ + UID: ulid.Make().String(), + Name: "Yet another project", + LogoURL: "s3.com/dsiuirueiy", + OrganisationID: seedOrg(t, db).UID, + Type: datastore.IncomingProject, + Config: &datastore.DefaultProjectConfig, + } + + require.NoError(t, projectRepo.CreateProject(context.Background(), project)) + endpoint := generateEndpoint(project) + + err := endpointRepo.CreateEndpoint(context.Background(), endpoint, project.UID) + require.NoError(t, err) + + dbEndpoint, err := endpointRepo.FindEndpointByID(context.Background(), endpoint.UID, project.UID) + require.NoError(t, err) + + require.NotEmpty(t, dbEndpoint.CreatedAt) + require.NotEmpty(t, dbEndpoint.UpdatedAt) + + dbEndpoint.CreatedAt = time.Time{} + dbEndpoint.UpdatedAt = time.Time{} + + for i := range dbEndpoint.Secrets { + secret := &dbEndpoint.Secrets[i] + + require.Equal(t, endpoint.Secrets[i].Value, secret.Value) + require.NotEmpty(t, secret.CreatedAt) + require.NotEmpty(t, secret.UpdatedAt) + + secret.CreatedAt, secret.UpdatedAt = time.Time{}, time.Time{} + endpoint.Secrets[i].CreatedAt, endpoint.Secrets[i].UpdatedAt = time.Time{}, time.Time{} + } + + require.Equal(t, endpoint, dbEndpoint) +} + +func Test_LoadEndpointsPaged(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + endpointRepo := NewEndpointRepo(db) + eventRepo := NewEventRepo(db) + + project := seedProject(t, db) + + for i := 0; i < 7; i++ { + endpoint := generateEndpoint(project) + if i == 1 || i == 2 || i == 4 { + endpoint.Name += " daniel" + } + + err := endpointRepo.CreateEndpoint(context.Background(), endpoint, project.UID) + require.NoError(t, err) + + event := generateEvent(t, db) + event.Endpoints = []string{endpoint.UID} + require.NoError(t, eventRepo.CreateEvent(context.Background(), event)) + } + + endpoints, _, err := endpointRepo.LoadEndpointsPaged(context.Background(), project.UID, &datastore.Filter{Query: "daniel"}, datastore.Pageable{ + PerPage: 10, + }) + + require.NoError(t, err) + require.Equal(t, 3, len(endpoints)) + + endpoints, _, err = endpointRepo.LoadEndpointsPaged(context.Background(), project.UID, &datastore.Filter{}, datastore.Pageable{ + PerPage: 10, + }) + + require.NoError(t, err) + + require.True(t, len(endpoints) == 7) +} + +func Test_FindEndpointsByID(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + newEndpointRepo := NewEndpointRepo(db) + newEventRepo := NewEventRepo(db) + + project := seedProject(t, db) + ids := []string{} + endpointMap := map[string]*datastore.Endpoint{} + for i := 0; i < 7; i++ { + endpoint := generateEndpoint(project) + + if i == 0 || i == 3 || i == 4 { + endpoint.Secrets[0].Value += fmt.Sprintf("ddhdhhss-%d", i) + endpointMap[endpoint.UID] = endpoint + ids = append(ids, endpoint.UID) + } + + err := newEndpointRepo.CreateEndpoint(context.Background(), endpoint, project.UID) + require.NoError(t, err) + + event := generateEvent(t, db) + event.Endpoints = []string{endpoint.UID} + require.NoError(t, newEventRepo.CreateEvent(context.Background(), event)) + } + + emptyEndpoints, err := newEndpointRepo.FindEndpointsByID(context.Background(), ids, "") + require.NoError(t, err) + require.Equal(t, 0, len(emptyEndpoints)) + + dbEndpoints, err := newEndpointRepo.FindEndpointsByID(context.Background(), ids, project.UID) + require.NoError(t, err) + require.Equal(t, 3, len(dbEndpoints)) + + for _, dbEndpoint := range dbEndpoints { + endpoint, ok := endpointMap[dbEndpoint.UID] + require.True(t, ok) + + require.NotEmpty(t, dbEndpoint.CreatedAt) + require.NotEmpty(t, dbEndpoint.UpdatedAt) + + dbEndpoint.CreatedAt, dbEndpoint.UpdatedAt = time.Time{}, time.Time{} + + for i := range dbEndpoint.Secrets { + s := &dbEndpoint.Secrets[i] + require.NotEmpty(t, s.CreatedAt) + require.NotEmpty(t, s.UpdatedAt) + + s.CreatedAt, s.UpdatedAt = time.Time{}, time.Time{} + endpoint.Secrets[i].CreatedAt, endpoint.Secrets[i].UpdatedAt = time.Time{}, time.Time{} + } + + dbEndpoint.Events = 0 + require.Equal(t, *endpoint, dbEndpoint) + } +} + +func Test_FindEndpointsByAppID(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + endpointRepo := NewEndpointRepo(db) + eventRepo := NewEventRepo(db) + + project := seedProject(t, db) + appID := "vvbbbb" + endpointMap := map[string]*datastore.Endpoint{} + for i := 0; i < 7; i++ { + endpoint := generateEndpoint(project) + + if i < 4 { + endpoint.AppID = appID + endpointMap[endpoint.UID] = endpoint + } + + err := endpointRepo.CreateEndpoint(context.Background(), endpoint, project.UID) + require.NoError(t, err) + + event := generateEvent(t, db) + event.Endpoints = []string{endpoint.UID} + require.NoError(t, eventRepo.CreateEvent(context.Background(), event)) + } + + emptyEndpoints, err := endpointRepo.FindEndpointsByAppID(context.Background(), appID, "") + require.NoError(t, err) + require.Equal(t, 0, len(emptyEndpoints)) + + dbEndpoints, err := endpointRepo.FindEndpointsByAppID(context.Background(), appID, project.UID) + require.NoError(t, err) + require.Equal(t, 4, len(dbEndpoints)) + + for _, dbEndpoint := range dbEndpoints { + endpoint, ok := endpointMap[dbEndpoint.UID] + require.True(t, ok) + + require.NotEmpty(t, dbEndpoint.CreatedAt) + require.NotEmpty(t, dbEndpoint.UpdatedAt) + + dbEndpoint.CreatedAt, dbEndpoint.UpdatedAt = time.Time{}, time.Time{} + + for i := range dbEndpoint.Secrets { + s := &dbEndpoint.Secrets[i] + require.NotEmpty(t, s.CreatedAt) + require.NotEmpty(t, s.UpdatedAt) + + s.CreatedAt, s.UpdatedAt = time.Time{}, time.Time{} + endpoint.Secrets[i].CreatedAt, endpoint.Secrets[i].UpdatedAt = time.Time{}, time.Time{} + } + + require.Equal(t, *endpoint, dbEndpoint) + } +} + +func Test_FindEndpointsByOwnerID(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + endpointRepo := NewEndpointRepo(db) + eventRepo := NewEventRepo(db) + + project := seedProject(t, db) + ownerID := "owner-ffdjj" + endpointMap := map[string]*datastore.Endpoint{} + for i := 0; i < 7; i++ { + endpoint := generateEndpoint(project) + + if i < 4 { + endpoint.OwnerID = ownerID + endpointMap[endpoint.UID] = endpoint + } + + err := endpointRepo.CreateEndpoint(context.Background(), endpoint, project.UID) + require.NoError(t, err) + + event := generateEvent(t, db) + event.Endpoints = []string{endpoint.UID} + require.NoError(t, eventRepo.CreateEvent(context.Background(), event)) + } + + emptyEndpoints, err := endpointRepo.FindEndpointsByOwnerID(context.Background(), "", ownerID) + require.NoError(t, err) + require.Equal(t, 0, len(emptyEndpoints)) + + dbEndpoints, err := endpointRepo.FindEndpointsByOwnerID(context.Background(), project.UID, ownerID) + require.NoError(t, err) + require.Equal(t, 4, len(dbEndpoints)) + + for _, dbEndpoint := range dbEndpoints { + endpoint, ok := endpointMap[dbEndpoint.UID] + require.True(t, ok) + + require.NotEmpty(t, dbEndpoint.CreatedAt) + require.NotEmpty(t, dbEndpoint.UpdatedAt) + + dbEndpoint.CreatedAt, dbEndpoint.UpdatedAt = time.Time{}, time.Time{} + + for i := range dbEndpoint.Secrets { + s := &dbEndpoint.Secrets[i] + require.NotEmpty(t, s.CreatedAt) + require.NotEmpty(t, s.UpdatedAt) + + s.CreatedAt, s.UpdatedAt = time.Time{}, time.Time{} + endpoint.Secrets[i].CreatedAt, endpoint.Secrets[i].UpdatedAt = time.Time{}, time.Time{} + } + + require.Equal(t, *endpoint, dbEndpoint) + } +} + +func Test_CountProjectEndpoints(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + endpointRepo := NewEndpointRepo(db) + + project := seedProject(t, db) + for i := 0; i < 6; i++ { + endpoint := generateEndpoint(project) + err := endpointRepo.CreateEndpoint(context.Background(), endpoint, project.UID) + require.NoError(t, err) + } + + for i := 0; i < 3; i++ { + endpoint := generateEndpoint(project) + p := seedProject(t, db) + endpoint.ProjectID = p.UID + + err := endpointRepo.CreateEndpoint(context.Background(), endpoint, p.UID) + require.NoError(t, err) + } + + c, err := endpointRepo.CountProjectEndpoints(context.Background(), project.UID) + require.NoError(t, err) + + require.Equal(t, int64(6), c) +} + +func Test_FindEndpointByID(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + endpointRepo := NewEndpointRepo(db) + eventRepo := NewEventRepo(db) + + _, err := endpointRepo.FindEndpointByID(context.Background(), ulid.Make().String(), "") + require.Equal(t, datastore.ErrEndpointNotFound, err) + + project := seedProject(t, db) + endpoint := generateEndpoint(project) + + err = endpointRepo.CreateEndpoint(context.Background(), endpoint, project.UID) + require.NoError(t, err) + + event := generateEvent(t, db) + event.Endpoints = []string{endpoint.UID} + require.NoError(t, eventRepo.CreateEvent(context.Background(), event)) + + dbEndpoint, err := endpointRepo.FindEndpointByID(context.Background(), endpoint.UID, project.UID) + require.NoError(t, err) + + require.NotEmpty(t, dbEndpoint.CreatedAt) + require.NotEmpty(t, dbEndpoint.UpdatedAt) + + dbEndpoint.CreatedAt = time.Time{} + dbEndpoint.UpdatedAt = time.Time{} + + for i := range dbEndpoint.Secrets { + secret := &dbEndpoint.Secrets[i] + + require.Equal(t, endpoint.Secrets[i].Value, secret.Value) + require.NotEmpty(t, secret.CreatedAt) + require.NotEmpty(t, secret.UpdatedAt) + + secret.CreatedAt, secret.UpdatedAt = time.Time{}, time.Time{} + endpoint.Secrets[i].CreatedAt, endpoint.Secrets[i].UpdatedAt = time.Time{}, time.Time{} + } + + require.Equal(t, endpoint, dbEndpoint) +} + +func Test_UpdateSecrets(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + endpointRepo := NewEndpointRepo(db) + + project := seedProject(t, db) + endpoint := generateEndpoint(project) + + err := endpointRepo.CreateEndpoint(context.Background(), endpoint, project.UID) + require.NoError(t, err) + + newSecret := datastore.Secret{ + UID: ulid.Make().String(), + Value: "new_secret", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + endpoint.Secrets[0].ExpiresAt = null.NewTime(time.Now(), true) + endpoint.Secrets = append(endpoint.Secrets, newSecret) + + err = endpointRepo.UpdateSecrets(context.Background(), endpoint.UID, project.UID, endpoint.Secrets) + require.NoError(t, err) + + newSecretEndpoint, err := endpointRepo.FindEndpointByID(context.Background(), endpoint.UID, project.UID) + require.NoError(t, err) + + require.Equal(t, endpoint.Secrets[0].UID, newSecretEndpoint.Secrets[0].UID) + require.Equal(t, endpoint.Secrets[0].Value, newSecretEndpoint.Secrets[0].Value) + require.NotEmpty(t, newSecretEndpoint.Secrets[0].ExpiresAt) + require.NotEmpty(t, newSecretEndpoint.Secrets[0].CreatedAt) + require.NotEmpty(t, newSecretEndpoint.Secrets[0].UpdatedAt) + + require.Equal(t, newSecret.UID, newSecretEndpoint.Secrets[1].UID) + require.Equal(t, newSecret.Value, newSecretEndpoint.Secrets[1].Value) + require.Empty(t, newSecretEndpoint.Secrets[1].ExpiresAt) + require.NotEmpty(t, newSecretEndpoint.Secrets[1].CreatedAt) + require.NotEmpty(t, newSecretEndpoint.Secrets[1].UpdatedAt) +} + +func Test_DeleteSecret(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + endpointRepo := NewEndpointRepo(db) + + project := seedProject(t, db) + endpoint := generateEndpoint(project) + + err := endpointRepo.CreateEndpoint(context.Background(), endpoint, project.UID) + require.NoError(t, err) + + newSecret := datastore.Secret{ + UID: ulid.Make().String(), + Value: "new_secret", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + endpoint.Secrets[0].ExpiresAt = null.NewTime(time.Now(), true) + endpoint.Secrets = append(endpoint.Secrets, newSecret) + + err = endpointRepo.UpdateSecrets(context.Background(), endpoint.UID, project.UID, endpoint.Secrets) + require.NoError(t, err) + + err = endpointRepo.DeleteSecret(context.Background(), endpoint, endpoint.Secrets[0].UID, project.UID) + require.NoError(t, err) + + deletedSecretEndpoint, err := endpointRepo.FindEndpointByID(context.Background(), endpoint.UID, project.UID) + require.NoError(t, err) + + for _, secret := range deletedSecretEndpoint.Secrets { + require.NotEqual(t, secret.UID, endpoint.Secrets[0].UID) // the deleted secret should not appear in a fetch + } + + require.Equal(t, newSecret.UID, deletedSecretEndpoint.Secrets[0].UID) + require.Equal(t, newSecret.Value, deletedSecretEndpoint.Secrets[0].Value) + require.Empty(t, deletedSecretEndpoint.Secrets[0].ExpiresAt) + require.Empty(t, deletedSecretEndpoint.Secrets[0].DeletedAt) + require.NotEmpty(t, deletedSecretEndpoint.Secrets[0].CreatedAt) + require.NotEmpty(t, deletedSecretEndpoint.Secrets[0].UpdatedAt) +} + +func generateEndpoint(project *datastore.Project) *datastore.Endpoint { + return &datastore.Endpoint{ + UID: ulid.Make().String(), + ProjectID: project.UID, + OwnerID: ulid.Make().String(), + Url: faker.New().Address().StreetAddress(), + Name: fmt.Sprintf("%s-%s", faker.New().Company().Name(), ulid.Make().String()), + AdvancedSignatures: true, + Description: "testing", + SlackWebhookURL: "https:/gggggg", + SupportEmail: "ex@convoy.com", + AppID: "app1", + HttpTimeout: 30, + RateLimit: 300, + Status: datastore.ActiveEndpointStatus, + RateLimitDuration: 10, + Secrets: []datastore.Secret{ + { + UID: ulid.Make().String(), + Value: "kirer", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }, + }, + Authentication: &datastore.EndpointAuthentication{ + Type: datastore.APIKeyAuthentication, + ApiKey: &datastore.ApiKey{ + HeaderValue: "4387rjejhgjfyuyu34", + HeaderName: "x-header", + }, + }, + } +} + +func seedEndpoint(t *testing.T, db database.Database) *datastore.Endpoint { + project := seedProject(t, db) + endpoint := generateEndpoint(project) + + err := NewEndpointRepo(db).CreateEndpoint(context.Background(), endpoint, project.UID) + require.NoError(t, err) + + return endpoint +} diff --git a/database/sqlite3/event.go b/database/sqlite3/event.go new file mode 100644 index 0000000000..7d2d65b0f7 --- /dev/null +++ b/database/sqlite3/event.go @@ -0,0 +1,673 @@ +package sqlite3 + +import ( + "context" + "database/sql" + "errors" + "fmt" + "io" + "strings" + "time" + + "github.com/frain-dev/convoy/database" + "github.com/frain-dev/convoy/datastore" + "github.com/frain-dev/convoy/util" + "github.com/jmoiron/sqlx" +) + +const ( + PartitionSize = 30_000 + + createEvent = ` + INSERT INTO events (id,event_type,endpoints,project_id, + source_id,headers,raw,data,url_query_params, + idempotency_key,is_duplicate_event,acknowledged_at,metadata,status) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) + ` + + updateEventEndpoints = ` + UPDATE events SET endpoints=$1 WHERE project_id= $2 AND id=$3 + ` + updateEventStatus = ` + UPDATE events SET status=$1 WHERE project_id= $2 AND id=$3 + ` + + createEventEndpoints = ` + INSERT INTO events_endpoints (endpoint_id, event_id) VALUES (:endpoint_id, :event_id) + ON CONFLICT (endpoint_id, event_id) DO NOTHING + ` + + fetchEventById = ` + SELECT id, event_type, endpoints, project_id, + raw, data, headers, is_duplicate_event, + COALESCE(source_id, '') AS source_id, + COALESCE(idempotency_key, '') AS idempotency_key, + COALESCE(url_query_params, '') AS url_query_params, + created_at,updated_at,acknowledged_at,metadata,status + FROM events WHERE id = $1 AND project_id = $2 AND deleted_at IS NULL; + ` + + fetchEventsByIdempotencyKey = ` + SELECT id FROM events WHERE idempotency_key = $1 AND project_id = $2 AND deleted_at IS NULL; + ` + + fetchFirstEventWithIdempotencyKey = ` + SELECT id FROM events + WHERE idempotency_key = $1 + AND is_duplicate_event IS FALSE + AND project_id = $2 + AND deleted_at IS NULL + ORDER BY id + LIMIT 1; + ` + + fetchEventsByIds = ` + SELECT ev.id, ev.project_id, + ev.is_duplicate_event, ev.id AS event_type, + COALESCE(ev.source_id, '') AS source_id, + COALESCE(ev.idempotency_key, '') AS idempotency_key, + COALESCE(ev.url_query_params, '') AS url_query_params, + ev.headers, ev.raw, ev.data, ev.created_at, + ev.updated_at, ev.deleted_at,ev.acknowledged_at, + COALESCE(s.id, '') AS "source_metadata.id", + COALESCE(s.name, '') AS "source_metadata.name" + FROM events ev + LEFT JOIN events_endpoints ee ON ee.event_id = ev.id + LEFT JOIN endpoints e ON e.id = ee.endpoint_id + LEFT JOIN sources s ON s.id = ev.source_id + WHERE ev.deleted_at IS NULL + AND ev.id IN (?) + AND ev.project_id = ? + ` + + countProjectMessages = ` + SELECT COUNT(project_id) FROM events WHERE project_id = $1 AND deleted_at IS NULL; + ` + countEvents = ` + SELECT COUNT(DISTINCT(ev.id)) FROM events ev + LEFT JOIN events_endpoints ee ON ee.event_id = ev.id + LEFT JOIN endpoints e ON ee.endpoint_id = e.id + WHERE ev.project_id = :project_id + AND ev.created_at >= :start_date AND ev.created_at <= :end_date AND ev.deleted_at IS NULL; + ` + + baseEventsPaged = ` + SELECT ev.id, ev.project_id, + ev.id AS event_type, ev.is_duplicate_event, + COALESCE(ev.source_id, '') AS source_id, + ev.headers, ev.raw, ev.data, ev.created_at, + COALESCE(idempotency_key, '') AS idempotency_key, + COALESCE(url_query_params, '') AS url_query_params, + ev.updated_at, ev.deleted_at,ev.acknowledged_at, + COALESCE(s.id, '') AS "source_metadata.id", + COALESCE(s.name, '') AS "source_metadata.name" + FROM events ev + LEFT JOIN events_endpoints ee ON ee.event_id = ev.id + LEFT JOIN endpoints e ON e.id = ee.endpoint_id + LEFT JOIN sources s ON s.id = ev.source_id + WHERE ev.deleted_at IS NULL` + + baseEventsSearch = ` + SELECT ev.id, ev.project_id, + ev.id AS event_type, ev.is_duplicate_event, + COALESCE(ev.source_id, '') AS source_id, + ev.headers, ev.raw, ev.data, ev.created_at, + COALESCE(idempotency_key, '') AS idempotency_key, + COALESCE(url_query_params, '') AS url_query_params, + ev.updated_at, ev.deleted_at, + COALESCE(s.id, '') AS "source_metadata.id", + COALESCE(s.name, '') AS "source_metadata.name" + FROM events_search ev + LEFT JOIN events_endpoints ee ON ee.event_id = ev.id + LEFT JOIN endpoints e ON e.id = ee.endpoint_id + LEFT JOIN sources s ON s.id = ev.source_id + WHERE ev.deleted_at IS NULL` + + baseEventsPagedForward = ` + WITH events AS ( + %s %s AND ev.id <= :cursor + GROUP BY ev.id, s.id + ORDER BY ev.id %s + LIMIT :limit + ) + + SELECT * FROM events ORDER BY id %s + ` + + baseEventsPagedBackward = ` + WITH events AS ( + %s %s AND ev.id >= :cursor + GROUP BY ev.id, s.id + ORDER BY ev.id %s + LIMIT :limit + ) + + SELECT * FROM events ORDER BY id %s + ` + + baseEventFilter = ` AND ev.project_id = :project_id + AND (ev.idempotency_key = :idempotency_key OR :idempotency_key = '') + AND ev.created_at >= :start_date + AND ev.created_at <= :end_date` + + endpointFilter = ` AND ee.endpoint_id IN (:endpoint_ids) ` + + sourceFilter = ` AND ev.source_id IN (:source_ids) ` + + searchFilter = ` AND search_token @@ websearch_to_tsquery('simple',:query) ` + + baseCountPrevEvents = ` + SELECT COUNT(DISTINCT(ev.id)) AS COUNT + FROM events ev + LEFT JOIN events_endpoints ee ON ev.id = ee.event_id + WHERE ev.deleted_at IS NULL + ` + + baseCountPrevEventSearch = ` + SELECT COUNT(DISTINCT(ev.id)) AS COUNT + FROM events_search ev + LEFT JOIN events_endpoints ee ON ev.id = ee.event_id + WHERE ev.deleted_at IS NULL + ` + countPrevEvents = ` AND ev.id > :cursor GROUP BY ev.id ORDER BY ev.id %s LIMIT 1` + + softDeleteProjectEvents = ` + UPDATE events SET deleted_at = NOW() + WHERE project_id = $1 AND created_at >= $2 AND created_at <= $3 + AND deleted_at IS NULL + ` + + hardDeleteProjectEvents = ` + DELETE FROM events WHERE project_id = $1 AND created_at >= $2 AND created_at <= $3 + AND NOT EXISTS ( + SELECT 1 + FROM event_deliveries + WHERE event_id = events.id + ) + ` + + hardDeleteTokenizedEvents = ` + DELETE FROM events_search + WHERE project_id = $1 + ` +) + +type eventRepo struct { + db *sqlx.DB +} + +func NewEventRepo(db database.Database) datastore.EventRepository { + return &eventRepo{db: db.GetDB()} +} + +func (e *eventRepo) CreateEvent(ctx context.Context, event *datastore.Event) error { + var sourceID *string + + if !util.IsStringEmpty(event.SourceID) { + sourceID = &event.SourceID + } + event.Status = datastore.PendingStatus + + tx, isWrapped, err := GetTx(ctx, e.db) + if err != nil { + return err + } + + if !isWrapped { + defer rollbackTx(tx) + } + + _, err = tx.ExecContext(ctx, createEvent, + event.UID, + event.EventType, + event.Endpoints, + event.ProjectID, + sourceID, + event.Headers, + event.Raw, + event.Data, + event.URLQueryParams, + event.IdempotencyKey, + event.IsDuplicateEvent, + event.AcknowledgedAt, + event.Metadata, + event.Status, + ) + if err != nil { + return err + } + + endpoints := event.Endpoints + var j int + for i := 0; i < len(endpoints); i += PartitionSize { + j += PartitionSize + if j > len(endpoints) { + j = len(endpoints) + } + + var ids []interface{} + for _, endpointID := range endpoints[i:j] { + ids = append(ids, &EventEndpoint{EventID: event.UID, EndpointID: endpointID}) + } + + _, err = tx.NamedExecContext(ctx, createEventEndpoints, ids) + if err != nil { + return err + } + } + + if isWrapped { + return nil + } + + return tx.Commit() +} + +func (e *eventRepo) UpdateEventEndpoints(ctx context.Context, event *datastore.Event, endpoints []string) error { + tx, isWrapped, err := GetTx(ctx, e.db) + if err != nil { + return err + } + + if !isWrapped { + defer rollbackTx(tx) + } + + _, err = tx.ExecContext(ctx, updateEventEndpoints, + event.Endpoints, + event.ProjectID, + event.UID, + ) + if err != nil { + return err + } + + var j int + for i := 0; i < len(endpoints); i += PartitionSize { + j += PartitionSize + if j > len(endpoints) { + j = len(endpoints) + } + + var ids []interface{} + for _, endpointID := range endpoints[i:j] { + ids = append(ids, &EventEndpoint{EventID: event.UID, EndpointID: endpointID}) + } + + _, err = tx.NamedExecContext(ctx, createEventEndpoints, ids) + if err != nil { + return err + } + } + + if isWrapped { + return nil + } + + return tx.Commit() +} + +func (e *eventRepo) UpdateEventStatus(ctx context.Context, event *datastore.Event, status datastore.EventStatus) error { + tx, isWrapped, err := GetTx(ctx, e.db) + if err != nil { + return err + } + + if !isWrapped { + defer rollbackTx(tx) + } + + _, err = tx.ExecContext(ctx, updateEventStatus, + status, + event.ProjectID, + event.UID, + ) + if err != nil { + return err + } + + if isWrapped { + return nil + } + + return tx.Commit() +} + +func (e *eventRepo) FindEventByID(ctx context.Context, projectID string, id string) (*datastore.Event, error) { + event := &datastore.Event{} + err := e.db.QueryRowxContext(ctx, fetchEventById, id, projectID).StructScan(event) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, datastore.ErrEventNotFound + } + + return nil, err + } + return event, nil +} + +func (e *eventRepo) FindEventsByIDs(ctx context.Context, projectID string, ids []string) ([]datastore.Event, error) { + query, args, err := sqlx.In(fetchEventsByIds, ids, projectID) + if err != nil { + return nil, err + } + + query = e.db.Rebind(query) + rows, err := e.db.QueryxContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer closeWithError(rows) + + events := make([]datastore.Event, 0) + for rows.Next() { + var event datastore.Event + + err := rows.StructScan(&event) + if err != nil { + return nil, err + } + + events = append(events, event) + } + + return events, nil +} + +func (e *eventRepo) FindEventsByIdempotencyKey(ctx context.Context, projectID string, idempotencyKey string) ([]datastore.Event, error) { + query, args, err := sqlx.In(fetchEventsByIdempotencyKey, idempotencyKey, projectID) + if err != nil { + return nil, err + } + + query = e.db.Rebind(query) + rows, err := e.db.QueryxContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer closeWithError(rows) + + events := make([]datastore.Event, 0) + for rows.Next() { + var event datastore.Event + + err := rows.StructScan(&event) + if err != nil { + return nil, err + } + + events = append(events, event) + } + + return events, nil +} + +func (e *eventRepo) FindFirstEventWithIdempotencyKey(ctx context.Context, projectID string, id string) (*datastore.Event, error) { + event := &datastore.Event{} + err := e.db.QueryRowxContext(ctx, fetchFirstEventWithIdempotencyKey, id, projectID).StructScan(event) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, datastore.ErrEventNotFound + } + + return nil, err + } + return event, nil +} + +func (e *eventRepo) CountProjectMessages(ctx context.Context, projectID string) (int64, error) { + var c int64 + + err := e.db.QueryRowxContext(ctx, countProjectMessages, projectID).Scan(&c) + if err != nil { + return c, err + } + + return c, nil +} + +func (e *eventRepo) CountEvents(ctx context.Context, projectID string, filter *datastore.Filter) (int64, error) { + var count int64 + startDate, endDate := getCreatedDateFilter(filter.SearchParams.CreatedAtStart, filter.SearchParams.CreatedAtEnd) + + arg := map[string]interface{}{ + "endpoint_ids": filter.EndpointIDs, + "project_id": projectID, + "source_id": filter.SourceID, + "start_date": startDate, + "end_date": endDate, + } + + query := countEvents + if len(filter.EndpointIDs) > 0 { + query += ` AND e.id IN (:endpoint_ids) ` + } + + if !util.IsStringEmpty(filter.SourceID) { + query += ` AND ev.source_id = :source_id ` + } + + query, args, err := sqlx.Named(query, arg) + if err != nil { + return 0, err + } + + query, args, err = sqlx.In(query, args...) + if err != nil { + return 0, err + } + + query = e.db.Rebind(query) + err = e.db.QueryRowxContext(ctx, query, args...).Scan(&count) + if err != nil { + return count, err + } + + return count, nil +} + +func (e *eventRepo) LoadEventsPaged(ctx context.Context, projectID string, filter *datastore.Filter) ([]datastore.Event, datastore.PaginationData, error) { + var query, countQuery, filterQuery string + var err error + var args, qargs []interface{} + + startDate, endDate := getCreatedDateFilter(filter.SearchParams.CreatedAtStart, filter.SearchParams.CreatedAtEnd) + if !util.IsStringEmpty(filter.EndpointID) { + filter.EndpointIDs = append(filter.EndpointIDs, filter.EndpointID) + } + + arg := map[string]interface{}{ + "endpoint_ids": filter.EndpointIDs, + "project_id": projectID, + "source_ids": filter.SourceIDs, + "limit": filter.Pageable.Limit(), + "start_date": startDate, + "end_date": endDate, + "query": filter.Query, + "cursor": filter.Pageable.Cursor(), + "idempotency_key": filter.IdempotencyKey, + } + + base := baseEventsPaged + var baseQueryPagination string + if filter.Pageable.Direction == datastore.Next { + baseQueryPagination = getFwdEventPageQuery(filter.Pageable.SortOrder()) + } else { + baseQueryPagination = getBackwardEventPageQuery(filter.Pageable.SortOrder()) + } + + filterQuery = baseEventFilter + + if len(filter.SourceIDs) > 0 { + filterQuery += sourceFilter + } + + if len(filter.EndpointIDs) > 0 { + filterQuery += endpointFilter + } + + if !util.IsStringEmpty(filter.Query) { + filterQuery += searchFilter + base = baseEventsSearch + } + + preOrder := filter.Pageable.SortOrder() + if filter.Pageable.Direction == datastore.Prev { + preOrder = reverseOrder(preOrder) + } + + query = fmt.Sprintf(baseQueryPagination, base, filterQuery, preOrder, filter.Pageable.SortOrder()) + query, args, err = sqlx.Named(query, arg) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + query, args, err = sqlx.In(query, args...) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + query = e.db.Rebind(query) + rows, err := e.db.QueryxContext(ctx, query, args...) + if err != nil { + return nil, datastore.PaginationData{}, err + } + defer closeWithError(rows) + + events := make([]datastore.Event, 0) + for rows.Next() { + var data datastore.Event + + err = rows.StructScan(&data) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + events = append(events, data) + } + + var count datastore.PrevRowCount + if len(events) > 0 { + first := events[0] + qarg := arg + qarg["cursor"] = first.UID + + baseCountEvents := baseCountPrevEvents + if !util.IsStringEmpty(filter.Query) { + baseCountEvents = baseCountPrevEventSearch + } + + tmp := getCountDeliveriesPrevRowQuery(filter.Pageable.SortOrder()) + tmp = fmt.Sprintf(tmp, filter.Pageable.SortOrder()) + + cq := baseCountEvents + filterQuery + tmp + countQuery, qargs, err = sqlx.Named(cq, qarg) + if err != nil { + return nil, datastore.PaginationData{}, err + } + countQuery, qargs, err = sqlx.In(countQuery, qargs...) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + countQuery = e.db.Rebind(countQuery) + + // count the row number before the first row + rows, err := e.db.QueryxContext(ctx, countQuery, qargs...) + if err != nil { + return nil, datastore.PaginationData{}, err + } + defer closeWithError(rows) + + if rows.Next() { + err = rows.StructScan(&count) + if err != nil { + return nil, datastore.PaginationData{}, err + } + } + } + + ids := make([]string, len(events)) + for i := range events { + ids[i] = events[i].UID + } + + if len(events) > filter.Pageable.PerPage { + events = events[:len(events)-1] + } + + pagination := &datastore.PaginationData{PrevRowCount: count} + pagination = pagination.Build(filter.Pageable, ids) + + return events, *pagination, nil +} + +func (e *eventRepo) DeleteProjectEvents(ctx context.Context, projectID string, filter *datastore.EventFilter, hardDelete bool) error { + query := softDeleteProjectEvents + startDate, endDate := getCreatedDateFilter(filter.CreatedAtStart, filter.CreatedAtEnd) + + if hardDelete { + query = hardDeleteProjectEvents + } + + _, err := e.db.ExecContext(ctx, query, projectID, startDate, endDate) + if err != nil { + return err + } + + return nil +} + +func (e *eventRepo) DeleteProjectTokenizedEvents(ctx context.Context, projectID string, filter *datastore.EventFilter) error { + startDate, endDate := getCreatedDateFilter(filter.CreatedAtStart, filter.CreatedAtEnd) + + query := hardDeleteTokenizedEvents + " AND created_at >= $2 AND created_at <= $3" + + _, err := e.db.ExecContext(ctx, query, projectID, startDate, endDate) + if err != nil { + return err + } + + return nil +} + +func (e *eventRepo) CopyRows(_ context.Context, _ string, _ int) error { + return nil +} + +func (e *eventRepo) ExportRecords(ctx context.Context, projectID string, createdAt time.Time, w io.Writer) (int64, error) { + return exportRecords(ctx, e.db, "events", projectID, createdAt, w) +} + +func getCreatedDateFilter(startDate, endDate int64) (time.Time, time.Time) { + return time.Unix(startDate, 0), time.Unix(endDate, 0) +} + +type EventEndpoint struct { + EventID string `db:"event_id"` + EndpointID string `db:"endpoint_id"` +} + +func getFwdEventPageQuery(sortOrder string) string { + if sortOrder == "ASC" { + return strings.Replace(baseEventsPagedForward, "<=", ">=", 1) + } + + return baseEventsPagedForward +} + +func getBackwardEventPageQuery(sortOrder string) string { + if sortOrder == "ASC" { + return strings.Replace(baseEventsPagedBackward, ">=", "<=", 1) + } + + return baseEventsPagedBackward +} + +func getCountDeliveriesPrevRowQuery(sortOrder string) string { + if sortOrder == "ASC" { + return strings.Replace(countPrevEvents, ">", "<", 1) + } + + return countPrevEvents +} diff --git a/database/sqlite3/event_delivery.go b/database/sqlite3/event_delivery.go new file mode 100644 index 0000000000..b40d7c38a5 --- /dev/null +++ b/database/sqlite3/event_delivery.go @@ -0,0 +1,1031 @@ +package sqlite3 + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "io" + "strings" + "time" + + "github.com/frain-dev/convoy/cache" + + "github.com/lib/pq" + + "github.com/frain-dev/convoy/database" + "github.com/frain-dev/convoy/database/hooks" + "github.com/frain-dev/convoy/datastore" + "github.com/frain-dev/convoy/pkg/httpheader" + "github.com/frain-dev/convoy/util" + "github.com/jmoiron/sqlx" + "gopkg.in/guregu/null.v4" +) + +type eventDeliveryRepo struct { + db *sqlx.DB + hook *hooks.Hook + cache cache.Cache +} + +var ( + ErrEventDeliveryNotCreated = errors.New("event delivery could not be created") + ErrEventDeliveryStatusNotUpdated = errors.New("event delivery status could not be updated") + ErrEventDeliveryAttemptsNotUpdated = errors.New("event delivery attempts could not be updated") + ErrEventDeliveriesNotDeleted = errors.New("event deliveries could not be deleted") +) + +const ( + createEventDelivery = ` + INSERT INTO event_deliveries (id,project_id,event_id,endpoint_id,device_id,subscription_id,headers,status,metadata,cli_metadata,description,url_query_params,idempotency_key,event_type,acknowledged_at) + VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15); + ` + createEventDeliveries = ` + INSERT INTO event_deliveries (id,project_id,event_id,endpoint_id,device_id,subscription_id,headers,status,metadata,cli_metadata,description,url_query_params,idempotency_key,event_type,acknowledged_at) + VALUES (:id, :project_id, :event_id, :endpoint_id, :device_id, :subscription_id, :headers, :status, :metadata, :cli_metadata, :description, :url_query_params, :idempotency_key, :event_type, :acknowledged_at); + ` + + baseFetchEventDelivery = ` + SELECT + ed.id,ed.project_id,ed.event_id,ed.subscription_id, + ed.headers,ed.attempts,ed.status,ed.metadata,ed.cli_metadata, + COALESCE(ed.url_query_params, '') AS url_query_params, + COALESCE(ed.idempotency_key, '') AS idempotency_key, + ed.description,ed.created_at,ed.updated_at,ed.acknowledged_at, + COALESCE(ed.event_type,'') AS "event_type", + COALESCE(ed.device_id,'') AS "device_id", + COALESCE(ed.endpoint_id,'') AS "endpoint_id", + COALESCE(ep.id, '') AS "endpoint_metadata.id", + COALESCE(ep.name, '') AS "endpoint_metadata.name", + COALESCE(ep.project_id, '') AS "endpoint_metadata.project_id", + COALESCE(ep.support_email, '') AS "endpoint_metadata.support_email", + COALESCE(ep.url, '') AS "endpoint_metadata.url", + COALESCE(ep.owner_id, '') AS "endpoint_metadata.owner_id", + + ev.id AS "event_metadata.id", + ev.event_type AS "event_metadata.event_type", + COALESCE(ed.latency_seconds, 0) AS latency_seconds, + + COALESCE(d.id,'') AS "device_metadata.id", + COALESCE(d.status,'') AS "device_metadata.status", + COALESCE(d.host_name,'') AS "device_metadata.host_name", + + COALESCE(s.id, '') AS "source_metadata.id", + COALESCE(s.name, '') AS "source_metadata.name", + COALESCE(s.idempotency_keys, '{}') AS "source_metadata.idempotency_keys" + FROM event_deliveries ed + LEFT JOIN endpoints ep ON ed.endpoint_id = ep.id + LEFT JOIN events ev ON ed.event_id = ev.id + LEFT JOIN devices d ON ed.device_id = d.id + LEFT JOIN sources s ON s.id = ev.source_id + ` + + baseEventDeliveryPagedForward = ` + WITH event_deliveries AS ( + %s + %s + AND ed.id <= :cursor + GROUP BY ed.id, ep.id, ev.id, d.id, s.id + ORDER BY ed.id %s + LIMIT :limit + ) + + SELECT * FROM event_deliveries ORDER BY id %s + ` + + baseEventDeliveryPagedBackward = ` + WITH event_deliveries AS ( + %s + %s + AND ed.id >= :cursor + GROUP BY ed.id, ep.id, ev.id, d.id, s.id + ORDER BY ed.id %s + LIMIT :limit + ) + + SELECT * FROM event_deliveries ORDER BY id %s + ` + + fetchEventDeliveryByID = baseFetchEventDelivery + ` AND ed.id = $1 AND ed.project_id = $2` + + fetchEventDeliverySlim = ` + SELECT + id,project_id,event_id,subscription_id, + headers,attempts,status,metadata,cli_metadata, + COALESCE(url_query_params, '') AS url_query_params, + COALESCE(idempotency_key, '') AS idempotency_key,created_at,updated_at, + COALESCE(event_type,'') AS "event_type", + COALESCE(device_id,'') AS "device_id", + COALESCE(endpoint_id,'') AS "endpoint_id", + acknowledged_at + FROM event_deliveries + e project_id = $1 AND id = $2 + ` + + baseEventDeliveryFilter = ` AND (ed.project_id = :project_id OR :project_id = '') + AND (ed.event_id = :event_id OR :event_id = '') + AND (ed.event_type = :event_type OR :event_type = '') + AND ed.created_at >= :start_date + AND ed.created_at <= :end_date` + + countPrevEventDeliveries = ` + SELECT COUNT(DISTINCT(ed.id)) + FROM event_deliveries ed + LEFT JOIN events ev ON ed.event_id = ev.id + WHERE %s + AND ed.id > :cursor + GROUP BY ed.id, ev.id + ORDER BY ed.id %s + ` + + loadEventDeliveriesIntervals = ` + SELECT + DATE_TRUNC('%s', created_at) AS "data.group_only", + TO_CHAR(DATE_TRUNC('%s', created_at), '%s') AS "data.total_time", + EXTRACT('%s' FROM created_at) AS "data.index", + COUNT(*) AS count + FROM + event_deliveries + WHERE + project_id = $1 AND + created_at >= $2 AND + created_at <= $3 + GROUP BY + "data.group_only", "data.index"; + ` + + fetchEventDeliveries = ` + SELECT + id,project_id,event_id,subscription_id, + headers,attempts,status,metadata,cli_metadata, + COALESCE(ed.idempotency_key, '') AS idempotency_key, + COALESCE(url_query_params, '') AS url_query_params, + description,created_at,updated_at, + COALESCE(event_type,'') AS "event_type", + COALESCE(device_id,'') AS "device_id", + COALESCE(endpoint_id,'') AS "endpoint_id", + acknowledged_at + FROM event_deliveries ed + ` + + fetchDiscardedEventDeliveries = ` + SELECT + id,project_id,event_id,subscription_id, + headers,attempts,status,metadata,cli_metadata, + COALESCE(idempotency_key, '') AS idempotency_key, + COALESCE(url_query_params, '') AS url_query_params, + description,created_at,updated_at, + COALESCE(event_type,'') AS "event_type", + COALESCE(device_id,'') AS "device_id", + acknowledged_at + FROM event_deliveries + WHERE status=$1 AND project_id = $2 AND device_id = $3 + AND created_at >= $4 AND created_at <= $5; + ` + + fetchStuckEventDeliveries = ` + SELECT id, project_id + FROM event_deliveries + WHERE status = $1 + AND created_at <= now() - make_interval(secs := 30) + FOR UPDATE SKIP LOCKED + LIMIT 1000; + ` + + countEventDeliveriesByStatus = ` + SELECT COUNT(id) FROM event_deliveries WHERE status = $1 AND (project_id = $2 OR $2 = '') AND created_at >= $3 AND created_at <= $4 + ` + + countEventDeliveries = ` + SELECT COUNT(id) FROM event_deliveries WHERE (project_id = ? OR ? = '') AND (event_id = ? OR ? = '') AND created_at >= ? AND created_at <= ? + ` + + updateEventDeliveriesStatus = ` + UPDATE event_deliveries SET status = ?, description = ?, updated_at = NOW() WHERE (project_id = ? OR ? = '')AND id IN (?); + ` + + updateEventDeliveryMetadata = ` + UPDATE event_deliveries SET status = $1, metadata = $2, latency_seconds = $3, updated_at = NOW() WHERE id = $4 AND project_id = $5; + ` + + hardDeleteProjectEventDeliveries = ` + DELETE FROM event_deliveries WHERE project_id = $1 AND created_at >= $2 AND created_at <= $3; + ` +) + +func NewEventDeliveryRepo(db database.Database) datastore.EventDeliveryRepository { + return &eventDeliveryRepo{db: db.GetDB(), hook: db.GetHook()} +} + +func (e *eventDeliveryRepo) CreateEventDelivery(ctx context.Context, delivery *datastore.EventDelivery) error { + var endpointID *string + var deviceID *string + + if !util.IsStringEmpty(delivery.EndpointID) { + endpointID = &delivery.EndpointID + } + + if !util.IsStringEmpty(delivery.DeviceID) { + deviceID = &delivery.DeviceID + } + + result, err := e.db.ExecContext( + ctx, createEventDelivery, delivery.UID, delivery.ProjectID, + delivery.EventID, endpointID, deviceID, + delivery.SubscriptionID, delivery.Headers, delivery.Status, + delivery.Metadata, delivery.CLIMetadata, delivery.Description, delivery.URLQueryParams, delivery.IdempotencyKey, delivery.EventType, + delivery.AcknowledgedAt, + ) + if err != nil { + return err + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + + if rowsAffected < 1 { + return ErrEventDeliveryNotCreated + } + + return nil +} + +// CreateEventDeliveries creates event deliveries in bulk +func (e *eventDeliveryRepo) CreateEventDeliveries(ctx context.Context, deliveries []*datastore.EventDelivery) error { + values := make([]map[string]interface{}, 0, len(deliveries)) + + for _, delivery := range deliveries { + var endpointID *string + var deviceID *string + + if !util.IsStringEmpty(delivery.EndpointID) { + endpointID = &delivery.EndpointID + } + + if !util.IsStringEmpty(delivery.DeviceID) { + deviceID = &delivery.DeviceID + } + + values = append(values, map[string]interface{}{ + "id": delivery.UID, + "project_id": delivery.ProjectID, + "event_id": delivery.EventID, + "endpoint_id": endpointID, + "device_id": deviceID, + "subscription_id": delivery.SubscriptionID, + "headers": delivery.Headers, + "status": delivery.Status, + "metadata": delivery.Metadata, + "cli_metadata": delivery.CLIMetadata, + "description": delivery.Description, + "url_query_params": delivery.URLQueryParams, + "idempotency_key": delivery.IdempotencyKey, + "event_type": delivery.EventType, + "acknowledged_at": delivery.AcknowledgedAt, + }) + } + + tx, isWrapped, err := GetTx(ctx, e.db) + if err != nil { + return err + } + + if !isWrapped { + defer rollbackTx(tx) + } + + var j int + for i := 0; i < len(values); i += PartitionSize { + j += PartitionSize + if j > len(values) { + j = len(values) + } + + var vs []interface{} + for _, v := range values[i:j] { + vs = append(vs, v) + } + + result, err := tx.NamedExecContext(ctx, createEventDeliveries, vs) + if err != nil { + return err + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + + if len(vs) > 0 && rowsAffected < 1 { + return ErrEventDeliveryNotCreated + } + } + + if isWrapped { + return nil + } + + return tx.Commit() +} + +func (e *eventDeliveryRepo) FindEventDeliveryByID(ctx context.Context, projectID string, id string) (*datastore.EventDelivery, error) { + eventDelivery := &datastore.EventDelivery{} + err := e.db.QueryRowxContext(ctx, fetchEventDeliveryByID, id, projectID).StructScan(eventDelivery) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, datastore.ErrEventDeliveryNotFound + } + return nil, err + } + + return eventDelivery, nil +} + +func (e *eventDeliveryRepo) FindEventDeliveryByIDSlim(ctx context.Context, projectID string, id string) (*datastore.EventDelivery, error) { + eventDelivery := &datastore.EventDelivery{} + err := e.db.QueryRowxContext(ctx, fetchEventDeliverySlim, projectID, id).StructScan(eventDelivery) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, datastore.ErrEventDeliveryNotFound + } + return nil, err + } + + return eventDelivery, nil +} + +func (e *eventDeliveryRepo) FindEventDeliveriesByIDs(ctx context.Context, projectID string, ids []string) ([]datastore.EventDelivery, error) { + eventDeliveries := make([]datastore.EventDelivery, 0) + query := fetchEventDeliveries + " WHERE id IN (?) AND project_id = ? AND deleted_at IS NULL" + + query, args, err := sqlx.In(query, ids, projectID) + if err != nil { + return nil, err + } + + query = e.db.Rebind(query) + + rows, err := e.db.QueryxContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer closeWithError(rows) + + for rows.Next() { + var ed datastore.EventDelivery + err = rows.StructScan(&ed) + if err != nil { + return nil, err + } + + eventDeliveries = append(eventDeliveries, ed) + } + + return eventDeliveries, nil +} + +func (e *eventDeliveryRepo) FindEventDeliveriesByEventID(ctx context.Context, projectID string, eventID string) ([]datastore.EventDelivery, error) { + eventDeliveries := make([]datastore.EventDelivery, 0) + + q := fetchEventDeliveries + " WHERE event_id = $1 AND project_id = $2 AND deleted_at IS NULL" + rows, err := e.db.QueryxContext(ctx, q, eventID, projectID) + if err != nil { + return nil, err + } + defer closeWithError(rows) + + for rows.Next() { + var ed datastore.EventDelivery + err = rows.StructScan(&ed) + if err != nil { + return nil, err + } + + eventDeliveries = append(eventDeliveries, ed) + } + + return eventDeliveries, nil +} + +func (e *eventDeliveryRepo) CountDeliveriesByStatus(ctx context.Context, projectID string, status datastore.EventDeliveryStatus, params datastore.SearchParams) (int64, error) { + count := struct { + Count int64 + }{} + + start := time.Unix(params.CreatedAtStart, 0) + end := time.Unix(params.CreatedAtEnd, 0) + err := e.db.QueryRowxContext(ctx, countEventDeliveriesByStatus, status, projectID, start, end).StructScan(&count) + if err != nil { + return 0, err + } + + return count.Count, nil +} + +func (e *eventDeliveryRepo) FindStuckEventDeliveriesByStatus(ctx context.Context, status datastore.EventDeliveryStatus) ([]datastore.EventDelivery, error) { + eventDeliveries := make([]datastore.EventDelivery, 0) + + rows, err := e.db.QueryxContext(ctx, fetchStuckEventDeliveries, status) + if err != nil { + return nil, err + } + defer closeWithError(rows) + + for rows.Next() { + var ed datastore.EventDelivery + err = rows.StructScan(&ed) + if err != nil { + return nil, err + } + + eventDeliveries = append(eventDeliveries, ed) + } + + return eventDeliveries, nil +} + +func (e *eventDeliveryRepo) UpdateStatusOfEventDelivery(ctx context.Context, projectID string, delivery datastore.EventDelivery, status datastore.EventDeliveryStatus) error { + query, args, err := sqlx.In(updateEventDeliveriesStatus, status, delivery.Description, projectID, projectID, []string{delivery.UID}) + if err != nil { + return err + } + + query = e.db.Rebind(query) + + result, err := e.db.ExecContext(ctx, query, args...) + if err != nil { + return err + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + + if rowsAffected < 1 { + return ErrEventDeliveryStatusNotUpdated + } + + return nil +} + +func (e *eventDeliveryRepo) UpdateStatusOfEventDeliveries(ctx context.Context, projectID string, ids []string, status datastore.EventDeliveryStatus) error { + query, args, err := sqlx.In(updateEventDeliveriesStatus, status, "", projectID, projectID, ids) + if err != nil { + return err + } + + query = e.db.Rebind(query) + + result, err := e.db.ExecContext(ctx, query, args...) + if err != nil { + return err + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + + if rowsAffected < 1 { + return ErrEventDeliveryStatusNotUpdated + } + + return nil +} + +func (e *eventDeliveryRepo) FindDiscardedEventDeliveries(ctx context.Context, projectID, deviceId string, searchParams datastore.SearchParams) ([]datastore.EventDelivery, error) { + eventDeliveries := make([]datastore.EventDelivery, 0) + + start := time.Unix(searchParams.CreatedAtStart, 0) + end := time.Unix(searchParams.CreatedAtEnd, 0) + + rows, err := e.db.QueryxContext(ctx, fetchDiscardedEventDeliveries, datastore.DiscardedEventStatus, projectID, deviceId, start, end) + if err != nil { + return nil, err + } + defer closeWithError(rows) + + for rows.Next() { + var ed datastore.EventDelivery + err = rows.StructScan(&ed) + if err != nil { + return nil, err + } + + eventDeliveries = append(eventDeliveries, ed) + } + + return eventDeliveries, nil +} + +func (e *eventDeliveryRepo) UpdateEventDeliveryMetadata(ctx context.Context, projectID string, delivery *datastore.EventDelivery) error { + result, err := e.db.ExecContext(ctx, updateEventDeliveryMetadata, delivery.Status, delivery.Metadata, delivery.LatencySeconds, delivery.UID, projectID) + if err != nil { + return err + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + + if rowsAffected < 1 { + return ErrEventDeliveryAttemptsNotUpdated + } + + go e.hook.Fire(datastore.EventDeliveryUpdated, delivery, nil) + return nil +} + +func (e *eventDeliveryRepo) CountEventDeliveries(ctx context.Context, projectID string, endpointIDs []string, eventID string, status []datastore.EventDeliveryStatus, params datastore.SearchParams) (int64, error) { + count := struct { + Count int64 + }{} + + start := time.Unix(params.CreatedAtStart, 0) + end := time.Unix(params.CreatedAtEnd, 0) + + args := []interface{}{ + projectID, projectID, + eventID, eventID, + start, end, + } + + q := countEventDeliveries + + if len(endpointIDs) > 0 { + q += ` AND endpoint_id IN (?)` + args = append(args, endpointIDs) + } + + if len(status) > 0 { + q += ` AND status IN (?)` + args = append(args, status) + } + + query, args, err := sqlx.In(q, args...) + if err != nil { + return 0, err + } + + query = e.db.Rebind(query) + + err = e.db.QueryRowxContext(ctx, query, args...).StructScan(&count) + if err != nil { + return 0, err + } + + return count.Count, nil +} + +func (e *eventDeliveryRepo) DeleteProjectEventDeliveries(ctx context.Context, projectID string, filter *datastore.EventDeliveryFilter, _ bool) error { + var result sql.Result + var err error + + start := time.Unix(filter.CreatedAtStart, 0) + end := time.Unix(filter.CreatedAtEnd, 0) + + result, err = e.db.ExecContext(ctx, hardDeleteProjectEventDeliveries, projectID, start, end) + + if err != nil { + return err + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + + if rowsAffected < 1 { + return ErrEventDeliveriesNotDeleted + } + + return nil +} + +func (e *eventDeliveryRepo) LoadEventDeliveriesPaged(ctx context.Context, projectID string, endpointIDs []string, eventID, subscriptionID string, status []datastore.EventDeliveryStatus, params datastore.SearchParams, pageable datastore.Pageable, idempotencyKey, eventType string) ([]datastore.EventDelivery, datastore.PaginationData, error) { + eventDeliveriesP := make([]EventDeliveryPaginated, 0) + + start := time.Unix(params.CreatedAtStart, 0) + end := time.Unix(params.CreatedAtEnd, 0) + + arg := map[string]interface{}{ + "endpoint_ids": endpointIDs, + "project_id": projectID, + "limit": pageable.Limit(), + "subscription_id": subscriptionID, + "start_date": start, + "event_id": eventID, + "event_type": eventType, + "end_date": end, + "status": status, + "cursor": pageable.Cursor(), + "idempotency_key": idempotencyKey, + } + + var query, filterQuery string + if pageable.Direction == datastore.Next { + query = getFwdDeliveryPageQuery(pageable.SortOrder()) + } else { + query = getBackwardDeliveryPageQuery(pageable.SortOrder()) + } + + filterQuery = baseEventDeliveryFilter + if len(endpointIDs) > 0 { + filterQuery += ` AND ed.endpoint_id IN (:endpoint_ids)` + } + + if len(status) > 0 { + filterQuery += ` AND ed.status IN (:status)` + } + + if !util.IsStringEmpty(subscriptionID) { + filterQuery += ` AND ed.subscription_id = :subscription_id` + } + + preOrder := pageable.SortOrder() + if pageable.Direction == datastore.Prev { + preOrder = reverseOrder(preOrder) + } + + query = fmt.Sprintf(query, baseFetchEventDelivery, filterQuery, preOrder, pageable.SortOrder()) + + query, args, err := sqlx.Named(query, arg) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + query, args, err = sqlx.In(query, args...) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + query = e.db.Rebind(query) + + rows, err := e.db.QueryxContext(ctx, query, args...) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + for rows.Next() { + var ed EventDeliveryPaginated + err = rows.StructScan(&ed) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + eventDeliveriesP = append(eventDeliveriesP, ed) + } + + eventDeliveries := make([]datastore.EventDelivery, 0, len(eventDeliveriesP)) + + for i := range eventDeliveriesP { + ev := &eventDeliveriesP[i] + var cli *datastore.CLIMetadata + if ev.CLIMetadata != nil { + cli = &datastore.CLIMetadata{ + EventType: ev.CLIMetadata.EventType.ValueOrZero(), + SourceID: ev.CLIMetadata.SourceID.ValueOrZero(), + } + } + + eventDeliveries = append(eventDeliveries, datastore.EventDelivery{ + UID: ev.UID, + ProjectID: ev.ProjectID, + EventID: ev.EventID, + EndpointID: ev.EndpointID, + DeviceID: ev.DeviceID, + SubscriptionID: ev.SubscriptionID, + IdempotencyKey: ev.IdempotencyKey, + Headers: ev.Headers, + URLQueryParams: ev.URLQueryParams, + Latency: ev.Latency, + LatencySeconds: ev.LatencySeconds, + EventType: ev.EventType, + Endpoint: &datastore.Endpoint{ + UID: ev.Endpoint.UID.ValueOrZero(), + ProjectID: ev.Endpoint.ProjectID.ValueOrZero(), + Url: ev.Endpoint.URL.ValueOrZero(), + Name: ev.Endpoint.Name.ValueOrZero(), + SupportEmail: ev.Endpoint.SupportEmail.ValueOrZero(), + OwnerID: ev.Endpoint.OwnerID.ValueOrZero(), + }, + Source: &datastore.Source{ + UID: ev.Source.UID.ValueOrZero(), + Name: ev.Source.Name.ValueOrZero(), + IdempotencyKeys: ev.Source.IdempotencyKeys, + }, + Device: &datastore.Device{ + UID: ev.Device.UID.ValueOrZero(), + HostName: ev.Device.HostName.ValueOrZero(), + Status: datastore.DeviceStatus(ev.Device.Status.ValueOrZero()), + }, + Event: &datastore.Event{EventType: datastore.EventType(ev.Event.EventType.ValueOrZero())}, + Status: ev.Status, + Metadata: ev.Metadata, + CLIMetadata: cli, + Description: ev.Description, + AcknowledgedAt: ev.AcknowledgedAt, + CreatedAt: ev.CreatedAt, + UpdatedAt: ev.UpdatedAt, + DeletedAt: ev.DeletedAt, + }) + } + + var count datastore.PrevRowCount + if len(eventDeliveries) > 0 { + var countQuery string + var qargs []interface{} + first := eventDeliveries[0] + qarg := arg + qarg["cursor"] = first.UID + + tmp := getCountEventPrevRowQuery(pageable.SortOrder()) + + cq := fmt.Sprintf(tmp, filterQuery, pageable.SortOrder()) + countQuery, qargs, err = sqlx.Named(cq, qarg) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + countQuery, qargs, err = sqlx.In(countQuery, qargs...) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + countQuery = e.db.Rebind(countQuery) + + // count the row number before the first row + rows, err := e.db.QueryxContext(ctx, countQuery, qargs...) + if err != nil { + return nil, datastore.PaginationData{}, err + } + defer closeWithError(rows) + + if rows.Next() { + err = rows.StructScan(&count) + if err != nil { + return nil, datastore.PaginationData{}, err + } + } + } + + ids := make([]string, len(eventDeliveries)) + for i := range eventDeliveries { + ids[i] = eventDeliveries[i].UID + } + + if len(eventDeliveries) > pageable.PerPage { + eventDeliveries = eventDeliveries[:len(eventDeliveries)-1] + } + + pagination := &datastore.PaginationData{PrevRowCount: count} + pagination = pagination.Build(pageable, ids) + + return eventDeliveries, *pagination, nil +} + +const ( + dailyIntervalFormat = "yyyy-mm-dd" // 1 day + weeklyIntervalFormat = dailyIntervalFormat // 1 week + monthlyIntervalFormat = "yyyy-mm" // 1 month + yearlyIntervalFormat = "yyyy" // 1 month +) + +func (e *eventDeliveryRepo) LoadEventDeliveriesIntervals(ctx context.Context, projectID string, params datastore.SearchParams, period datastore.Period) ([]datastore.EventInterval, error) { + intervals := make([]datastore.EventInterval, 0) + + start := time.Unix(params.CreatedAtStart, 0) + end := time.Unix(params.CreatedAtEnd, 0) + + var timeComponent string + var format string + var extract string + switch period { + case datastore.Daily: + timeComponent = "day" + format = dailyIntervalFormat + extract = "doy" + case datastore.Weekly: + timeComponent = "week" + format = weeklyIntervalFormat + extract = timeComponent + case datastore.Monthly: + timeComponent = "month" + format = monthlyIntervalFormat + extract = timeComponent + case datastore.Yearly: + timeComponent = "year" + format = yearlyIntervalFormat + extract = timeComponent + default: + return nil, errors.New("specified data cannot be generated for period") + } + + q := fmt.Sprintf(loadEventDeliveriesIntervals, timeComponent, timeComponent, format, extract) + rows, err := e.db.QueryxContext(ctx, q, projectID, start, end) + if err != nil { + return nil, err + } + + for rows.Next() { + var interval datastore.EventInterval + err = rows.StructScan(&interval) + if err != nil { + return nil, err + } + + intervals = append(intervals, interval) + } + + if len(intervals) < minLen { + var d time.Duration + switch period { + case datastore.Daily: + d = time.Hour * 24 + case datastore.Weekly: + d = time.Hour * 24 * 7 + case datastore.Monthly: + d = time.Hour * 24 * 30 + case datastore.Yearly: + d = time.Hour * 24 * 365 + } + intervals, err = padIntervals(intervals, d, period) + if err != nil { + return nil, err + } + } + + return intervals, nil +} + +func (e *eventDeliveryRepo) ExportRecords(ctx context.Context, projectID string, createdAt time.Time, w io.Writer) (int64, error) { + return exportRecords(ctx, e.db, "event_deliveries", projectID, createdAt, w) +} + +const minLen = 30 + +func padIntervals(intervals []datastore.EventInterval, duration time.Duration, period datastore.Period) ([]datastore.EventInterval, error) { + var err error + + var format string + + switch period { + case datastore.Daily: + format = "2006-01-02" + case datastore.Weekly: + format = "2006-01-02" + case datastore.Monthly: + format = "2006-01" + case datastore.Yearly: + format = "2006" + default: + return nil, errors.New("specified data cannot be generated for period") + } + + start := time.Now() + if len(intervals) > 0 { + start, err = time.Parse(format, intervals[0].Data.Time) + if err != nil { + return nil, err + } + start = start.Add(-duration) // take it back once here, since we getting it from the original slice + } + + numPadding := minLen - (len(intervals)) + paddedIntervals := make([]datastore.EventInterval, numPadding, numPadding+len(intervals)) + for i := numPadding; i > 0; i-- { + paddedIntervals[i-1] = datastore.EventInterval{ + Data: datastore.EventIntervalData{ + Interval: 0, + Time: start.Format(format), + }, + Count: 0, + } + start = start.Add(-duration) + } + + paddedIntervals = append(paddedIntervals, intervals...) + + return paddedIntervals, nil +} + +type EndpointMetadata struct { + UID null.String `db:"id"` + Name null.String `db:"name"` + URL null.String `db:"url"` + ProjectID null.String `db:"project_id"` + SupportEmail null.String `db:"support_email"` + OwnerID null.String `db:"owner_id"` +} + +type EventMetadata struct { + UID null.String `db:"id"` + EventType null.String `db:"event_type"` +} + +type SourceMetadata struct { + UID null.String `db:"id"` + Name null.String `db:"name"` + IdempotencyKeys pq.StringArray `db:"idempotency_keys"` +} + +type DeviceMetadata struct { + UID null.String `db:"id"` + Status null.String `json:"status" db:"status"` + HostName null.String `json:"host_name" db:"host_name"` +} + +type CLIMetadata struct { + EventType null.String `json:"event_type" db:"event_type"` + SourceID null.String `json:"source_id" db:"source_id"` +} + +type EventDeliveryPaginated struct { + UID string `json:"uid" db:"id"` + ProjectID string `json:"project_id,omitempty" db:"project_id"` + EventID string `json:"event_id,omitempty" db:"event_id"` + EndpointID string `json:"endpoint_id,omitempty" db:"endpoint_id"` + DeviceID string `json:"device_id" db:"device_id"` + SubscriptionID string `json:"subscription_id,omitempty" db:"subscription_id"` + Headers httpheader.HTTPHeader `json:"headers" db:"headers"` + URLQueryParams string `json:"url_query_params" db:"url_query_params"` + IdempotencyKey string `json:"idempotency_key" db:"idempotency_key"` + // Deprecated: Latency is deprecated. + Latency string `json:"latency" db:"latency"` + LatencySeconds float64 `json:"latency_seconds" db:"latency_seconds"` + EventType datastore.EventType `json:"event_type,omitempty" db:"event_type"` + + Endpoint *EndpointMetadata `json:"endpoint_metadata,omitempty" db:"endpoint_metadata"` + Event *EventMetadata `json:"event_metadata,omitempty" db:"event_metadata"` + Source *SourceMetadata `json:"source_metadata,omitempty" db:"source_metadata"` + Device *DeviceMetadata `json:"device_metadata,omitempty" db:"device_metadata"` + + DeliveryAttempts datastore.DeliveryAttempts `json:"-" db:"attempts"` + Status datastore.EventDeliveryStatus `json:"status" db:"status"` + Metadata *datastore.Metadata `json:"metadata" db:"metadata"` + CLIMetadata *CLIMetadata `json:"cli_metadata" db:"cli_metadata"` + Description string `json:"description,omitempty" db:"description"` + AcknowledgedAt null.Time `json:"acknowledged_at,omitempty" db:"acknowledged_at,omitempty" swaggertype:"string"` + CreatedAt time.Time `json:"created_at,omitempty" db:"created_at,omitempty" swaggertype:"string"` + UpdatedAt time.Time `json:"updated_at,omitempty" db:"updated_at,omitempty" swaggertype:"string"` + DeletedAt null.Time `json:"deleted_at,omitempty" db:"deleted_at" swaggertype:"string"` +} + +func (m *CLIMetadata) Scan(value interface{}) error { + if value == nil { + return nil + } + + b, ok := value.([]byte) + if !ok { + return fmt.Errorf("unsupported value type %T", value) + } + + if string(b) == "null" { + return nil + } + + if err := json.Unmarshal(b, &m); err != nil { + return err + } + + return nil +} + +func getFwdDeliveryPageQuery(sortOrder string) string { + if sortOrder == "ASC" { + return strings.Replace(baseEventDeliveryPagedForward, "<=", ">=", 1) + } + + return baseEventDeliveryPagedForward +} + +func getBackwardDeliveryPageQuery(sortOrder string) string { + if sortOrder == "ASC" { + return strings.Replace(baseEventDeliveryPagedBackward, ">=", "<=", 1) + } + + return baseEventDeliveryPagedBackward +} + +func getCountEventPrevRowQuery(sortOrder string) string { + if sortOrder == "ASC" { + return strings.Replace(countPrevEventDeliveries, ">", "<", 1) + } + + return countPrevEventDeliveries +} + +func reverseOrder(sortOrder string) string { + switch sortOrder { + case "ASC": + return "DESC" + default: + return "ASC" + } +} diff --git a/database/sqlite3/event_delivery_test.go b/database/sqlite3/event_delivery_test.go new file mode 100644 index 0000000000..8780454a20 --- /dev/null +++ b/database/sqlite3/event_delivery_test.go @@ -0,0 +1,518 @@ +//go:build integration +// +build integration + +package sqlite3 + +import ( + "context" + "gopkg.in/guregu/null.v4" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/frain-dev/convoy/pkg/httpheader" + + "github.com/oklog/ulid/v2" + + "github.com/frain-dev/convoy/datastore" +) + +func Test_eventDeliveryRepo_CreateEventDelivery(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + source := seedSource(t, db) + project := seedProject(t, db) + device := seedDevice(t, db) + endpoint := seedEndpoint(t, db) + event := seedEvent(t, db, project) + sub := seedSubscription(t, db, project, source, endpoint, device) + + ed := generateEventDelivery(project, endpoint, event, device, sub) + + edRepo := NewEventDeliveryRepo(db) + err := edRepo.CreateEventDelivery(context.Background(), ed) + require.NoError(t, err) + + dbEventDelivery, err := edRepo.FindEventDeliveryByID(context.Background(), project.UID, ed.UID) + require.NoError(t, err) + + require.NotEmpty(t, dbEventDelivery.CreatedAt) + require.NotEmpty(t, dbEventDelivery.UpdatedAt) + + dbEventDelivery.CreatedAt, dbEventDelivery.UpdatedAt, dbEventDelivery.AcknowledgedAt = time.Time{}, time.Time{}, null.Time{} + dbEventDelivery.Event, dbEventDelivery.Endpoint, dbEventDelivery.Source, dbEventDelivery.Device = nil, nil, nil, nil + + require.Equal(t, "", dbEventDelivery.Latency) + require.Equal(t, 0.0, dbEventDelivery.LatencySeconds) + + require.Equal(t, ed.Metadata.NextSendTime.UTC(), dbEventDelivery.Metadata.NextSendTime.UTC()) + ed.Metadata.NextSendTime = time.Time{} + dbEventDelivery.Metadata.NextSendTime = time.Time{} + + ed.CreatedAt, ed.UpdatedAt, ed.AcknowledgedAt = time.Time{}, time.Time{}, null.Time{} + require.Equal(t, ed, dbEventDelivery) +} + +func generateEventDelivery(project *datastore.Project, endpoint *datastore.Endpoint, event *datastore.Event, device *datastore.Device, sub *datastore.Subscription) *datastore.EventDelivery { + e := &datastore.EventDelivery{ + UID: ulid.Make().String(), + ProjectID: project.UID, + EventID: event.UID, + EndpointID: endpoint.UID, + DeviceID: device.UID, + SubscriptionID: sub.UID, + EventType: event.EventType, + Headers: httpheader.HTTPHeader{"X-sig": []string{"3787 fmmfbf"}}, + URLQueryParams: "name=ref&category=food", + Status: datastore.SuccessEventStatus, + Metadata: &datastore.Metadata{ + Data: []byte(`{"name": "10x"}`), + Raw: `{"name": "10x"}`, + Strategy: datastore.ExponentialStrategyProvider, + NextSendTime: time.Now().Add(time.Hour), + NumTrials: 1, + IntervalSeconds: 10, + RetryLimit: 20, + }, + CLIMetadata: &datastore.CLIMetadata{}, + Description: "test", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + return e +} + +func Test_eventDeliveryRepo_FindEventDeliveriesByIDs(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + source := seedSource(t, db) + project := seedProject(t, db) + device := seedDevice(t, db) + endpoint := seedEndpoint(t, db) + event := seedEvent(t, db, project) + sub := seedSubscription(t, db, project, source, endpoint, device) + + edRepo := NewEventDeliveryRepo(db) + edMap := map[string]*datastore.EventDelivery{} + ids := []string{} + for i := 0; i < 8; i++ { + ed := generateEventDelivery(project, endpoint, event, device, sub) + ed.Headers["uid"] = []string{ulid.Make().String()} + if i == 0 || i == 1 || i == 5 { + edMap[ed.UID] = ed + ids = append(ids, ed.UID) + } + + err := edRepo.CreateEventDelivery(context.Background(), ed) + require.NoError(t, err) + } + + dbEventDeliveries, err := edRepo.FindEventDeliveriesByIDs(context.Background(), project.UID, ids) + require.NoError(t, err) + require.Equal(t, 3, len(dbEventDeliveries)) + + for i := range dbEventDeliveries { + + dbEventDelivery := &dbEventDeliveries[i] + ed := edMap[dbEventDelivery.UID] + + require.NotEmpty(t, dbEventDelivery.CreatedAt) + require.NotEmpty(t, dbEventDelivery.UpdatedAt) + + dbEventDelivery.CreatedAt, dbEventDelivery.UpdatedAt, dbEventDelivery.AcknowledgedAt = time.Time{}, time.Time{}, null.Time{} + dbEventDelivery.Event, dbEventDelivery.Endpoint, dbEventDelivery.Source = nil, nil, nil + + require.Equal(t, ed.Metadata.NextSendTime.UTC(), dbEventDelivery.Metadata.NextSendTime.UTC()) + ed.Metadata.NextSendTime = time.Time{} + dbEventDelivery.Metadata.NextSendTime = time.Time{} + + ed.CreatedAt, ed.UpdatedAt, ed.AcknowledgedAt = time.Time{}, time.Time{}, null.Time{} + + require.Equal(t, ed.Headers, dbEventDelivery.Headers) + require.Equal(t, ed, dbEventDelivery) + } +} + +func Test_eventDeliveryRepo_FindEventDeliveriesByEventID(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + source := seedSource(t, db) + project := seedProject(t, db) + device := seedDevice(t, db) + endpoint := seedEndpoint(t, db) + sub := seedSubscription(t, db, project, source, endpoint, device) + + edRepo := NewEventDeliveryRepo(db) + edMap := map[string]*datastore.EventDelivery{} + + mainEvent := seedEvent(t, db, project) + for i := 0; i < 8; i++ { + + ed := generateEventDelivery(project, endpoint, seedEvent(t, db, project), device, sub) + if i == 1 || i == 4 || i == 5 { + ed.EventID = mainEvent.UID + edMap[ed.UID] = ed + } + + err := edRepo.CreateEventDelivery(context.Background(), ed) + require.NoError(t, err) + } + + dbEventDeliveries, err := edRepo.FindEventDeliveriesByEventID(context.Background(), project.UID, mainEvent.UID) + require.NoError(t, err) + require.Equal(t, 3, len(dbEventDeliveries)) + + for i := range dbEventDeliveries { + + dbEventDelivery := &dbEventDeliveries[i] + + ed, ok := edMap[dbEventDelivery.UID] + + require.True(t, ok) + + require.NotEmpty(t, dbEventDelivery.CreatedAt) + require.NotEmpty(t, dbEventDelivery.UpdatedAt) + + dbEventDelivery.CreatedAt, dbEventDelivery.UpdatedAt, dbEventDelivery.AcknowledgedAt = time.Time{}, time.Time{}, null.Time{} + dbEventDelivery.Event, dbEventDelivery.Endpoint, dbEventDelivery.Source = nil, nil, nil + + require.Equal(t, ed.Metadata.NextSendTime.UTC(), dbEventDelivery.Metadata.NextSendTime.UTC()) + ed.Metadata.NextSendTime = time.Time{} + dbEventDelivery.Metadata.NextSendTime = time.Time{} + + ed.CreatedAt, ed.UpdatedAt, ed.AcknowledgedAt = time.Time{}, time.Time{}, null.Time{} + require.Equal(t, ed, dbEventDelivery) + } +} + +func Test_eventDeliveryRepo_CountDeliveriesByStatus(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + source := seedSource(t, db) + project := seedProject(t, db) + device := seedDevice(t, db) + endpoint := seedEndpoint(t, db) + event := seedEvent(t, db, project) + sub := seedSubscription(t, db, project, source, endpoint, device) + + edRepo := NewEventDeliveryRepo(db) + + status := datastore.FailureEventStatus + for i := 0; i < 8; i++ { + + ed := generateEventDelivery(project, endpoint, event, device, sub) + if i == 1 || i == 4 || i == 5 { + ed.Status = status + } + + err := edRepo.CreateEventDelivery(context.Background(), ed) + require.NoError(t, err) + } + + count, err := edRepo.CountDeliveriesByStatus(context.Background(), project.UID, status, datastore.SearchParams{ + CreatedAtStart: time.Now().Add(-time.Hour).Unix(), + CreatedAtEnd: time.Now().Add(time.Hour).Unix(), + }) + + require.NoError(t, err) + require.Equal(t, int64(3), count) +} + +func Test_eventDeliveryRepo_UpdateStatusOfEventDelivery(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + source := seedSource(t, db) + project := seedProject(t, db) + device := seedDevice(t, db) + endpoint := seedEndpoint(t, db) + event := seedEvent(t, db, project) + sub := seedSubscription(t, db, project, source, endpoint, device) + + ed := generateEventDelivery(project, endpoint, event, device, sub) + + edRepo := NewEventDeliveryRepo(db) + err := edRepo.CreateEventDelivery(context.Background(), ed) + require.NoError(t, err) + + err = edRepo.UpdateStatusOfEventDelivery(context.Background(), project.UID, *ed, datastore.RetryEventStatus) + require.NoError(t, err) + + dbEventDelivery, err := edRepo.FindEventDeliveryByID(context.Background(), project.UID, ed.UID) + require.NoError(t, err) + + require.Equal(t, datastore.RetryEventStatus, dbEventDelivery.Status) +} + +func Test_eventDeliveryRepo_UpdateStatusOfEventDeliveries(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + source := seedSource(t, db) + project := seedProject(t, db) + device := seedDevice(t, db) + endpoint := seedEndpoint(t, db) + event := seedEvent(t, db, project) + sub := seedSubscription(t, db, project, source, endpoint, device) + + ed1 := generateEventDelivery(project, endpoint, event, device, sub) + ed2 := generateEventDelivery(project, endpoint, event, device, sub) + + edRepo := NewEventDeliveryRepo(db) + err := edRepo.CreateEventDelivery(context.Background(), ed1) + require.NoError(t, err) + + err = edRepo.CreateEventDelivery(context.Background(), ed2) + require.NoError(t, err) + + err = edRepo.UpdateStatusOfEventDeliveries(context.Background(), project.UID, []string{ed1.UID, ed2.UID}, datastore.RetryEventStatus) + require.NoError(t, err) + + dbEventDeliveries, err := edRepo.FindEventDeliveriesByIDs(context.Background(), project.UID, []string{ed1.UID, ed2.UID}) + require.NoError(t, err) + + for _, d := range dbEventDeliveries { + require.Equal(t, datastore.RetryEventStatus, d.Status) + } +} + +func Test_eventDeliveryRepo_FindDiscardedEventDeliveries(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + source := seedSource(t, db) + project := seedProject(t, db) + device := seedDevice(t, db) + endpoint := seedEndpoint(t, db) + event := seedEvent(t, db, project) + sub := seedSubscription(t, db, project, source, endpoint, device) + + edRepo := NewEventDeliveryRepo(db) + + status := datastore.DiscardedEventStatus + for i := 0; i < 8; i++ { + + ed := generateEventDelivery(project, endpoint, event, device, sub) + if i == 1 || i == 4 || i == 5 { + ed.Status = status + } + + err := edRepo.CreateEventDelivery(context.Background(), ed) + require.NoError(t, err) + } + + dbEventDeliveries, err := edRepo.FindDiscardedEventDeliveries(context.Background(), project.UID, device.UID, datastore.SearchParams{ + CreatedAtStart: time.Now().Add(-time.Hour).Unix(), + CreatedAtEnd: time.Now().Add(time.Hour).Unix(), + }) + require.NoError(t, err) + + for _, d := range dbEventDeliveries { + require.Equal(t, datastore.DiscardedEventStatus, d.Status) + } +} + +func Test_eventDeliveryRepo_UpdateEventDeliveryWithAttempt(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + source := seedSource(t, db) + project := seedProject(t, db) + device := seedDevice(t, db) + endpoint := seedEndpoint(t, db) + event := seedEvent(t, db, project) + sub := seedSubscription(t, db, project, source, endpoint, device) + + ed := generateEventDelivery(project, endpoint, event, device, sub) + + edRepo := NewEventDeliveryRepo(db) + err := edRepo.CreateEventDelivery(context.Background(), ed) + require.NoError(t, err) + + latency := "1h2m" + latencySeconds := 3720.0 + + ed.Latency = latency + ed.LatencySeconds = latencySeconds + + err = edRepo.UpdateEventDeliveryMetadata(context.Background(), project.UID, ed) + require.NoError(t, err) + + dbEventDelivery, err := edRepo.FindEventDeliveryByID(context.Background(), project.UID, ed.UID) + require.NoError(t, err) + + require.Equal(t, "", dbEventDelivery.Latency) + require.Equal(t, latencySeconds, dbEventDelivery.LatencySeconds) +} + +func Test_eventDeliveryRepo_CountEventDeliveries(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + source := seedSource(t, db) + project := seedProject(t, db) + device := seedDevice(t, db) + endpoint := seedEndpoint(t, db) + event := seedEvent(t, db, project) + sub := seedSubscription(t, db, project, source, endpoint, device) + + ed1 := generateEventDelivery(project, endpoint, event, device, sub) + ed2 := generateEventDelivery(project, endpoint, event, device, sub) + + edRepo := NewEventDeliveryRepo(db) + err := edRepo.CreateEventDelivery(context.Background(), ed1) + require.NoError(t, err) + + err = edRepo.CreateEventDelivery(context.Background(), ed2) + require.NoError(t, err) + + c, err := edRepo.CountEventDeliveries(context.Background(), project.UID, []string{ed1.EndpointID, ed2.EndpointID}, event.UID, []datastore.EventDeliveryStatus{datastore.SuccessEventStatus}, datastore.SearchParams{ + CreatedAtStart: time.Now().Add(-time.Hour).Unix(), + CreatedAtEnd: time.Now().Add(time.Hour).Unix(), + }) + require.NoError(t, err) + + require.Equal(t, int64(2), c) +} + +func Test_eventDeliveryRepo_DeleteProjectEventDeliveries(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + source := seedSource(t, db) + project := seedProject(t, db) + device := seedDevice(t, db) + endpoint := seedEndpoint(t, db) + event := seedEvent(t, db, project) + sub := seedSubscription(t, db, project, source, endpoint, device) + + // soft delete + ed1 := generateEventDelivery(project, endpoint, event, device, sub) + ed2 := generateEventDelivery(project, endpoint, event, device, sub) + + edRepo := NewEventDeliveryRepo(db) + err := edRepo.CreateEventDelivery(context.Background(), ed1) + require.NoError(t, err) + + err = edRepo.CreateEventDelivery(context.Background(), ed2) + require.NoError(t, err) + + err = edRepo.DeleteProjectEventDeliveries(context.Background(), project.UID, &datastore.EventDeliveryFilter{ + CreatedAtStart: time.Now().Add(-time.Hour).Unix(), + CreatedAtEnd: time.Now().Add(time.Hour).Unix(), + }, false) + + require.NoError(t, err) + + // hard delete + + ed1 = generateEventDelivery(project, endpoint, event, device, sub) + ed2 = generateEventDelivery(project, endpoint, event, device, sub) + + err = edRepo.CreateEventDelivery(context.Background(), ed1) + require.NoError(t, err) + + err = edRepo.CreateEventDelivery(context.Background(), ed2) + require.NoError(t, err) + + err = edRepo.DeleteProjectEventDeliveries(context.Background(), project.UID, &datastore.EventDeliveryFilter{ + CreatedAtStart: time.Now().Add(-time.Hour).Unix(), + CreatedAtEnd: time.Now().Add(time.Hour).Unix(), + }, true) + + require.NoError(t, err) +} + +func Test_eventDeliveryRepo_LoadEventDeliveriesPaged(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + source := seedSource(t, db) + project := seedProject(t, db) + device := seedDevice(t, db) + endpoint := seedEndpoint(t, db) + event := seedEvent(t, db, project) + sub := seedSubscription(t, db, project, source, endpoint, device) + + edRepo := NewEventDeliveryRepo(db) + edMap := map[string]*datastore.EventDelivery{} + for i := 0; i < 8; i++ { + ed := generateEventDelivery(project, endpoint, event, device, sub) + edMap[ed.UID] = ed + + err := edRepo.CreateEventDelivery(context.Background(), ed) + require.NoError(t, err) + } + + dbEventDeliveries, _, err := edRepo.LoadEventDeliveriesPaged( + context.Background(), project.UID, []string{endpoint.UID}, event.UID, sub.UID, + []datastore.EventDeliveryStatus{datastore.SuccessEventStatus}, + datastore.SearchParams{ + CreatedAtStart: time.Now().Add(-time.Hour).Unix(), + CreatedAtEnd: time.Now().Add(time.Hour).Unix(), + }, + datastore.Pageable{ + PerPage: 10, + }, + "", "", + ) + + require.NoError(t, err) + require.Equal(t, 8, len(dbEventDeliveries)) + + for i := range dbEventDeliveries { + + dbEventDelivery := &dbEventDeliveries[i] + ed := edMap[dbEventDelivery.UID] + + require.NotEmpty(t, dbEventDelivery.CreatedAt) + require.NotEmpty(t, dbEventDelivery.UpdatedAt) + + dbEventDelivery.CreatedAt, dbEventDelivery.UpdatedAt, dbEventDelivery.AcknowledgedAt = time.Time{}, time.Time{}, null.Time{} + + require.Equal(t, event.EventType, dbEventDelivery.Event.EventType) + require.Equal(t, endpoint.UID, dbEventDelivery.Endpoint.UID) + dbEventDelivery.Event, dbEventDelivery.Endpoint, dbEventDelivery.Source, dbEventDelivery.Device = nil, nil, nil, nil + + require.Equal(t, ed.Metadata.NextSendTime.UTC(), dbEventDelivery.Metadata.NextSendTime.UTC()) + ed.Metadata.NextSendTime = time.Time{} + dbEventDelivery.Metadata.NextSendTime = time.Time{} + + ed.CreatedAt, ed.UpdatedAt, ed.AcknowledgedAt = time.Time{}, time.Time{}, null.Time{} + ed.Latency = "" + dbEventDelivery.Latency = "" + + require.Equal(t, ed, dbEventDelivery) + } + + evType := "file" + event = seedEventWithEventType(t, db, project, evType) + + ed := generateEventDelivery(project, endpoint, event, device, sub) + + err = edRepo.CreateEventDeliveries(context.Background(), []*datastore.EventDelivery{ed}) + require.NoError(t, err) + + filteredDeliveries, _, err := edRepo.LoadEventDeliveriesPaged( + context.Background(), project.UID, []string{endpoint.UID}, event.UID, sub.UID, + []datastore.EventDeliveryStatus{datastore.SuccessEventStatus}, + datastore.SearchParams{ + CreatedAtStart: time.Now().Add(-time.Hour).Unix(), + CreatedAtEnd: time.Now().Add(time.Hour).Unix(), + }, + datastore.Pageable{ + PerPage: 10, + }, + "", evType, + ) + + require.NoError(t, err) + require.Equal(t, 1, len(filteredDeliveries)) + require.Equal(t, ed.UID, filteredDeliveries[0].UID) +} diff --git a/database/sqlite3/event_test.go b/database/sqlite3/event_test.go new file mode 100644 index 0000000000..4a5ca3c16f --- /dev/null +++ b/database/sqlite3/event_test.go @@ -0,0 +1,470 @@ +//go:build integration +// +build integration + +package sqlite3 + +import ( + "context" + "encoding/json" + "errors" + "gopkg.in/guregu/null.v4" + "testing" + "time" + + "github.com/frain-dev/convoy/database" + "github.com/frain-dev/convoy/datastore" + "github.com/frain-dev/convoy/pkg/httpheader" + "github.com/frain-dev/convoy/util" + "github.com/oklog/ulid/v2" + "github.com/stretchr/testify/require" +) + +func Test_CreateEvent(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + eventRepo := NewEventRepo(db) + event := generateEvent(t, db) + ctx := context.Background() + + require.NoError(t, eventRepo.CreateEvent(ctx, event)) + + newEvent, err := eventRepo.FindEventByID(ctx, event.ProjectID, event.UID) + require.NoError(t, err) + + newEvent.CreatedAt = time.Time{} + newEvent.UpdatedAt = time.Time{} + newEvent.AcknowledgedAt = null.Time{} + event.CreatedAt, event.UpdatedAt, event.AcknowledgedAt = time.Time{}, time.Time{}, null.Time{} + + require.Equal(t, event, newEvent) +} + +func Test_FindEventByID(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + eventRepo := NewEventRepo(db) + event := generateEvent(t, db) + ctx := context.Background() + + _, err := eventRepo.FindEventByID(ctx, event.ProjectID, event.UID) + require.Error(t, err) + require.True(t, errors.Is(err, datastore.ErrEventNotFound)) + + require.NoError(t, eventRepo.CreateEvent(ctx, event)) + + newEvent, err := eventRepo.FindEventByID(ctx, event.ProjectID, event.UID) + require.NoError(t, err) + + newEvent.CreatedAt = time.Time{} + newEvent.UpdatedAt = time.Time{} + newEvent.AcknowledgedAt = null.Time{} + + event.CreatedAt, event.UpdatedAt, event.UpdatedAt = time.Time{}, time.Time{}, time.Time{} + require.Equal(t, event, newEvent) +} + +func Test_FindEventsByIDs(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + eventRepo := NewEventRepo(db) + ctx := context.Background() + event := generateEvent(t, db) + + err := eventRepo.CreateEvent(ctx, event) + require.NoError(t, err) + + records, err := eventRepo.FindEventsByIDs(ctx, event.ProjectID, []string{event.UID}) + require.NoError(t, err) + + require.Equal(t, 1, len(records)) +} + +func Test_CountProjectMessages(t *testing.T) { + data := json.RawMessage([]byte(`{ + "event_id": "123456", + "endpoint_id": "123456" + }`)) + + tests := []struct { + name string + count int + data json.RawMessage + }{ + { + name: "Count Project Messages - 10 records", + count: 10, + data: data, + }, + + { + name: "Count Project Messages - 12 records", + count: 12, + data: data, + }, + + { + name: "Count Project Messages - 5 records", + count: 5, + data: data, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + project := seedProject(t, db) + endpoint := generateEndpoint(project) + eventRepo := NewEventRepo(db) + + err := NewEndpointRepo(db).CreateEndpoint(context.Background(), endpoint, project.UID) + require.NoError(t, err) + + for i := 0; i < tc.count; i++ { + event := &datastore.Event{ + UID: ulid.Make().String(), + EventType: "test-event", + Endpoints: []string{endpoint.UID}, + ProjectID: project.UID, + Headers: httpheader.HTTPHeader{}, + Raw: string(tc.data), + Data: data, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + err := eventRepo.CreateEvent(context.Background(), event) + require.NoError(t, err) + } + + count, err := eventRepo.CountProjectMessages(context.Background(), project.UID) + + require.NoError(t, err) + require.Equal(t, tc.count, int(count)) + }) + } +} + +func Test_CountEvents(t *testing.T) { + data := json.RawMessage([]byte(`{ + "event_id": "123456", + "endpoint_id": "123456" + }`)) + + tests := []struct { + name string + count int + data json.RawMessage + }{ + { + name: "Count Events - 11 records", + count: 11, + data: data, + }, + + { + name: "Count Events - 12 records", + count: 12, + data: data, + }, + + { + name: "Count Events - 10 records", + count: 10, + data: data, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + project := seedProject(t, db) + endpoint := generateEndpoint(project) + eventRepo := NewEventRepo(db) + + err := NewEndpointRepo(db).CreateEndpoint(context.Background(), endpoint, project.UID) + require.NoError(t, err) + + for i := 0; i < tc.count; i++ { + event := &datastore.Event{ + UID: ulid.Make().String(), + EventType: "test-event", + Endpoints: []string{endpoint.UID}, + ProjectID: project.UID, + Headers: httpheader.HTTPHeader{}, + Raw: string(tc.data), + Data: data, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + err := eventRepo.CreateEvent(context.Background(), event) + require.NoError(t, err) + } + + count, err := eventRepo.CountEvents(context.Background(), project.UID, &datastore.Filter{ + EndpointID: endpoint.UID, + SearchParams: datastore.SearchParams{ + CreatedAtStart: time.Now().Add(-5 * time.Minute).Unix(), + CreatedAtEnd: time.Now().Add(5 * time.Minute).Unix(), + }, + }) + + require.NoError(t, err) + require.Equal(t, tc.count, int(count)) + }) + } +} + +func Test_LoadEventsPaged(t *testing.T) { + type Expected struct { + paginationData datastore.PaginationData + } + + data := json.RawMessage([]byte(`{ + "event_id": "123456", + "endpoint_id": "123456" + }`)) + + tests := []struct { + name string + pageData datastore.Pageable + count int + expectedCount int + endpointID string + sourceID string + expected Expected + }{ + { + name: "Load Events Paged - 10 records", + pageData: datastore.Pageable{ + PerPage: 3, + Direction: datastore.Next, + NextCursor: datastore.DefaultCursor, + }, + count: 10, + expected: Expected{ + paginationData: datastore.PaginationData{ + PerPage: 3, + }, + }, + }, + + { + name: "Load Events Paged - 12 records", + pageData: datastore.Pageable{ + PerPage: 4, + Direction: datastore.Next, + NextCursor: datastore.DefaultCursor, + }, + count: 12, + expected: Expected{ + paginationData: datastore.PaginationData{ + PerPage: 4, + }, + }, + }, + + { + name: "Load Events Paged - 5 records", + pageData: datastore.Pageable{ + PerPage: 3, + Direction: datastore.Next, + NextCursor: datastore.DefaultCursor, + }, + count: 5, + expected: Expected{ + paginationData: datastore.PaginationData{ + PerPage: 3, + }, + }, + }, + + { + name: "Filter Events Paged By Endpoint ID - 1 record", + pageData: datastore.Pageable{ + PerPage: 3, + Direction: datastore.Next, + NextCursor: datastore.DefaultCursor, + }, + count: 1, + endpointID: ulid.Make().String(), + sourceID: ulid.Make().String(), + expected: Expected{ + paginationData: datastore.PaginationData{ + PerPage: 3, + }, + }, + }, + + { + name: "Filter Events Paged By Source ID - 1 record", + pageData: datastore.Pageable{ + PerPage: 3, + Direction: datastore.Next, + NextCursor: datastore.DefaultCursor, + }, + count: 3, + + sourceID: ulid.Make().String(), + expected: Expected{ + paginationData: datastore.PaginationData{ + PerPage: 3, + }, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + project := seedProject(t, db) + endpoint := generateEndpoint(project) + source := generateSource(t, db) + source.ProjectID = project.UID + if !util.IsStringEmpty(tc.sourceID) { + source.UID = tc.sourceID + } + eventRepo := NewEventRepo(db) + + if !util.IsStringEmpty(tc.endpointID) { + endpoint.UID = tc.endpointID + } + + err := NewSourceRepo(db).CreateSource(context.Background(), source) + require.NoError(t, err) + + err = NewEndpointRepo(db).CreateEndpoint(context.Background(), endpoint, project.UID) + require.NoError(t, err) + + for i := 0; i < tc.count; i++ { + event := &datastore.Event{ + UID: ulid.Make().String(), + EventType: "test-event", + Endpoints: []string{endpoint.UID}, + ProjectID: project.UID, + Headers: httpheader.HTTPHeader{}, + SourceID: source.UID, + Raw: string(data), + Data: data, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + err := eventRepo.CreateEvent(context.Background(), event) + require.NoError(t, err) + } + + _, pageable, err := eventRepo.LoadEventsPaged(context.Background(), project.UID, &datastore.Filter{ + EndpointID: endpoint.UID, + SourceIDs: []string{source.UID}, + SearchParams: datastore.SearchParams{ + CreatedAtStart: time.Now().Add(-time.Hour).Unix(), + CreatedAtEnd: time.Now().Add(5 * time.Minute).Unix(), + }, + Pageable: tc.pageData, + }) + + require.NoError(t, err) + + require.Equal(t, tc.expected.paginationData.PerPage, pageable.PerPage) + }) + } +} + +func Test_SoftDeleteProjectEvents(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + eventRepo := NewEventRepo(db) + event := generateEvent(t, db) + ctx := context.Background() + + require.NoError(t, eventRepo.CreateEvent(ctx, event)) + + _, err := eventRepo.FindEventByID(ctx, event.ProjectID, event.UID) + require.NoError(t, err) + + require.NoError(t, eventRepo.DeleteProjectEvents(ctx, event.ProjectID, &datastore.EventFilter{ + CreatedAtStart: event.CreatedAt.Unix(), + CreatedAtEnd: time.Now().Add(5 * time.Minute).Unix(), + }, false)) + + _, err = eventRepo.FindEventByID(ctx, event.ProjectID, event.UID) + require.Error(t, err) + require.True(t, errors.Is(err, datastore.ErrEventNotFound)) +} + +func Test_HardDeleteProjectEvents(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + eventRepo := NewEventRepo(db) + event := generateEvent(t, db) + ctx := context.Background() + + require.NoError(t, eventRepo.CreateEvent(ctx, event)) + + _, err := eventRepo.FindEventByID(ctx, event.ProjectID, event.UID) + require.NoError(t, err) + + require.NoError(t, eventRepo.DeleteProjectEvents(ctx, event.ProjectID, &datastore.EventFilter{ + CreatedAtStart: time.Now().Unix(), + CreatedAtEnd: time.Now().Add(5 * time.Minute).Unix(), + }, true)) + + _, err = eventRepo.FindEventByID(ctx, event.ProjectID, event.UID) + require.Error(t, err) + require.True(t, errors.Is(err, datastore.ErrEventNotFound)) +} + +func generateEvent(t *testing.T, db database.Database) *datastore.Event { + project := seedProject(t, db) + endpoint := generateEndpoint(project) + + err := NewEndpointRepo(db).CreateEndpoint(context.Background(), endpoint, project.UID) + require.NoError(t, err) + + data := json.RawMessage([]byte(`{ + "event_id": "123456", + "endpoint_id": "123456" + }`)) + + return &datastore.Event{ + UID: ulid.Make().String(), + EventType: "test-event", + Endpoints: []string{endpoint.UID}, + URLQueryParams: "name=ref&category=food", + ProjectID: project.UID, + Headers: httpheader.HTTPHeader{}, + Raw: string(data), + Data: data, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } +} + +func seedEvent(t *testing.T, db database.Database, project *datastore.Project) *datastore.Event { + ev := generateEvent(t, db) + ev.ProjectID = project.UID + + require.NoError(t, NewEventRepo(db).CreateEvent(context.Background(), ev)) + return ev +} + +func seedEventWithEventType(t *testing.T, db database.Database, project *datastore.Project, eventType string) *datastore.Event { + ev := generateEvent(t, db) + ev.EventType = datastore.EventType(eventType) + ev.ProjectID = project.UID + + require.NoError(t, NewEventRepo(db).CreateEvent(context.Background(), ev)) + return ev +} diff --git a/database/sqlite3/event_types.go b/database/sqlite3/event_types.go new file mode 100644 index 0000000000..3aafcd8338 --- /dev/null +++ b/database/sqlite3/event_types.go @@ -0,0 +1,184 @@ +package sqlite3 + +import ( + "context" + "database/sql" + "errors" + "github.com/frain-dev/convoy/database" + "github.com/frain-dev/convoy/datastore" + "github.com/jmoiron/sqlx" + "github.com/oklog/ulid/v2" + "time" +) + +var ( + ErrEventTypeNotFound = errors.New("event type not found") + ErrEventTypeNotCreated = errors.New("event type could not be created") + ErrEventTypeNotUpdated = errors.New("event type could not be updated") +) + +const ( + createEventType = ` + INSERT INTO event_types (id, name, description, category, project_id, created_at, updated_at) + VALUES ($1, $2, $3, $4, $5, now(), now()); + ` + + updateEventType = ` + UPDATE event_types SET + description = $3, + category = $4, + updated_at = NOW() + WHERE id = $1 and project_id = $2; + ` + + deprecateEventType = ` + UPDATE event_types SET + deprecated_at = NOW() + WHERE id = $1 and project_id = $2 + returning *; + ` + + fetchEventTypeById = ` + SELECT * FROM event_types + WHERE id = $1 and project_id = $2; + ` + + fetchAllEventTypes = ` + SELECT * FROM event_types where project_id = $1; + ` +) + +type eventTypesRepo struct { + db *sqlx.DB +} + +func NewEventTypesRepo(db database.Database) datastore.EventTypesRepository { + return &eventTypesRepo{db: db.GetDB()} +} + +func (e *eventTypesRepo) CreateEventType(ctx context.Context, eventType *datastore.ProjectEventType) error { + r, err := e.db.ExecContext(ctx, createEventType, + eventType.UID, + eventType.Name, + eventType.Description, + eventType.Category, + eventType.ProjectId, + ) + if err != nil { + return err + } + + rowsAffected, err := r.RowsAffected() + if err != nil { + return err + } + + if rowsAffected < 1 { + return ErrEventTypeNotCreated + } + + return nil +} + +func (e *eventTypesRepo) CreateDefaultEventType(ctx context.Context, projectId string) error { + eventType := &datastore.ProjectEventType{ + UID: ulid.Make().String(), + Name: "*", + ProjectId: projectId, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + r, err := e.db.ExecContext(ctx, createEventType, + eventType.UID, + eventType.Name, + eventType.Description, + eventType.Category, + eventType.ProjectId, + ) + if err != nil { + return err + } + + rowsAffected, err := r.RowsAffected() + if err != nil { + return err + } + + if rowsAffected < 1 { + return ErrEventTypeNotCreated + } + + return nil +} + +func (e *eventTypesRepo) UpdateEventType(ctx context.Context, eventType *datastore.ProjectEventType) error { + r, err := e.db.ExecContext(ctx, updateEventType, + eventType.UID, + eventType.ProjectId, + eventType.Description, + eventType.Category, + ) + if err != nil { + return err + } + + rowsAffected, err := r.RowsAffected() + if err != nil { + return err + } + + if rowsAffected < 1 { + return ErrEventTypeNotUpdated + } + + return nil +} + +func (e *eventTypesRepo) DeprecateEventType(ctx context.Context, id, projectId string) (*datastore.ProjectEventType, error) { + eventType := &datastore.ProjectEventType{} + err := e.db.QueryRowxContext(ctx, deprecateEventType, id, projectId).StructScan(eventType) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrEventTypeNotFound + } + return nil, err + } + + return eventType, nil +} + +func (e *eventTypesRepo) FetchEventTypeById(ctx context.Context, id, projectId string) (*datastore.ProjectEventType, error) { + eventType := &datastore.ProjectEventType{} + err := e.db.QueryRowxContext(ctx, fetchEventTypeById, id, projectId).StructScan(eventType) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrEventTypeNotFound + } + return nil, err + } + + return eventType, nil +} + +func (e *eventTypesRepo) FetchAllEventTypes(ctx context.Context, projectId string) ([]datastore.ProjectEventType, error) { + var eventTypes []datastore.ProjectEventType + rows, err := e.db.QueryxContext(ctx, fetchAllEventTypes, projectId) + if err != nil { + return nil, err + } + defer closeWithError(rows) + + for rows.Next() { + var eventType datastore.ProjectEventType + + err = rows.StructScan(&eventType) + if err != nil { + return nil, err + } + + eventTypes = append(eventTypes, eventType) + } + + return eventTypes, nil +} diff --git a/database/sqlite3/export.go b/database/sqlite3/export.go new file mode 100644 index 0000000000..2a646569df --- /dev/null +++ b/database/sqlite3/export.go @@ -0,0 +1,143 @@ +package sqlite3 + +import ( + "context" + "encoding/json" + "fmt" + "io" + "math" + "time" + + "github.com/tidwall/gjson" + + "github.com/jmoiron/sqlx" +) + +const ( + exportRepoQ = ` + SELECT TO_JSONB(ed) - 'id' || JSONB_BUILD_OBJECT('uid', ed.id) AS json_output + FROM %s AS ed %s + ORDER BY id ASC + LIMIT $4 + ` + + count = ` + SELECT COUNT(*) FROM %s %s + ` + + where = ` WHERE deleted_at IS NULL AND project_id = $1 AND created_at < $2 AND (id > $3 OR $3 = '')` +) + +// ExportRecords exports the records from the given table and writes them in json format to the passed writer. +// It's the caller's responsibility to close the writer. +func exportRecords(ctx context.Context, db *sqlx.DB, tableName, projectID string, createdAt time.Time, w io.Writer) (int64, error) { + c := &struct { + Count int64 `db:"count"` + }{} + + countQuery := fmt.Sprintf(count, tableName, where) + err := db.QueryRowxContext(ctx, countQuery, projectID, createdAt, "").StructScan(c) + if err != nil { + return 0, err + } + + if c.Count == 0 { // nothing to export + return 0, nil + } + + var ( + batchSize = 3000 + numDocs int64 + numBatches = int(math.Ceil(float64(c.Count) / float64(batchSize))) + ) + + _, err = w.Write([]byte(`[`)) + if err != nil { + return 0, err + } + + q := fmt.Sprintf(exportRepoQ, tableName, where) + var ( + n int64 + lastID string + ) + + for i := 0; i < numBatches; i++ { + n, lastID, err = querybatch(ctx, db, q, projectID, lastID, createdAt, batchSize, w) + if err != nil { + return 0, fmt.Errorf("failed to query batch %d: %v", i, err) + } + numDocs += n + } + + _, err = w.Write([]byte(`]`)) + if err != nil { + return 0, err + } + + return numDocs, nil +} + +var commaJSON = []byte(`,`) + +func querybatch(ctx context.Context, db *sqlx.DB, q, projectID, lastID string, createdAt time.Time, batchSize int, w io.Writer) (int64, string, error) { + var numDocs int64 + + // Calling rows.Close() manually in places before we return is important here to prevent + // a memory leak, we cannot use defer in a loop because this can fill up the function stack quickly + rows, err := db.QueryxContext(ctx, q, projectID, createdAt, lastID, batchSize) + if err != nil { + return 0, "", err + } + defer closeWithError(rows) + + var record json.RawMessage + records := make([]byte, 0, 1000) + + // scan the first record and append it without appending a comma + if rows.Next() { + numDocs++ + err = rows.Scan(&record) + if err != nil { + return 0, "", err + } + + records = append(records, record...) + } + + i := 0 + // scan remaining records and prefix a comma before writing it + for rows.Next() { + numDocs++ + i++ + err = rows.Scan(&record) + if err != nil { + return 0, "", err + } + + records = append(records, append(commaJSON, record...)...) + + // after gathering 1k records, write records to file + if i == 100 { + i = 0 + + _, err = w.Write(records) + if err != nil { + return 0, "", err + } + records = records[:0] // reset records slice it to length 0, so we can reuse the allocated memory + } + } + + // check for any unwritten records + if len(records) > 0 { + _, err = w.Write(records) + if err != nil { + return 0, "", err + } + } + + value := gjson.Get(string(record), "uid") // get the id of the last record, we use it for pagination + + return numDocs, value.String(), nil +} diff --git a/database/sqlite3/job.go b/database/sqlite3/job.go new file mode 100644 index 0000000000..a34daa993e --- /dev/null +++ b/database/sqlite3/job.go @@ -0,0 +1,368 @@ +package sqlite3 + +import ( + "context" + "database/sql" + "errors" + "fmt" + "github.com/frain-dev/convoy/database" + "github.com/frain-dev/convoy/datastore" + "github.com/jmoiron/sqlx" +) + +var ( + ErrJobNotFound = errors.New("job not found") + ErrJobNotCreated = errors.New("job could not be created") + ErrJobNotUpdated = errors.New("job could not be updated") + ErrJobNotDeleted = errors.New("job could not be deleted") +) + +const ( + createJob = ` + INSERT INTO jobs (id, type, status, project_id) + VALUES ($1, $2, $3, $4) + ` + + updateJobStartedAt = ` + UPDATE jobs SET + status = 'running', + started_at = NOW(), + updated_at = NOW() + WHERE id = $1 AND project_id = $2 AND deleted_at IS NULL; + ` + + updateJobCompletedAt = ` + UPDATE jobs SET + status = 'completed', + completed_at = NOW(), + updated_at = NOW() + WHERE id = $1 AND project_id = $2 AND deleted_at IS NULL; + ` + + updateJobFailedAt = ` + UPDATE jobs SET + status = 'failed', + failed_at = NOW(), + updated_at = NOW() + WHERE id = $1 AND project_id = $2 AND deleted_at IS NULL; + ` + + deleteJob = ` + UPDATE jobs SET + deleted_at = NOW() + WHERE id = $1 AND project_id = $2 AND deleted_at IS NULL; + ` + + fetchJobById = ` + SELECT * FROM jobs + WHERE id = $1 AND project_id = $2 AND deleted_at IS NULL; + ` + + fetchRunningJobsByProjectId = ` + SELECT * FROM jobs + WHERE status = 'running' + AND project_id = $1 + AND deleted_at IS NULL; + ` + + fetchJobsByProjectId = ` + SELECT * FROM jobs WHERE project_id = $1 AND deleted_at IS NULL; + ` + + fetchJobsPaginated = ` + SELECT * FROM jobs WHERE deleted_at IS NULL` + + baseJobsFilter = ` + AND project_id = :project_id` + + baseFetchJobsPagedForward = ` + %s + %s + AND id <= :cursor + GROUP BY id + ORDER BY id DESC + LIMIT :limit + ` + + baseFetchJobsPagedBackward = ` + WITH jobs AS ( + %s + %s + AND id >= :cursor + GROUP BY id + ORDER BY id ASC + LIMIT :limit + ) + + SELECT * FROM jobs ORDER BY id DESC + ` + + countPrevJobs = ` + SELECT COUNT(DISTINCT(id)) AS count + FROM jobs + WHERE deleted_at IS NULL + %s + AND id > :cursor GROUP BY id ORDER BY id DESC LIMIT 1` +) + +type jobRepo struct { + db *sqlx.DB +} + +func NewJobRepo(db database.Database) datastore.JobRepository { + return &jobRepo{db: db.GetDB()} +} + +func (j *jobRepo) CreateJob(ctx context.Context, job *datastore.Job) error { + r, err := j.db.ExecContext(ctx, createJob, + job.UID, + job.Type, + job.Status, + job.ProjectID, + ) + if err != nil { + return err + } + + rowsAffected, err := r.RowsAffected() + if err != nil { + return err + } + + if rowsAffected < 1 { + return ErrJobNotCreated + } + + return nil +} + +func (j *jobRepo) MarkJobAsStarted(ctx context.Context, uid, projectID string) error { + r, err := j.db.ExecContext(ctx, updateJobStartedAt, uid, projectID) + if err != nil { + return err + } + + rowsAffected, err := r.RowsAffected() + if err != nil { + return err + } + + if rowsAffected < 1 { + return ErrJobNotUpdated + } + + return nil +} + +func (j *jobRepo) MarkJobAsCompleted(ctx context.Context, uid, projectID string) error { + r, err := j.db.ExecContext(ctx, updateJobCompletedAt, uid, projectID) + if err != nil { + return err + } + + rowsAffected, err := r.RowsAffected() + if err != nil { + return err + } + + if rowsAffected < 1 { + return ErrJobNotUpdated + } + + return nil +} + +func (j *jobRepo) MarkJobAsFailed(ctx context.Context, uid, projectID string) error { + r, err := j.db.ExecContext(ctx, updateJobFailedAt, uid, projectID) + if err != nil { + return err + } + + rowsAffected, err := r.RowsAffected() + if err != nil { + return err + } + + if rowsAffected < 1 { + return ErrJobNotUpdated + } + + return nil +} + +func (j *jobRepo) DeleteJob(ctx context.Context, uid string, projectID string) error { + r, err := j.db.ExecContext(ctx, deleteJob, uid, projectID) + if err != nil { + return err + } + + rowsAffected, err := r.RowsAffected() + if err != nil { + return err + } + + if rowsAffected < 1 { + return ErrJobNotDeleted + } + + return nil +} + +func (j *jobRepo) FetchJobById(ctx context.Context, uid string, projectID string) (*datastore.Job, error) { + job := &datastore.Job{} + err := j.db.QueryRowxContext(ctx, fetchJobById, uid, projectID).StructScan(job) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrJobNotFound + } + return nil, err + } + + return job, nil +} + +func (j *jobRepo) FetchRunningJobsByProjectId(ctx context.Context, projectID string) ([]datastore.Job, error) { + var jobs []datastore.Job + rows, err := j.db.QueryxContext(ctx, fetchRunningJobsByProjectId, projectID) + if err != nil { + return nil, err + } + defer closeWithError(rows) + + for rows.Next() { + var job datastore.Job + + err = rows.StructScan(&job) + if err != nil { + return nil, err + } + + jobs = append(jobs, job) + } + + return jobs, nil +} + +func (j *jobRepo) FetchJobsByProjectId(ctx context.Context, projectID string) ([]datastore.Job, error) { + var jobs []datastore.Job + rows, err := j.db.QueryxContext(ctx, fetchJobsByProjectId, projectID) + if err != nil { + return nil, err + } + defer closeWithError(rows) + + for rows.Next() { + var job datastore.Job + + err = rows.StructScan(&job) + if err != nil { + return nil, err + } + + jobs = append(jobs, job) + } + + return jobs, nil +} + +func (j *jobRepo) LoadJobsPaged(ctx context.Context, projectID string, pageable datastore.Pageable) ([]datastore.Job, datastore.PaginationData, error) { + var query, filterQuery string + var args []interface{} + var err error + + arg := map[string]interface{}{ + "project_id": projectID, + "limit": pageable.Limit(), + "cursor": pageable.Cursor(), + } + + if pageable.Direction == datastore.Next { + query = baseFetchJobsPagedForward + } else { + query = baseFetchJobsPagedBackward + } + + filterQuery = baseJobsFilter + + query = fmt.Sprintf(query, fetchJobsPaginated, filterQuery) + + query, args, err = sqlx.Named(query, arg) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + query, args, err = sqlx.In(query, args...) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + query = j.db.Rebind(query) + + rows, err := j.db.QueryxContext(ctx, query, args...) + if err != nil { + return nil, datastore.PaginationData{}, err + } + defer closeWithError(rows) + + var jobs []datastore.Job + for rows.Next() { + var data JobPaginated + + err = rows.StructScan(&data) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + jobs = append(jobs, data.Job) + } + + var count datastore.PrevRowCount + if len(jobs) > 0 { + var countQuery string + var qargs []interface{} + first := jobs[0] + qarg := arg + qarg["cursor"] = first.UID + + cq := fmt.Sprintf(countPrevJobs, filterQuery) + countQuery, qargs, err = sqlx.Named(cq, qarg) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + countQuery = j.db.Rebind(countQuery) + + // count the row number before the first row + rows, err := j.db.QueryxContext(ctx, countQuery, qargs...) + if err != nil { + return nil, datastore.PaginationData{}, err + } + defer closeWithError(rows) + + if rows.Next() { + err = rows.StructScan(&count) + if err != nil { + return nil, datastore.PaginationData{}, err + } + } + } + + ids := make([]string, len(jobs)) + for i := range jobs { + ids[i] = jobs[i].UID + } + + if len(jobs) > pageable.PerPage { + jobs = jobs[:len(jobs)-1] + } + + pagination := &datastore.PaginationData{PrevRowCount: count} + pagination = pagination.Build(pageable, ids) + + return jobs, *pagination, nil +} + +type JobPaginated struct { + Count int + datastore.Job +} diff --git a/database/sqlite3/job_test.go b/database/sqlite3/job_test.go new file mode 100644 index 0000000000..202ec2652e --- /dev/null +++ b/database/sqlite3/job_test.go @@ -0,0 +1,345 @@ +//go:build integration + +package sqlite3 + +import ( + "context" + "gopkg.in/guregu/null.v4" + "testing" + "time" + + "github.com/oklog/ulid/v2" + + "github.com/frain-dev/convoy/database" + "github.com/frain-dev/convoy/datastore" + "github.com/stretchr/testify/require" +) + +func Test_CreateJob(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + jobRepo := NewJobRepo(db) + job := generateJob(t, db) + + require.NoError(t, jobRepo.CreateJob(context.Background(), job)) + + jobById, err := jobRepo.FetchJobById(context.Background(), job.UID, job.ProjectID) + require.NoError(t, err) + + require.NotNil(t, jobById) + require.Equal(t, datastore.JobStatusReady, jobById.Status) +} + +func TestJobRepo_FetchJobsByProjectId(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + org := seedOrg(t, db) + jobRepo := NewJobRepo(db) + + p1 := &datastore.Project{ + UID: ulid.Make().String(), + Name: "P1", + OrganisationID: org.UID, + Type: datastore.IncomingProject, + Config: &datastore.DefaultProjectConfig, + } + + p2 := &datastore.Project{ + UID: ulid.Make().String(), + Name: "P2", + OrganisationID: org.UID, + Type: datastore.IncomingProject, + Config: &datastore.DefaultProjectConfig, + } + + err := NewProjectRepo(db).CreateProject(context.Background(), p1) + require.NoError(t, err) + + err = NewProjectRepo(db).CreateProject(context.Background(), p2) + require.NoError(t, err) + + require.NoError(t, jobRepo.CreateJob(context.Background(), &datastore.Job{ + UID: ulid.Make().String(), + Type: "create", + Status: datastore.JobStatusRunning, + StartedAt: null.TimeFrom(time.Now()), + ProjectID: p1.UID, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + })) + + require.NoError(t, jobRepo.CreateJob(context.Background(), &datastore.Job{ + UID: ulid.Make().String(), + Type: "update", + Status: datastore.JobStatusCompleted, + StartedAt: null.TimeFrom(time.Now()), + CompletedAt: null.TimeFrom(time.Now()), + ProjectID: p2.UID, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + })) + + require.NoError(t, jobRepo.CreateJob(context.Background(), &datastore.Job{ + UID: ulid.Make().String(), + Type: "update", + Status: datastore.JobStatusFailed, + StartedAt: null.TimeFrom(time.Now()), + FailedAt: null.TimeFrom(time.Now()), + ProjectID: p2.UID, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + })) + + jobs, err := jobRepo.FetchJobsByProjectId(context.Background(), p2.UID) + require.NoError(t, err) + + require.Equal(t, 2, len(jobs)) +} + +func TestJobRepo_FetchRunningJobsByProjectId(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + org := seedOrg(t, db) + jobRepo := NewJobRepo(db) + + p1 := &datastore.Project{ + UID: ulid.Make().String(), + Name: "P1", + OrganisationID: org.UID, + Type: datastore.IncomingProject, + Config: &datastore.DefaultProjectConfig, + } + + p2 := &datastore.Project{ + UID: ulid.Make().String(), + Name: "P2", + OrganisationID: org.UID, + Type: datastore.IncomingProject, + Config: &datastore.DefaultProjectConfig, + } + + err := NewProjectRepo(db).CreateProject(context.Background(), p1) + require.NoError(t, err) + + err = NewProjectRepo(db).CreateProject(context.Background(), p2) + require.NoError(t, err) + + require.NoError(t, jobRepo.CreateJob(context.Background(), &datastore.Job{ + UID: ulid.Make().String(), + Type: "create", + Status: datastore.JobStatusRunning, + StartedAt: null.TimeFrom(time.Now()), + ProjectID: p1.UID, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + })) + + require.NoError(t, jobRepo.CreateJob(context.Background(), &datastore.Job{ + UID: ulid.Make().String(), + Type: "update", + Status: datastore.JobStatusRunning, + StartedAt: null.TimeFrom(time.Now()), + ProjectID: p2.UID, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + })) + + require.NoError(t, jobRepo.CreateJob(context.Background(), &datastore.Job{ + UID: ulid.Make().String(), + Type: "update", + Status: datastore.JobStatusFailed, + StartedAt: null.TimeFrom(time.Now()), + FailedAt: null.TimeFrom(time.Now()), + ProjectID: p2.UID, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + })) + + jobs, err := jobRepo.FetchRunningJobsByProjectId(context.Background(), p2.UID) + require.NoError(t, err) + + require.Equal(t, 1, len(jobs)) +} + +func TestJobRepo_MarkJobAsStarted(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + jobRepo := NewJobRepo(db) + job := generateJob(t, db) + + ctx := context.Background() + + require.NoError(t, jobRepo.CreateJob(ctx, job)) + + require.NoError(t, jobRepo.MarkJobAsStarted(ctx, job.UID, job.ProjectID)) + + jobById, err := jobRepo.FetchJobById(ctx, job.UID, job.ProjectID) + require.NoError(t, err) + + require.Equal(t, datastore.JobStatusRunning, jobById.Status) + require.Less(t, time.Time{}.Unix(), jobById.StartedAt.Time.Unix()) + require.True(t, time.Now().After(jobById.StartedAt.Time)) + require.Equal(t, time.Time{}.Unix(), jobById.FailedAt.Time.Unix()) + require.Equal(t, time.Time{}.Unix(), jobById.CompletedAt.Time.Unix()) +} + +func TestJobRepo_MarkJobAsCompleted(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + jobRepo := NewJobRepo(db) + job := generateJob(t, db) + + ctx := context.Background() + + require.NoError(t, jobRepo.CreateJob(ctx, job)) + + require.NoError(t, jobRepo.MarkJobAsStarted(ctx, job.UID, job.ProjectID)) + require.NoError(t, jobRepo.MarkJobAsCompleted(ctx, job.UID, job.ProjectID)) + + jobById, err := jobRepo.FetchJobById(ctx, job.UID, job.ProjectID) + require.NoError(t, err) + + require.Equal(t, datastore.JobStatusCompleted, jobById.Status) + require.Less(t, time.Time{}.Unix(), jobById.StartedAt.Time.Unix()) + require.True(t, time.Now().After(jobById.StartedAt.Time)) + require.Equal(t, time.Time{}.Unix(), jobById.FailedAt.Time.Unix()) +} + +func TestJobRepo_MarkJobAsFailed(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + jobRepo := NewJobRepo(db) + job := generateJob(t, db) + + ctx := context.Background() + + require.NoError(t, jobRepo.CreateJob(ctx, job)) + + require.NoError(t, jobRepo.MarkJobAsStarted(ctx, job.UID, job.ProjectID)) + require.NoError(t, jobRepo.MarkJobAsFailed(ctx, job.UID, job.ProjectID)) + + jobById, err := jobRepo.FetchJobById(ctx, job.UID, job.ProjectID) + require.NoError(t, err) + + require.Equal(t, datastore.JobStatusFailed, jobById.Status) + require.Less(t, time.Time{}.Unix(), jobById.StartedAt.Time.Unix()) + require.True(t, time.Now().After(jobById.StartedAt.Time)) + require.Equal(t, time.Time{}.Unix(), jobById.CompletedAt.Time.Unix()) +} + +func TestJobRepo_DeleteJob(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + jobRepo := NewJobRepo(db) + job := generateJob(t, db) + + require.NoError(t, jobRepo.CreateJob(context.Background(), job)) + + err := jobRepo.DeleteJob(context.Background(), job.UID, job.ProjectID) + require.NoError(t, err) + + _, err = jobRepo.FetchJobById(context.Background(), job.UID, job.ProjectID) + require.Equal(t, ErrJobNotFound, err) +} + +func Test_LoadJobsPaged(t *testing.T) { + type Expected struct { + paginationData datastore.PaginationData + } + + tests := []struct { + name string + pageData datastore.Pageable + count int + expected Expected + }{ + { + name: "Load Jobs Paged - 10 records", + pageData: datastore.Pageable{PerPage: 3}, + count: 10, + expected: Expected{ + paginationData: datastore.PaginationData{ + PerPage: 3, + }, + }, + }, + + { + name: "Load Jobs Paged - 12 records", + pageData: datastore.Pageable{PerPage: 4}, + count: 12, + expected: Expected{ + paginationData: datastore.PaginationData{ + PerPage: 4, + }, + }, + }, + + { + name: "Load Jobs Paged - 5 records", + pageData: datastore.Pageable{PerPage: 3}, + count: 5, + expected: Expected{ + paginationData: datastore.PaginationData{ + PerPage: 3, + }, + }, + }, + + { + name: "Load Jobs Paged - 1 record", + pageData: datastore.Pageable{PerPage: 3}, + count: 1, + expected: Expected{ + paginationData: datastore.PaginationData{ + PerPage: 3, + }, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + jobRepository := NewJobRepo(db) + project := seedProject(t, db) + + for i := 0; i < tc.count; i++ { + job := &datastore.Job{ + UID: ulid.Make().String(), + ProjectID: project.UID, + Status: datastore.JobStatusReady, + } + + require.NoError(t, jobRepository.CreateJob(context.Background(), job)) + } + + _, pageable, err := jobRepository.LoadJobsPaged(context.Background(), project.UID, tc.pageData) + require.NoError(t, err) + + require.Equal(t, tc.expected.paginationData.PerPage, pageable.PerPage) + }) + } +} + +func generateJob(t *testing.T, db database.Database) *datastore.Job { + project := seedProject(t, db) + + return &datastore.Job{ + UID: ulid.Make().String(), + Type: "search_tokenizer", + Status: datastore.JobStatusReady, + ProjectID: project.UID, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } +} diff --git a/database/sqlite3/meta_event.go b/database/sqlite3/meta_event.go new file mode 100644 index 0000000000..ced8600983 --- /dev/null +++ b/database/sqlite3/meta_event.go @@ -0,0 +1,225 @@ +package sqlite3 + +import ( + "context" + "database/sql" + "errors" + "fmt" + "github.com/frain-dev/convoy/database" + "github.com/frain-dev/convoy/datastore" + "github.com/jmoiron/sqlx" +) + +var ( + ErrMetaEventNotCreated = errors.New("meta event could not be created") + ErrMetaEventNotUpdated = errors.New("meta event could not be updated") +) + +const ( + createMetaEvent = ` + INSERT INTO meta_events (id, event_type, project_id, metadata, status) + VALUES ($1, $2, $3, $4, $5) + ` + fetchMetaEventById = ` + SELECT id, project_id, event_type, metadata, + attempt, status, created_at, updated_at + FROM meta_events WHERE id = $1 AND project_id = $2 AND deleted_at IS NULL; + ` + baseMetaEventsPaged = ` + SELECT mv.id, mv.project_id, mv.event_type, + mv.metadata, mv.attempt, mv.status, + mv.created_at, mv.updated_at FROM meta_events mv + WHERE mv.deleted_at IS NULL + ` + baseMetaEventsPagedForward = `%s %s AND mv.id <= :cursor + GROUP BY mv.id + ORDER BY mv.id DESC + LIMIT :limit + ` + baseMetaEventsPagedBackward = ` + WITH meta_events AS ( + %s %s AND mv.id >= :cursor + GROUP BY mv.id + ORDER BY mv.id ASC + LIMIT :limit + ) + + SELECT * from meta_events ORDER BY id DESC + ` + baseMetaEventFilter = ` AND mv.project_id = :project_id + AND mv.created_at >= :start_date + AND mv.created_at <= :end_date` + + baseCountPrevMetaEvents = ` + SELECT COUNT(DISTINCT(mv.id)) AS count + FROM meta_events mv WHERE mv.deleted_at IS NULL + ` + countPrevMetaEvents = ` AND mv.id > :cursor GROUP BY mv.id ORDER BY mv.id DESC LIMIT 1` + + updateMetaEventStatus = ` + UPDATE meta_events SET status = $1 WHERE id = $2 AND project_id = $3 + AND deleted_at IS NULL; + ` + updateMetaEvent = ` + UPDATE meta_events SET + event_type = $3, + metadata = $4, + attempt = $5, + status = $6, + updated_at = NOW() + WHERE id = $1 AND project_id = $2 AND deleted_at IS NULL; + ` +) + +type metaEventRepo struct { + db *sqlx.DB +} + +func NewMetaEventRepo(db database.Database) datastore.MetaEventRepository { + return &metaEventRepo{db: db.GetDB()} +} + +func (m *metaEventRepo) CreateMetaEvent(ctx context.Context, metaEvent *datastore.MetaEvent) error { + r, err := m.db.ExecContext(ctx, createMetaEvent, metaEvent.UID, metaEvent.EventType, metaEvent.ProjectID, + metaEvent.Metadata, metaEvent.Status, + ) + if err != nil { + return err + } + + rowsAffected, err := r.RowsAffected() + if err != nil { + return err + } + + if rowsAffected < 1 { + return ErrMetaEventNotCreated + } + + return nil +} + +func (m *metaEventRepo) FindMetaEventByID(ctx context.Context, projectID string, id string) (*datastore.MetaEvent, error) { + metaEvent := &datastore.MetaEvent{} + err := m.db.QueryRowxContext(ctx, fetchMetaEventById, id, projectID).StructScan(metaEvent) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, datastore.ErrMetaEventNotFound + } + + return nil, err + } + + return metaEvent, nil +} + +func (m *metaEventRepo) LoadMetaEventsPaged(ctx context.Context, projectID string, filter *datastore.Filter) ([]datastore.MetaEvent, datastore.PaginationData, error) { + var query, countQuery, filterQuery string + var err error + var args, qargs []interface{} + + startDate, endDate := getCreatedDateFilter(filter.SearchParams.CreatedAtStart, filter.SearchParams.CreatedAtEnd) + + arg := map[string]interface{}{ + "project_id": projectID, + "start_date": startDate, + "end_date": endDate, + "limit": filter.Pageable.Limit(), + "cursor": filter.Pageable.Cursor(), + } + + var baseQueryPagination string + if filter.Pageable.Direction == datastore.Next { + baseQueryPagination = baseMetaEventsPagedForward + } else { + baseQueryPagination = baseMetaEventsPagedBackward + } + + filterQuery = baseMetaEventFilter + query = fmt.Sprintf(baseQueryPagination, baseMetaEventsPaged, filterQuery) + + query, args, err = sqlx.Named(query, arg) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + query = m.db.Rebind(query) + rows, err := m.db.QueryxContext(ctx, query, args...) + if err != nil { + return nil, datastore.PaginationData{}, err + } + defer closeWithError(rows) + + metaEvents := make([]datastore.MetaEvent, 0) + for rows.Next() { + var data datastore.MetaEvent + + err = rows.StructScan(&data) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + metaEvents = append(metaEvents, data) + } + + var count datastore.PrevRowCount + if len(metaEvents) > 0 { + first := metaEvents[0] + qarg := arg + qarg["cursor"] = first.UID + + cq := baseCountPrevMetaEvents + filterQuery + countPrevMetaEvents + countQuery, qargs, err = sqlx.Named(cq, qarg) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + countQuery = m.db.Rebind(countQuery) + rows, err := m.db.QueryxContext(ctx, countQuery, qargs...) + if err != nil { + return nil, datastore.PaginationData{}, err + } + defer closeWithError(rows) + + if rows.Next() { + err = rows.StructScan(&count) + if err != nil { + return nil, datastore.PaginationData{}, err + } + } + } + + ids := make([]string, len(metaEvents)) + for i := range metaEvents { + ids[i] = metaEvents[i].UID + } + + if len(metaEvents) > filter.Pageable.PerPage { + metaEvents = metaEvents[:len(metaEvents)-1] + } + + pagination := &datastore.PaginationData{PrevRowCount: count} + pagination = pagination.Build(filter.Pageable, ids) + + return metaEvents, *pagination, nil +} + +func (m *metaEventRepo) UpdateMetaEvent(ctx context.Context, projectID string, metaEvent *datastore.MetaEvent) error { + result, err := m.db.ExecContext(ctx, updateMetaEvent, metaEvent.UID, projectID, metaEvent.EventType, metaEvent.Metadata, + metaEvent.Attempt, metaEvent.Status, + ) + if err != nil { + return err + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + + if rowsAffected < 1 { + return ErrMetaEventNotUpdated + } + + return nil +} diff --git a/database/sqlite3/meta_event_test.go b/database/sqlite3/meta_event_test.go new file mode 100644 index 0000000000..dd7aa2560f --- /dev/null +++ b/database/sqlite3/meta_event_test.go @@ -0,0 +1,207 @@ +//go:build integration +// +build integration + +package sqlite3 + +import ( + "context" + "encoding/json" + "errors" + "testing" + "time" + + "github.com/frain-dev/convoy/database" + "github.com/frain-dev/convoy/datastore" + "github.com/oklog/ulid/v2" + "github.com/stretchr/testify/require" +) + +func Test_CreateMetaEvent(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + metaEventRepo := NewMetaEventRepo(db) + metaEvent := generateMetaEvent(t, db) + ctx := context.Background() + + require.NoError(t, metaEventRepo.CreateMetaEvent(ctx, metaEvent)) + + newMetaEvent, err := metaEventRepo.FindMetaEventByID(ctx, metaEvent.ProjectID, metaEvent.UID) + require.NoError(t, err) + + newMetaEvent.CreatedAt, newMetaEvent.UpdatedAt, newMetaEvent.Metadata.NextSendTime = time.Time{}, time.Time{}, time.Time{} + metaEvent.CreatedAt, metaEvent.UpdatedAt, metaEvent.Metadata.NextSendTime = time.Time{}, time.Time{}, time.Time{} + + require.Equal(t, metaEvent, newMetaEvent) +} + +func Test_FindMetaEventByID(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + metaEventRepo := NewMetaEventRepo(db) + metaEvent := generateMetaEvent(t, db) + ctx := context.Background() + + _, err := metaEventRepo.FindMetaEventByID(ctx, metaEvent.ProjectID, metaEvent.UID) + require.Error(t, err) + require.True(t, errors.Is(err, datastore.ErrMetaEventNotFound)) + + require.NoError(t, metaEventRepo.CreateMetaEvent(ctx, metaEvent)) + + newMetaEvent, err := metaEventRepo.FindMetaEventByID(ctx, metaEvent.ProjectID, metaEvent.UID) + require.NoError(t, err) + + newMetaEvent.CreatedAt, newMetaEvent.UpdatedAt, newMetaEvent.Metadata.NextSendTime = time.Time{}, time.Time{}, time.Time{} + metaEvent.CreatedAt, metaEvent.UpdatedAt, metaEvent.Metadata.NextSendTime = time.Time{}, time.Time{}, time.Time{} + + require.Equal(t, metaEvent, newMetaEvent) +} + +func Test_UpdateMetaEvent(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + metaEventRepo := NewMetaEventRepo(db) + metaEvent := generateMetaEvent(t, db) + ctx := context.Background() + + require.NoError(t, metaEventRepo.CreateMetaEvent(ctx, metaEvent)) + + data := json.RawMessage([]byte(`{"event_type": "endpoint.updated"}`)) + + metaEvent.Status = datastore.SuccessEventStatus + metaEvent.EventType = string(datastore.EndpointUpdated) + metaEvent.Metadata = &datastore.Metadata{ + Data: data, + Raw: string(data), + } + err := metaEventRepo.UpdateMetaEvent(ctx, metaEvent.ProjectID, metaEvent) + require.NoError(t, err) + + newMetaEvent, err := metaEventRepo.FindMetaEventByID(ctx, metaEvent.ProjectID, metaEvent.UID) + require.NoError(t, err) + + newMetaEvent.CreatedAt, newMetaEvent.UpdatedAt, newMetaEvent.Metadata.NextSendTime = time.Time{}, time.Time{}, time.Time{} + metaEvent.CreatedAt, metaEvent.UpdatedAt, metaEvent.Metadata.NextSendTime = time.Time{}, time.Time{}, time.Time{} + + require.Equal(t, metaEvent, newMetaEvent) +} + +func Test_LoadMetaEventsPaged(t *testing.T) { + type Expected struct { + paginationData datastore.PaginationData + } + + tests := []struct { + name string + pageData datastore.Pageable + count int + endpointID string + expected Expected + }{ + { + name: "Load Meta Events Paged - 10 records", + pageData: datastore.Pageable{ + PerPage: 3, + Direction: datastore.Next, + NextCursor: datastore.DefaultCursor, + }, + count: 10, + expected: Expected{ + paginationData: datastore.PaginationData{ + PerPage: 3, + }, + }, + }, + + { + name: "Load Meta Events Paged - 12 records", + pageData: datastore.Pageable{ + PerPage: 4, + Direction: datastore.Next, + NextCursor: datastore.DefaultCursor, + }, + count: 12, + expected: Expected{ + paginationData: datastore.PaginationData{ + PerPage: 4, + }, + }, + }, + + { + name: "Load Meta Events Paged - 5 records", + pageData: datastore.Pageable{ + PerPage: 3, + Direction: datastore.Next, + NextCursor: datastore.DefaultCursor, + }, + count: 5, + expected: Expected{ + paginationData: datastore.PaginationData{ + PerPage: 3, + }, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + project := seedProject(t, db) + metaEventRepo := NewMetaEventRepo(db) + + for i := 0; i < tc.count; i++ { + metaEvent := &datastore.MetaEvent{ + UID: ulid.Make().String(), + Status: datastore.ScheduledEventStatus, + EventType: string(datastore.EndpointCreated), + ProjectID: project.UID, + Metadata: &datastore.Metadata{}, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + err := metaEventRepo.CreateMetaEvent(context.Background(), metaEvent) + require.NoError(t, err) + } + + _, pageable, err := metaEventRepo.LoadMetaEventsPaged(context.Background(), project.UID, &datastore.Filter{ + SearchParams: datastore.SearchParams{ + CreatedAtStart: time.Now().Add(-time.Hour).Unix(), + CreatedAtEnd: time.Now().Add(5 * time.Minute).Unix(), + }, + Pageable: tc.pageData, + }) + + require.NoError(t, err) + + require.Equal(t, tc.expected.paginationData.PerPage, pageable.PerPage) + }) + } +} + +func generateMetaEvent(t *testing.T, db database.Database) *datastore.MetaEvent { + project := seedProject(t, db) + + return &datastore.MetaEvent{ + UID: ulid.Make().String(), + Status: datastore.ScheduledEventStatus, + EventType: string(datastore.EndpointCreated), + ProjectID: project.UID, + Metadata: &datastore.Metadata{ + Data: []byte(`{"name": "10x"}`), + Raw: `{"name": "10x"}`, + Strategy: datastore.ExponentialStrategyProvider, + NextSendTime: time.Now().Add(time.Hour), + NumTrials: 1, + IntervalSeconds: 10, + RetryLimit: 20, + }, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } +} diff --git a/database/sqlite3/organisation.go b/database/sqlite3/organisation.go new file mode 100644 index 0000000000..1d8cd8bbdc --- /dev/null +++ b/database/sqlite3/organisation.go @@ -0,0 +1,285 @@ +package sqlite3 + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/frain-dev/convoy/cache" + "github.com/frain-dev/convoy/database" + "github.com/frain-dev/convoy/datastore" + "github.com/jmoiron/sqlx" +) + +var ( + ErrOrganizationNotCreated = errors.New("organization could not be created") + ErrOrganizationNotUpdated = errors.New("organization could not be updated") + ErrOrganizationNotDeleted = errors.New("organization could not be deleted") +) + +const ( + createOrganization = ` + INSERT INTO organisations (id, name, owner_id, custom_domain, assigned_domain) + VALUES ($1, $2, $3, $4, $5); + ` + + fetchOrganisation = ` + SELECT * FROM organisations + WHERE deleted_at IS NULL + ` + + fetchOrganisationsPaged = ` + SELECT * FROM organisations WHERE deleted_at IS NULL + ` + + updateOrganizationById = ` + UPDATE organisations SET + name = $2, + custom_domain = $3, + assigned_domain = $4, + updated_at = NOW() + WHERE id = $1 AND deleted_at IS NULL; + ` + + deleteOrganisation = ` + UPDATE organisations SET + deleted_at = NOW() + WHERE id = $1 AND deleted_at IS NULL; + ` + + baseFetchOrganizationsPagedForward = ` + %s + AND id <= :cursor + GROUP BY id + ORDER BY id DESC + LIMIT :limit + ` + + baseFetchOrganizationsPagedBackward = ` + WITH organizations AS ( + %s + AND id >= :cursor + GROUP BY id + ORDER BY id ASC + LIMIT :limit + ) + + SELECT * FROM organizations ORDER BY id DESC + ` + + countPrevOrganizations = ` + SELECT COUNT(DISTINCT(id)) AS count + FROM organisations + WHERE deleted_at IS NULL + AND id > :cursor + GROUP BY id + ORDER BY id DESC + LIMIT 1` + + countOrganizations = ` + SELECT COUNT(*) AS count + FROM organisations + WHERE deleted_at IS NULL` +) + +type orgRepo struct { + db *sqlx.DB + cache cache.Cache +} + +func NewOrgRepo(db database.Database) datastore.OrganisationRepository { + return &orgRepo{db: db.GetDB()} +} + +func (o *orgRepo) CreateOrganisation(ctx context.Context, org *datastore.Organisation) error { + result, err := o.db.ExecContext(ctx, createOrganization, org.UID, org.Name, org.OwnerID, org.CustomDomain, org.AssignedDomain) + if err != nil { + return err + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + + if rowsAffected < 1 { + return ErrOrganizationNotCreated + } + + return nil +} + +func (o *orgRepo) LoadOrganisationsPaged(ctx context.Context, pageable datastore.Pageable) ([]datastore.Organisation, datastore.PaginationData, error) { + var query string + if pageable.Direction == datastore.Next { + query = baseFetchOrganizationsPagedForward + } else { + query = baseFetchOrganizationsPagedBackward + } + + query = fmt.Sprintf(query, fetchOrganisationsPaged) + + arg := map[string]interface{}{ + "limit": pageable.Limit(), + "cursor": pageable.Cursor(), + } + + query, args, err := sqlx.Named(query, arg) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + query, args, err = sqlx.In(query, args...) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + query = o.db.Rebind(query) + + rows, err := o.db.QueryxContext(ctx, query, args...) + if err != nil { + return nil, datastore.PaginationData{}, err + } + defer closeWithError(rows) + + organizations := make([]datastore.Organisation, 0) + for rows.Next() { + var org datastore.Organisation + + err = rows.StructScan(&org) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + organizations = append(organizations, org) + } + + var count datastore.PrevRowCount + if len(organizations) > 0 { + var countQuery string + var qargs []interface{} + + arg["cursor"] = organizations[0].UID + + countQuery, qargs, err = sqlx.Named(countPrevOrganizations, arg) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + countQuery = o.db.Rebind(countQuery) + + // count the row number before the first row + rows, err := o.db.QueryxContext(ctx, countQuery, qargs...) + if err != nil { + return nil, datastore.PaginationData{}, err + } + defer closeWithError(rows) + + if rows.Next() { + err = rows.StructScan(&count) + if err != nil { + return nil, datastore.PaginationData{}, err + } + } + } + + ids := make([]string, len(organizations)) + for i := range organizations { + ids[i] = organizations[i].UID + } + + if len(organizations) > pageable.PerPage { + organizations = organizations[:len(organizations)-1] + } + + pagination := &datastore.PaginationData{PrevRowCount: count} + pagination = pagination.Build(pageable, ids) + + return organizations, *pagination, nil +} + +func (o *orgRepo) UpdateOrganisation(ctx context.Context, org *datastore.Organisation) error { + result, err := o.db.ExecContext(ctx, updateOrganizationById, org.UID, org.Name, org.CustomDomain, org.AssignedDomain) + if err != nil { + return err + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + + if rowsAffected < 1 { + return ErrOrganizationNotUpdated + } + + return nil +} + +func (o *orgRepo) DeleteOrganisation(ctx context.Context, uid string) error { + result, err := o.db.Exec(deleteOrganisation, uid) + if err != nil { + return err + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + + if rowsAffected < 1 { + return ErrOrganizationNotDeleted + } + + return nil +} + +func (o *orgRepo) CountOrganisations(ctx context.Context) (int64, error) { + var count int64 + err := o.db.GetContext(ctx, &count, countOrganizations) + if err != nil { + return 0, err + } + + return count, nil +} + +func (o *orgRepo) FetchOrganisationByID(ctx context.Context, id string) (*datastore.Organisation, error) { + org := &datastore.Organisation{} + err := o.db.QueryRowxContext(ctx, fmt.Sprintf("%s AND id = $1", fetchOrganisation), id).StructScan(org) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, datastore.ErrOrgNotFound + } + return nil, err + } + + return org, nil +} + +func (o *orgRepo) FetchOrganisationByAssignedDomain(ctx context.Context, domain string) (*datastore.Organisation, error) { + org := &datastore.Organisation{} + err := o.db.QueryRowxContext(ctx, fmt.Sprintf("%s AND assigned_domain = $1", fetchOrganisation), domain).StructScan(org) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, datastore.ErrOrgNotFound + } + return nil, err + } + + return org, nil +} + +func (o *orgRepo) FetchOrganisationByCustomDomain(ctx context.Context, domain string) (*datastore.Organisation, error) { + org := &datastore.Organisation{} + err := o.db.QueryRowxContext(ctx, fmt.Sprintf("%s AND custom_domain = $1", fetchOrganisation), domain).StructScan(org) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, datastore.ErrOrgNotFound + } + return nil, err + } + + return org, nil +} diff --git a/database/sqlite3/organisation_invite.go b/database/sqlite3/organisation_invite.go new file mode 100644 index 0000000000..548f4ccd47 --- /dev/null +++ b/database/sqlite3/organisation_invite.go @@ -0,0 +1,342 @@ +package sqlite3 + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/frain-dev/convoy/util" + + "github.com/frain-dev/convoy/database" + "github.com/frain-dev/convoy/datastore" + "github.com/jmoiron/sqlx" +) + +var ( + ErrOrganizationInviteNotCreated = errors.New("organization invite could not be created") + ErrOrganizationInviteNotUpdated = errors.New("organization invite could not be updated") + ErrOrganizationInviteNotDeleted = errors.New("organization invite could not be deleted") +) + +const ( + createOrganisationInvite = ` + INSERT INTO organisation_invites (id, organisation_id, invitee_email, token, role_type, role_project, role_endpoint, status, expires_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9); + ` + + updateOrganisationInvite = ` + UPDATE organisation_invites + SET + role_type = $2, + role_project = $3, + role_endpoint = $4, + status = $5, + expires_at = $6, + updated_at = NOW(), + deleted_at = $7 + WHERE id = $1 AND deleted_at IS NULL; + ` + + fetchOrganisationInviteById = ` + SELECT + id, + organisation_id, + invitee_email, + token, + status, + role_type AS "role.type", + COALESCE(role_project,'') AS "role.project", + COALESCE(role_endpoint,'') AS "role.endpoint", + created_at, updated_at, expires_at + FROM organisation_invites + WHERE id = $1 AND deleted_at IS NULL; + ` + + fetchOrganisationInviteByToken = ` + SELECT + id, + organisation_id, + invitee_email, + token, + status, + role_type AS "role.type", + COALESCE(role_project,'') AS "role.project", + COALESCE(role_endpoint,'') AS "role.endpoint", + created_at, updated_at, expires_at + FROM organisation_invites + WHERE token = $1 AND deleted_at IS NULL; + ` + + fetchOrganisationInvitesPaginated = ` + SELECT + id, + organisation_id, + invitee_email, + status, + role_type AS "role.type", + COALESCE(role_project,'') AS "role.project", + COALESCE(role_endpoint,'') AS "role.endpoint", + created_at, updated_at, expires_at + FROM organisation_invites + WHERE organisation_id = :org_id + AND status = :status + AND deleted_at IS NULL + ` + + baseFetchInvitesPagedForward = ` + %s + AND id <= :cursor + GROUP BY id + ORDER BY id DESC + LIMIT :limit + ` + + baseFetchInvitesPagedBackward = ` + WITH organisation_invites AS ( + %s + AND id >= :cursor + GROUP BY id + ORDER BY id ASC + LIMIT :limit + ) + + SELECT * FROM organisation_invites ORDER BY id DESC + ` + + countPrevOrganisationInvites = ` + SELECT COUNT(DISTINCT(id)) AS count + FROM organisation_invites + WHERE organisation_id = :org_id + AND deleted_at IS NULL + AND id > :cursor + GROUP BY id + ORDER BY id DESC + LIMIT 1 + ` + + deleteOrganisationInvite = ` + UPDATE organisation_invites SET + deleted_at = NOW() + WHERE id = $1 AND deleted_at IS NULL; + ` +) + +type orgInviteRepo struct { + db *sqlx.DB +} + +func NewOrgInviteRepo(db database.Database) datastore.OrganisationInviteRepository { + return &orgInviteRepo{db: db.GetDB()} +} + +func (i *orgInviteRepo) CreateOrganisationInvite(ctx context.Context, iv *datastore.OrganisationInvite) error { + var endpointID *string + var projectID *string + if !util.IsStringEmpty(iv.Role.Endpoint) { + endpointID = &iv.Role.Endpoint + } + + if !util.IsStringEmpty(iv.Role.Project) { + projectID = &iv.Role.Project + } + + r, err := i.db.ExecContext(ctx, createOrganisationInvite, + iv.UID, + iv.OrganisationID, + iv.InviteeEmail, + iv.Token, + iv.Role.Type, + projectID, + endpointID, + iv.Status, + iv.ExpiresAt, + ) + if err != nil { + return err + } + + rowsAffected, err := r.RowsAffected() + if err != nil { + return err + } + + if rowsAffected < 1 { + return ErrOrganizationInviteNotCreated + } + + return nil +} + +func (i *orgInviteRepo) LoadOrganisationsInvitesPaged(ctx context.Context, orgID string, inviteStatus datastore.InviteStatus, pageable datastore.Pageable) ([]datastore.OrganisationInvite, datastore.PaginationData, error) { + arg := map[string]interface{}{ + "org_id": orgID, + "status": inviteStatus, + "limit": pageable.Limit(), + "cursor": pageable.Cursor(), + } + + var query string + if pageable.Direction == datastore.Next { + query = baseFetchInvitesPagedForward + } else { + query = baseFetchInvitesPagedBackward + } + + query = fmt.Sprintf(query, fetchOrganisationInvitesPaginated) + + query, args, err := sqlx.Named(query, arg) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + query, args, err = sqlx.In(query, args...) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + query = i.db.Rebind(query) + + rows, err := i.db.QueryxContext(ctx, query, args...) + if err != nil { + return nil, datastore.PaginationData{}, err + } + defer closeWithError(rows) + + var invites []datastore.OrganisationInvite + for rows.Next() { + var iv datastore.OrganisationInvite + + err = rows.StructScan(&iv) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + invites = append(invites, iv) + } + + var count datastore.PrevRowCount + if len(invites) > 0 { + var countQuery string + var qargs []interface{} + first := invites[0] + qarg := arg + qarg["cursor"] = first.UID + + countQuery, qargs, err = sqlx.Named(countPrevOrganisationInvites, qarg) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + countQuery = i.db.Rebind(countQuery) + + // count the row number before the first row + rows, err := i.db.QueryxContext(ctx, countQuery, qargs...) + if err != nil { + return nil, datastore.PaginationData{}, err + } + defer closeWithError(rows) + + if rows.Next() { + err = rows.StructScan(&count) + if err != nil { + return nil, datastore.PaginationData{}, err + } + } + } + + ids := make([]string, len(invites)) + for i := range invites { + ids[i] = invites[i].UID + } + + if len(invites) > pageable.PerPage { + invites = invites[:len(invites)-1] + } + + pagination := &datastore.PaginationData{PrevRowCount: count} + pagination = pagination.Build(pageable, ids) + + return invites, *pagination, nil +} + +func (i *orgInviteRepo) UpdateOrganisationInvite(ctx context.Context, iv *datastore.OrganisationInvite) error { + var endpointID *string + var projectID *string + if !util.IsStringEmpty(iv.Role.Endpoint) { + endpointID = &iv.Role.Endpoint + } + + if !util.IsStringEmpty(iv.Role.Project) { + projectID = &iv.Role.Project + } + + r, err := i.db.ExecContext(ctx, + updateOrganisationInvite, + iv.UID, + iv.Role.Type, + projectID, + endpointID, + iv.Status, + iv.ExpiresAt, + iv.DeletedAt, + ) + if err != nil { + return err + } + + rowsAffected, err := r.RowsAffected() + if err != nil { + return err + } + + if rowsAffected < 1 { + return ErrOrganizationInviteNotUpdated + } + + return nil +} + +func (i *orgInviteRepo) DeleteOrganisationInvite(ctx context.Context, id string) error { + r, err := i.db.ExecContext(ctx, deleteOrganisationInvite, id) + if err != nil { + return err + } + + rowsAffected, err := r.RowsAffected() + if err != nil { + return err + } + + if rowsAffected < 1 { + return ErrOrganizationInviteNotDeleted + } + + return nil +} + +func (i *orgInviteRepo) FetchOrganisationInviteByID(ctx context.Context, id string) (*datastore.OrganisationInvite, error) { + invite := &datastore.OrganisationInvite{} + err := i.db.QueryRowxContext(ctx, fetchOrganisationInviteById, id).StructScan(invite) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, datastore.ErrOrgInviteNotFound + } + return nil, err + } + + return invite, nil +} + +func (i *orgInviteRepo) FetchOrganisationInviteByToken(ctx context.Context, token string) (*datastore.OrganisationInvite, error) { + invite := &datastore.OrganisationInvite{} + err := i.db.QueryRowxContext(ctx, fetchOrganisationInviteByToken, token).StructScan(invite) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, datastore.ErrOrgInviteNotFound + } + return nil, err + } + + return invite, nil +} diff --git a/database/sqlite3/organisation_invite_test.go b/database/sqlite3/organisation_invite_test.go new file mode 100644 index 0000000000..95392c9b6b --- /dev/null +++ b/database/sqlite3/organisation_invite_test.go @@ -0,0 +1,227 @@ +//go:build integration +// +build integration + +package sqlite3 + +import ( + "context" + "fmt" + "testing" + + "github.com/frain-dev/convoy/auth" + "github.com/oklog/ulid/v2" + + "github.com/frain-dev/convoy/datastore" + "github.com/stretchr/testify/require" +) + +func TestLoadOrganisationsInvitesPaged(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + org := seedOrg(t, db) + inviteRepo := NewOrgInviteRepo(db) + project := seedProject(t, db) + + uids := []string{} + for i := 1; i < 3; i++ { + iv := &datastore.OrganisationInvite{ + UID: ulid.Make().String(), + InviteeEmail: fmt.Sprintf("%s@gmail.com", ulid.Make().String()), + Token: ulid.Make().String(), + OrganisationID: org.UID, + Role: auth.Role{Type: auth.RoleAdmin, Project: project.UID}, + Status: datastore.InviteStatusPending, + } + uids = append(uids, iv.UID) + err := inviteRepo.CreateOrganisationInvite(context.Background(), iv) + require.NoError(t, err) + } + + for i := 1; i < 3; i++ { + iv := &datastore.OrganisationInvite{ + UID: ulid.Make().String(), + InviteeEmail: fmt.Sprintf("%s@gmail.com", ulid.Make().String()), + Token: ulid.Make().String(), + OrganisationID: org.UID, + Role: auth.Role{Type: auth.RoleAdmin, Project: project.UID}, + Status: datastore.InviteStatusDeclined, + } + + err := inviteRepo.CreateOrganisationInvite(context.Background(), iv) + require.NoError(t, err) + } + + organisationInvites, _, err := inviteRepo.LoadOrganisationsInvitesPaged(context.Background(), org.UID, datastore.InviteStatusPending, datastore.Pageable{ + PerPage: 100, + }) + + require.NoError(t, err) + require.Equal(t, 2, len(organisationInvites)) + for _, invite := range organisationInvites { + require.Contains(t, uids, invite.UID) + } +} + +func TestCreateOrganisationInvite(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + project := seedProject(t, db) + + org := seedOrg(t, db) + inviteRepo := NewOrgInviteRepo(db) + iv := &datastore.OrganisationInvite{ + UID: ulid.Make().String(), + InviteeEmail: fmt.Sprintf("%s@gmail.com", ulid.Make().String()), + Token: ulid.Make().String(), + OrganisationID: org.UID, + Role: auth.Role{Type: auth.RoleAdmin, Project: project.UID}, + Status: datastore.InviteStatusPending, + } + + err := inviteRepo.CreateOrganisationInvite(context.Background(), iv) + require.NoError(t, err) + + invite, err := inviteRepo.FetchOrganisationInviteByID(context.Background(), iv.UID) + require.NoError(t, err) + + require.Equal(t, iv.UID, invite.UID) + require.Equal(t, iv.Token, invite.Token) +} + +func TestUpdateOrganisationInvite(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + org := seedOrg(t, db) + project := seedProject(t, db) + + inviteRepo := NewOrgInviteRepo(db) + iv := &datastore.OrganisationInvite{ + UID: ulid.Make().String(), + InviteeEmail: fmt.Sprintf("%s@gmail.com", ulid.Make().String()), + Token: ulid.Make().String(), + OrganisationID: org.UID, + Role: auth.Role{ + Type: auth.RoleAdmin, + Project: project.UID, + }, + Status: datastore.InviteStatusPending, + } + + err := inviteRepo.CreateOrganisationInvite(context.Background(), iv) + require.NoError(t, err) + + role := auth.Role{ + Type: auth.RoleSuperUser, + Project: seedProject(t, db).UID, + Endpoint: "", + } + status := datastore.InviteStatusAccepted + + iv.Role = role + iv.Status = status + + err = inviteRepo.UpdateOrganisationInvite(context.Background(), iv) + require.NoError(t, err) + + invite, err := inviteRepo.FetchOrganisationInviteByID(context.Background(), iv.UID) + require.NoError(t, err) + + require.Equal(t, invite.UID, iv.UID) + require.Equal(t, invite.Role, role) + require.Equal(t, invite.Status, status) +} + +func TestDeleteOrganisationInvite(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + org := seedOrg(t, db) + project := seedProject(t, db) + + inviteRepo := NewOrgInviteRepo(db) + iv := &datastore.OrganisationInvite{ + UID: ulid.Make().String(), + InviteeEmail: fmt.Sprintf("%s@gmail.com", ulid.Make().String()), + Token: ulid.Make().String(), + OrganisationID: org.UID, + Role: auth.Role{ + Type: auth.RoleAdmin, + Project: project.UID, + }, + Status: datastore.InviteStatusPending, + } + + err := inviteRepo.CreateOrganisationInvite(context.Background(), iv) + require.NoError(t, err) + + err = inviteRepo.DeleteOrganisationInvite(context.Background(), iv.UID) + require.NoError(t, err) + + _, err = inviteRepo.FetchOrganisationInviteByID(context.Background(), iv.UID) + require.Equal(t, datastore.ErrOrgInviteNotFound, err) +} + +func TestFetchOrganisationInviteByID(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + project := seedProject(t, db) + + org := seedOrg(t, db) + inviteRepo := NewOrgInviteRepo(db) + iv := &datastore.OrganisationInvite{ + UID: ulid.Make().String(), + InviteeEmail: fmt.Sprintf("%s@gmail.com", ulid.Make().String()), + Token: ulid.Make().String(), + OrganisationID: org.UID, + Role: auth.Role{ + Type: auth.RoleAdmin, + Project: project.UID, + }, + Status: datastore.InviteStatusPending, + } + + err := inviteRepo.CreateOrganisationInvite(context.Background(), iv) + require.NoError(t, err) + + invite, err := inviteRepo.FetchOrganisationInviteByID(context.Background(), iv.UID) + require.NoError(t, err) + + require.Equal(t, iv.UID, invite.UID) + require.Equal(t, iv.Token, invite.Token) + require.Equal(t, iv.InviteeEmail, invite.InviteeEmail) +} + +func TestFetchOrganisationInviteByToken(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + org := seedOrg(t, db) + project := seedProject(t, db) + + inviteRepo := NewOrgInviteRepo(db) + iv := &datastore.OrganisationInvite{ + UID: ulid.Make().String(), + InviteeEmail: fmt.Sprintf("%s@gmail.com", ulid.Make().String()), + Token: ulid.Make().String(), + OrganisationID: org.UID, + Role: auth.Role{ + Type: auth.RoleAdmin, + Project: project.UID, + }, + Status: datastore.InviteStatusPending, + } + + err := inviteRepo.CreateOrganisationInvite(context.Background(), iv) + require.NoError(t, err) + + invite, err := inviteRepo.FetchOrganisationInviteByToken(context.Background(), iv.Token) + require.NoError(t, err) + + require.Equal(t, iv.UID, invite.UID) + require.Equal(t, iv.Token, invite.Token) + require.Equal(t, iv.InviteeEmail, invite.InviteeEmail) +} diff --git a/database/sqlite3/organisation_member.go b/database/sqlite3/organisation_member.go new file mode 100644 index 0000000000..ccfd2abc07 --- /dev/null +++ b/database/sqlite3/organisation_member.go @@ -0,0 +1,499 @@ +package sqlite3 + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/frain-dev/convoy/util" + + "github.com/frain-dev/convoy/database" + "github.com/frain-dev/convoy/datastore" + "github.com/jmoiron/sqlx" +) + +var ( + ErrOrganizationMemberNotCreated = errors.New("organization member could not be created") + ErrOrganizationMemberNotUpdated = errors.New("organization member could not be updated") + ErrOrganizationMemberNotDeleted = errors.New("organization member could not be deleted") +) + +const ( + createOrgMember = ` + INSERT INTO organisation_members (id, organisation_id, user_id, role_type, role_project, role_endpoint) + VALUES ($1, $2, $3, $4, $5, $6); + ` + + updateOrgMember = ` + UPDATE organisation_members + SET + role_type = $2, + role_project = $3, + role_endpoint = $4, + updated_at = NOW() + WHERE id = $1 AND deleted_at IS NULL; + ` + + deleteOrgMember = ` + UPDATE organisation_members SET + deleted_at = NOW() + WHERE id = $1 AND organisation_id = $2 AND deleted_at IS NULL; + ` + + fetchOrgMemberById = ` + SELECT + o.id AS id, + o.organisation_id AS "organisation_id", + o.role_type AS "role.type", + COALESCE(o.role_project,'') AS "role.project", + COALESCE(o.role_endpoint,'') AS "role.endpoint", + u.id AS "user_id", + u.id AS "user_metadata.user_id", + u.first_name AS "user_metadata.first_name", + u.last_name AS "user_metadata.last_name", + u.email AS "user_metadata.email" + FROM organisation_members o + LEFT JOIN users u + ON o.user_id = u.id + WHERE o.id = $1 AND o.organisation_id = $2 AND o.deleted_at IS NULL; + ` + + fetchOrgMemberByUserId = ` + SELECT + o.id AS id, + o.organisation_id AS "organisation_id", + o.role_type AS "role.type", + COALESCE(o.role_project,'') AS "role.project", + COALESCE(o.role_endpoint,'') AS "role.endpoint", + u.id AS "user_id", + u.id AS "user_metadata.user_id", + u.first_name AS "user_metadata.first_name", + u.last_name AS "user_metadata.last_name", + u.email AS "user_metadata.email" + FROM organisation_members o + LEFT JOIN users u + ON o.user_id = u.id + WHERE o.user_id = $1 AND o.organisation_id = $2 AND o.deleted_at IS NULL; + ` + + fetchOrganisationMembersPaged = ` + SELECT + o.id AS id, + o.organisation_id AS "organisation_id", + o.role_type AS "role.type", + COALESCE(o.role_project,'') AS "role.project", + COALESCE(o.role_endpoint,'') AS "role.endpoint", + u.id AS "user_id", + u.id AS "user_metadata.user_id", + u.first_name AS "user_metadata.first_name", + u.last_name AS "user_metadata.last_name", + u.email AS "user_metadata.email" + FROM organisation_members o + LEFT JOIN users u ON o.user_id = u.id + WHERE o.organisation_id = :organisation_id + AND (o.user_id = :user_id OR :user_id = '') + AND o.deleted_at IS NULL + ` + + baseFetchOrganisationMembersPagedForward = ` + %s + AND o.id <= :cursor + GROUP BY o.id, u.id + ORDER BY o.id DESC + LIMIT :limit + ` + + baseFetchOrganisationMembersPagedBackward = ` + WITH organisation_members AS ( + %s + AND o.id >= :cursor + GROUP BY o.id, u.id + ORDER BY o.id ASC + LIMIT :limit + ) + + SELECT * FROM organisation_members ORDER BY id DESC + ` + + countPrevOrganisationMembers = ` + SELECT COUNT(DISTINCT(o.id)) AS count + FROM organisation_members o + LEFT JOIN users u ON o.user_id = u.id + WHERE o.organisation_id = :organisation_id + AND o.deleted_at IS NULL + AND o.id > :cursor + GROUP BY o.id, u.id + ORDER BY o.id DESC + LIMIT 1` + + fetchOrgMemberOrganisations = ` + SELECT o.* FROM organisation_members m + JOIN organisations o ON m.organisation_id = o.id + WHERE m.user_id = :user_id + AND o.deleted_at IS NULL + AND m.deleted_at IS NULL + ` + + baseFetchUserOrganisationsPagedForward = ` + %s + AND o.id <= :cursor + GROUP BY o.id, m.id + ORDER BY o.id DESC + LIMIT :limit + ` + + baseFetchUserOrganisationsPagedBackward = ` + WITH user_organisations AS ( + %s + AND o.id >= :cursor + GROUP BY o.id, m.id + ORDER BY o.id ASC + LIMIT :limit + ) + + SELECT * FROM user_organisations ORDER BY id DESC + ` + + countPrevUserOrgs = ` + SELECT COUNT(DISTINCT(o.id)) AS count + FROM organisation_members m + JOIN organisations o ON m.organisation_id = o.id + WHERE m.user_id = :user_id + AND o.deleted_at IS NULL + AND m.deleted_at IS NULL + AND o.id > :cursor + GROUP BY o.id, m.id + ORDER BY o.id DESC + LIMIT 1` + + fetchUserProjects = ` + SELECT p.id, p.name, p.type, p.retained_events, p.logo_url, + p.organisation_id, p.project_configuration_id, p.created_at, + p.updated_at FROM organisation_members m + RIGHT JOIN projects p ON p.organisation_id = m.organisation_id + WHERE m.user_id = $1 AND m.deleted_at IS NULL AND p.deleted_at IS NULL + ` +) + +type orgMemberRepo struct { + db *sqlx.DB +} + +func NewOrgMemberRepo(db database.Database) datastore.OrganisationMemberRepository { + return &orgMemberRepo{db: db.GetDB()} +} + +func (o *orgMemberRepo) CreateOrganisationMember(ctx context.Context, member *datastore.OrganisationMember) error { + var endpointID *string + var projectID *string + if !util.IsStringEmpty(member.Role.Endpoint) { + endpointID = &member.Role.Endpoint + } + + if !util.IsStringEmpty(member.Role.Project) { + projectID = &member.Role.Project + } + + r, err := o.db.ExecContext(ctx, createOrgMember, + member.UID, + member.OrganisationID, + member.UserID, + member.Role.Type, + projectID, + endpointID, + ) + if err != nil { + return err + } + + nRows, err := r.RowsAffected() + if err != nil { + return err + } + + if nRows < 1 { + return ErrOrganizationMemberNotCreated + } + + return nil +} + +func (o *orgMemberRepo) LoadOrganisationMembersPaged(ctx context.Context, organisationID, userID string, pageable datastore.Pageable) ([]*datastore.OrganisationMember, datastore.PaginationData, error) { + var query string + if pageable.Direction == datastore.Next { + query = baseFetchOrganisationMembersPagedForward + } else { + query = baseFetchOrganisationMembersPagedBackward + } + + query = fmt.Sprintf(query, fetchOrganisationMembersPaged) + + arg := map[string]interface{}{ + "limit": pageable.Limit(), + "cursor": pageable.Cursor(), + "organisation_id": organisationID, + "user_id": userID, + } + + query, args, err := sqlx.Named(query, arg) + if err != nil { + return nil, datastore.PaginationData{}, err + } + query, args, err = sqlx.In(query, args...) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + query = o.db.Rebind(query) + + rows, err := o.db.QueryxContext(ctx, query, args...) + if err != nil { + return nil, datastore.PaginationData{}, err + } + defer closeWithError(rows) + + var members []*datastore.OrganisationMember + for rows.Next() { + var member datastore.OrganisationMember + + err = rows.StructScan(&member) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + members = append(members, &member) + } + + var count datastore.PrevRowCount + if len(members) > 0 { + var countQuery string + var qargs []interface{} + + arg["cursor"] = members[0].UID + + countQuery, qargs, err = sqlx.Named(countPrevOrganisationMembers, arg) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + countQuery = o.db.Rebind(countQuery) + + // count the row number before the first row + rows, err := o.db.QueryxContext(ctx, countQuery, qargs...) + if err != nil { + return nil, datastore.PaginationData{}, err + } + defer closeWithError(rows) + + if rows.Next() { + err = rows.StructScan(&count) + if err != nil { + return nil, datastore.PaginationData{}, err + } + } + } + + ids := make([]string, len(members)) + for i := range members { + ids[i] = members[i].UID + } + + if len(members) > pageable.PerPage { + members = members[:len(members)-1] + } + + pagination := &datastore.PaginationData{PrevRowCount: count} + pagination = pagination.Build(pageable, ids) + + return members, *pagination, nil +} + +func (o *orgMemberRepo) LoadUserOrganisationsPaged(ctx context.Context, userID string, pageable datastore.Pageable) ([]datastore.Organisation, datastore.PaginationData, error) { + var query string + if pageable.Direction == datastore.Next { + query = baseFetchUserOrganisationsPagedForward + } else { + query = baseFetchUserOrganisationsPagedBackward + } + + query = fmt.Sprintf(query, fetchOrgMemberOrganisations) + + arg := map[string]interface{}{ + "limit": pageable.Limit(), + "cursor": pageable.Cursor(), + "user_id": userID, + } + + query, args, err := sqlx.Named(query, arg) + if err != nil { + return nil, datastore.PaginationData{}, err + } + query, args, err = sqlx.In(query, args...) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + query = o.db.Rebind(query) + + rows, err := o.db.QueryxContext(ctx, query, args...) + if err != nil { + return nil, datastore.PaginationData{}, err + } + defer closeWithError(rows) + + organisations := make([]datastore.Organisation, 0) + for rows.Next() { + var org datastore.Organisation + + err = rows.StructScan(&org) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + organisations = append(organisations, org) + } + + var count datastore.PrevRowCount + if len(organisations) > 0 { + var countQuery string + var qargs []interface{} + + arg["cursor"] = organisations[0].UID + + countQuery, qargs, err = sqlx.Named(countPrevUserOrgs, arg) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + countQuery = o.db.Rebind(countQuery) + + // count the row number before the first row + rows, err := o.db.QueryxContext(ctx, countQuery, qargs...) + if err != nil { + return nil, datastore.PaginationData{}, err + } + defer closeWithError(rows) + + if rows.Next() { + err = rows.StructScan(&count) + if err != nil { + return nil, datastore.PaginationData{}, err + } + } + } + + ids := make([]string, len(organisations)) + for i := range organisations { + ids[i] = organisations[i].UID + } + + if len(organisations) > pageable.PerPage { + organisations = organisations[:len(organisations)-1] + } + + pagination := &datastore.PaginationData{PrevRowCount: count} + pagination = pagination.Build(pageable, ids) + + return organisations, *pagination, nil +} + +func (o *orgMemberRepo) FindUserProjects(ctx context.Context, userID string) ([]datastore.Project, error) { + rows, err := o.db.QueryxContext(ctx, fetchUserProjects, userID) + if err != nil { + return nil, err + } + defer closeWithError(rows) + + var projects []datastore.Project + for rows.Next() { + var proj datastore.Project + + err = rows.StructScan(&proj) + if err != nil { + return nil, err + } + + projects = append(projects, proj) + } + + return projects, nil +} + +func (o *orgMemberRepo) UpdateOrganisationMember(ctx context.Context, member *datastore.OrganisationMember) error { + var endpointID *string + var projectID *string + if !util.IsStringEmpty(member.Role.Endpoint) { + endpointID = &member.Role.Endpoint + } + + if !util.IsStringEmpty(member.Role.Project) { + projectID = &member.Role.Project + } + + r, err := o.db.ExecContext(ctx, + updateOrgMember, + member.UID, + member.Role.Type, + projectID, + endpointID, + ) + if err != nil { + return err + } + + nRows, err := r.RowsAffected() + if err != nil { + return err + } + + if nRows < 1 { + return ErrOrganizationMemberNotUpdated + } + + return nil +} + +func (o *orgMemberRepo) DeleteOrganisationMember(ctx context.Context, uid, orgID string) error { + r, err := o.db.ExecContext(ctx, deleteOrgMember, uid, orgID) + if err != nil { + return err + } + + nRows, err := r.RowsAffected() + if err != nil { + return err + } + + if nRows < 1 { + return ErrOrganizationMemberNotDeleted + } + + return nil +} + +func (o *orgMemberRepo) FetchOrganisationMemberByID(ctx context.Context, uid, orgID string) (*datastore.OrganisationMember, error) { + member := &datastore.OrganisationMember{} + err := o.db.QueryRowxContext(ctx, fetchOrgMemberById, uid, orgID).StructScan(member) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, datastore.ErrOrgMemberNotFound + } + return nil, err + } + + return member, nil +} + +func (o *orgMemberRepo) FetchOrganisationMemberByUserID(ctx context.Context, userID, orgID string) (*datastore.OrganisationMember, error) { + member := &datastore.OrganisationMember{} + err := o.db.QueryRowxContext(ctx, fetchOrgMemberByUserId, userID, orgID).StructScan(member) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, datastore.ErrOrgMemberNotFound + } + return nil, err + } + + return member, nil +} diff --git a/database/sqlite3/organisation_member_test.go b/database/sqlite3/organisation_member_test.go new file mode 100644 index 0000000000..dd252ee74e --- /dev/null +++ b/database/sqlite3/organisation_member_test.go @@ -0,0 +1,323 @@ +//go:build integration +// +build integration + +package sqlite3 + +import ( + "context" + "testing" + + "github.com/frain-dev/convoy/auth" + "github.com/oklog/ulid/v2" + + "github.com/frain-dev/convoy/datastore" + "github.com/stretchr/testify/require" +) + +func TestLoadOrganisationMembersPaged(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + organisationMemberRepo := NewOrgMemberRepo(db) + org := seedOrg(t, db) + project := seedProject(t, db) + + userMap := map[string]*datastore.UserMetadata{} + userRepo := NewUserRepo(db) + + for i := 1; i < 6; i++ { + user := generateUser(t) + + require.NoError(t, userRepo.CreateUser(context.Background(), user)) + + member := &datastore.OrganisationMember{ + UID: ulid.Make().String(), + OrganisationID: org.UID, + UserID: user.UID, + Role: auth.Role{Type: auth.RoleAdmin, Project: project.UID}, + } + + userMap[user.UID] = &datastore.UserMetadata{ + UserID: user.UID, + FirstName: user.FirstName, + LastName: user.LastName, + Email: user.Email, + } + + err := organisationMemberRepo.CreateOrganisationMember(context.Background(), member) + require.NoError(t, err) + } + + members, _, err := organisationMemberRepo.LoadOrganisationMembersPaged(context.Background(), org.UID, "", datastore.Pageable{ + PerPage: 2, + }) + + require.NoError(t, err) + require.Equal(t, 2, len(members)) + + for _, member := range members { + m := userMap[member.UserID] + require.Equal(t, *m, member.UserMetadata) + } +} + +func TestLoadUserOrganisationsPaged(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + organisationMemberRepo := NewOrgMemberRepo(db) + orgRepo := NewOrgRepo(db) + project := seedProject(t, db) + + user := seedUser(t, db) + for i := 0; i < 7; i++ { + + org := &datastore.Organisation{ + UID: ulid.Make().String(), + OwnerID: user.UID, + } + + err := orgRepo.CreateOrganisation(context.Background(), org) + require.NoError(t, err) + + member := &datastore.OrganisationMember{ + UID: ulid.Make().String(), + OrganisationID: org.UID, + UserID: user.UID, + Role: auth.Role{Type: auth.RoleAdmin, Project: project.UID}, + } + + err = organisationMemberRepo.CreateOrganisationMember(context.Background(), member) + require.NoError(t, err) + } + + organisations, _, err := organisationMemberRepo.LoadUserOrganisationsPaged(context.Background(), user.UID, datastore.Pageable{ + PerPage: 10, + }) + + require.NoError(t, err) + require.Equal(t, 7, len(organisations)) +} + +func TestCreateOrganisationMember(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + user := generateUser(t) + require.NoError(t, NewUserRepo(db).CreateUser(context.Background(), user)) + org := seedOrg(t, db) + project := seedProject(t, db) + + organisationMemberRepo := NewOrgMemberRepo(db) + + m := &datastore.OrganisationMember{ + UID: ulid.Make().String(), + OrganisationID: org.UID, + UserID: user.UID, + Role: auth.Role{Type: auth.RoleAdmin, Project: project.UID}, + } + + err := organisationMemberRepo.CreateOrganisationMember(context.Background(), m) + require.NoError(t, err) + + member, err := organisationMemberRepo.FetchOrganisationMemberByID(context.Background(), m.UID, m.OrganisationID) + require.NoError(t, err) + + require.Equal(t, m.UID, member.UID) + require.Equal(t, m.OrganisationID, member.OrganisationID) + require.Equal(t, m.UserID, member.UserID) + require.Equal(t, datastore.UserMetadata{ + UserID: user.UID, + FirstName: user.FirstName, + LastName: user.LastName, + Email: user.Email, + }, member.UserMetadata) +} + +func TestUpdateOrganisationMember(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + user := generateUser(t) + org := seedOrg(t, db) + require.NoError(t, NewUserRepo(db).CreateUser(context.Background(), user)) + project := seedProject(t, db) + + organisationMemberRepo := NewOrgMemberRepo(db) + m := &datastore.OrganisationMember{ + UID: ulid.Make().String(), + OrganisationID: org.UID, + UserID: user.UID, + Role: auth.Role{Type: auth.RoleAdmin, Project: project.UID}, + } + + err := organisationMemberRepo.CreateOrganisationMember(context.Background(), m) + require.NoError(t, err) + + role := auth.Role{ + Type: auth.RoleSuperUser, + Project: project.UID, + Endpoint: "", + } + m.Role = role + + err = organisationMemberRepo.UpdateOrganisationMember(context.Background(), m) + require.NoError(t, err) + + member, err := organisationMemberRepo.FetchOrganisationMemberByID(context.Background(), m.UID, m.OrganisationID) + require.NoError(t, err) + + require.Equal(t, m.UID, member.UID) + require.Equal(t, role, member.Role) + require.Equal(t, datastore.UserMetadata{ + UserID: user.UID, + FirstName: user.FirstName, + LastName: user.LastName, + Email: user.Email, + }, member.UserMetadata) +} + +func TestDeleteOrganisationMember(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + organisationMemberRepo := NewOrgMemberRepo(db) + org := seedOrg(t, db) + project := seedProject(t, db) + + m := &datastore.OrganisationMember{ + UID: ulid.Make().String(), + OrganisationID: org.UID, + UserID: org.OwnerID, + Role: auth.Role{Type: auth.RoleAdmin, Project: project.UID}, + } + + err := organisationMemberRepo.CreateOrganisationMember(context.Background(), m) + require.NoError(t, err) + + err = organisationMemberRepo.DeleteOrganisationMember(context.Background(), m.UID, m.OrganisationID) + require.NoError(t, err) + + _, err = organisationMemberRepo.FetchOrganisationMemberByID(context.Background(), m.UID, m.OrganisationID) + require.Equal(t, datastore.ErrOrgMemberNotFound, err) +} + +func TestFetchOrganisationMemberByID(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + user := generateUser(t) + require.NoError(t, NewUserRepo(db).CreateUser(context.Background(), user)) + + org := seedOrg(t, db) + project := seedProject(t, db) + organisationMemberRepo := NewOrgMemberRepo(db) + + m := &datastore.OrganisationMember{ + UID: ulid.Make().String(), + OrganisationID: org.UID, + UserID: user.UID, + Role: auth.Role{Type: auth.RoleAdmin, Project: project.UID}, + } + + err := organisationMemberRepo.CreateOrganisationMember(context.Background(), m) + require.NoError(t, err) + + member, err := organisationMemberRepo.FetchOrganisationMemberByID(context.Background(), m.UID, m.OrganisationID) + require.NoError(t, err) + + require.Equal(t, m.UID, member.UID) + require.Equal(t, m.OrganisationID, member.OrganisationID) + require.Equal(t, m.UserID, member.UserID) + require.Equal(t, datastore.UserMetadata{ + UserID: user.UID, + FirstName: user.FirstName, + LastName: user.LastName, + Email: user.Email, + }, member.UserMetadata) +} + +func TestFetchOrganisationMemberByUserID(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + user := generateUser(t) + require.NoError(t, NewUserRepo(db).CreateUser(context.Background(), user)) + + org := seedOrg(t, db) + project := seedProject(t, db) + + organisationMemberRepo := NewOrgMemberRepo(db) + m := &datastore.OrganisationMember{ + UID: ulid.Make().String(), + OrganisationID: org.UID, + UserID: user.UID, + Role: auth.Role{Type: auth.RoleAdmin, Project: project.UID}, + } + + err := organisationMemberRepo.CreateOrganisationMember(context.Background(), m) + require.NoError(t, err) + + member, err := organisationMemberRepo.FetchOrganisationMemberByUserID(context.Background(), m.UserID, m.OrganisationID) + require.NoError(t, err) + + require.Equal(t, m.UID, member.UID) + require.Equal(t, m.OrganisationID, member.OrganisationID) + require.Equal(t, m.UserID, member.UserID) + require.Equal(t, datastore.UserMetadata{ + UserID: user.UID, + FirstName: user.FirstName, + LastName: user.LastName, + Email: user.Email, + }, member.UserMetadata) +} + +func TestFetchUserProjects(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + user := generateUser(t) + ctx := context.Background() + + require.NoError(t, NewUserRepo(db).CreateUser(ctx, user)) + + org := seedOrg(t, db) + project := seedProject(t, db) + + organisationMemberRepo := NewOrgMemberRepo(db) + projectRepo := NewProjectRepo(db) + m := &datastore.OrganisationMember{ + UID: ulid.Make().String(), + OrganisationID: org.UID, + UserID: user.UID, + Role: auth.Role{Type: auth.RoleAdmin, Project: project.UID}, + } + + err := organisationMemberRepo.CreateOrganisationMember(context.Background(), m) + require.NoError(t, err) + + project1 := &datastore.Project{ + UID: ulid.Make().String(), + Name: "project1", + Config: &datastore.DefaultProjectConfig, + OrganisationID: org.UID, + } + + project2 := &datastore.Project{ + UID: ulid.Make().String(), + Name: "project2", + Config: &datastore.DefaultProjectConfig, + OrganisationID: org.UID, + } + + err = projectRepo.CreateProject(context.Background(), project1) + require.NoError(t, err) + + err = projectRepo.CreateProject(context.Background(), project2) + require.NoError(t, err) + + projects, err := organisationMemberRepo.FindUserProjects(ctx, user.UID) + require.NoError(t, err) + + require.Equal(t, 2, len(projects)) +} diff --git a/database/sqlite3/organisation_test.go b/database/sqlite3/organisation_test.go new file mode 100644 index 0000000000..59922c5b59 --- /dev/null +++ b/database/sqlite3/organisation_test.go @@ -0,0 +1,249 @@ +//go:build integration +// +build integration + +package sqlite3 + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/oklog/ulid/v2" + + "gopkg.in/guregu/null.v4" + + "github.com/frain-dev/convoy/database" + "github.com/frain-dev/convoy/datastore" + "github.com/stretchr/testify/require" +) + +func TestLoadOrganisationsPaged(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + orgRepo := NewOrgRepo(db) + + user := seedUser(t, db) + + for i := 1; i < 6; i++ { + org := &datastore.Organisation{ + UID: ulid.Make().String(), + OwnerID: user.UID, + Name: fmt.Sprintf("org%d", i), + CustomDomain: null.NewString(ulid.Make().String(), true), + AssignedDomain: null.NewString(ulid.Make().String(), true), + } + + err := orgRepo.CreateOrganisation(context.Background(), org) + require.NoError(t, err) + } + + organisations, _, err := orgRepo.LoadOrganisationsPaged(context.Background(), datastore.Pageable{ + PerPage: 2, + }) + + require.NoError(t, err) + require.Equal(t, 2, len(organisations)) +} + +func TestCountOrganisations(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + orgRepo := NewOrgRepo(db) + + user := seedUser(t, db) + count := 10 + for i := 0; i < count; i++ { + org := &datastore.Organisation{ + UID: ulid.Make().String(), + OwnerID: user.UID, + Name: fmt.Sprintf("org%d", i), + CustomDomain: null.NewString(ulid.Make().String(), true), + AssignedDomain: null.NewString(ulid.Make().String(), true), + } + + err := orgRepo.CreateOrganisation(context.Background(), org) + require.NoError(t, err) + } + + orgCount, err := orgRepo.CountOrganisations(context.Background()) + + require.NoError(t, err) + require.Equal(t, int64(count), orgCount) +} + +func TestCreateOrganisation(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + user := seedUser(t, db) + + org := &datastore.Organisation{ + UID: ulid.Make().String(), + Name: "new org", + OwnerID: user.UID, + CustomDomain: null.NewString("https://google.com", true), + AssignedDomain: null.NewString("https://google.com", true), + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + err := NewOrgRepo(db).CreateOrganisation(context.Background(), org) + require.NoError(t, err) +} + +func TestUpdateOrganisation(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + orgRepo := NewOrgRepo(db) + + user := seedUser(t, db) + + org := &datastore.Organisation{ + Name: "new org", + OwnerID: user.UID, + CustomDomain: null.NewString(ulid.Make().String(), true), + AssignedDomain: null.NewString(ulid.Make().String(), true), + } + + err := orgRepo.CreateOrganisation(context.Background(), org) + require.NoError(t, err) + + name := "organisation update" + org.Name = name + newDomain := null.NewString("https://yt.com", true) + + org.CustomDomain = newDomain + org.AssignedDomain = newDomain + + err = orgRepo.UpdateOrganisation(context.Background(), org) + require.NoError(t, err) + + dbOrg, err := orgRepo.FetchOrganisationByID(context.Background(), org.UID) + require.NoError(t, err) + + require.Equal(t, name, dbOrg.Name) + require.Equal(t, newDomain, dbOrg.CustomDomain) + require.Equal(t, newDomain, dbOrg.AssignedDomain) +} + +func TestFetchOrganisationByID(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + user := seedUser(t, db) + + orgRepo := NewOrgRepo(db) + + org := &datastore.Organisation{ + UID: ulid.Make().String(), + Name: "new org", + OwnerID: user.UID, + CustomDomain: null.NewString("https://google.com", true), + AssignedDomain: null.NewString("https://google.com", true), + } + + err := orgRepo.CreateOrganisation(context.Background(), org) + require.NoError(t, err) + + dbOrg, err := orgRepo.FetchOrganisationByID(context.Background(), org.UID) + require.NoError(t, err) + require.NotEmpty(t, dbOrg.CreatedAt) + require.NotEmpty(t, dbOrg.UpdatedAt) + + dbOrg.CreatedAt = time.Time{} + dbOrg.UpdatedAt = time.Time{} + + require.Equal(t, org, dbOrg) +} + +func TestFetchOrganisationByAssignedDomain(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + user := seedUser(t, db) + + orgRepo := NewOrgRepo(db) + + org := &datastore.Organisation{ + UID: ulid.Make().String(), + Name: "new org", + OwnerID: user.UID, + CustomDomain: null.NewString("https://yt.com", true), + AssignedDomain: null.NewString("https://google.com", true), + } + + err := orgRepo.CreateOrganisation(context.Background(), org) + require.NoError(t, err) + + dbOrg, err := orgRepo.FetchOrganisationByAssignedDomain(context.Background(), "https://google.com") + require.NoError(t, err) + require.NotEmpty(t, dbOrg.CreatedAt) + require.NotEmpty(t, dbOrg.UpdatedAt) + + dbOrg.CreatedAt = time.Time{} + dbOrg.UpdatedAt = time.Time{} + + require.Equal(t, org, dbOrg) +} + +func TestFetchOrganisationByCustomDomain(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + user := seedUser(t, db) + + orgRepo := NewOrgRepo(db) + + org := &datastore.Organisation{ + UID: ulid.Make().String(), + Name: "new org", + OwnerID: user.UID, + CustomDomain: null.NewString("https://yt.com", true), + AssignedDomain: null.NewString("https://google.com", true), + } + + err := orgRepo.CreateOrganisation(context.Background(), org) + require.NoError(t, err) + + dbOrg, err := orgRepo.FetchOrganisationByCustomDomain(context.Background(), "https://yt.com") + require.NoError(t, err) + require.NotEmpty(t, dbOrg.CreatedAt) + require.NotEmpty(t, dbOrg.UpdatedAt) + + dbOrg.CreatedAt = time.Time{} + dbOrg.UpdatedAt = time.Time{} + + require.Equal(t, org, dbOrg) +} + +func TestDeleteOrganisation(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + orgRepo := NewOrgRepo(db) + user := seedUser(t, db) + + org := &datastore.Organisation{Name: "new org", OwnerID: user.UID} + + err := orgRepo.CreateOrganisation(context.Background(), org) + require.NoError(t, err) + + err = orgRepo.DeleteOrganisation(context.Background(), org.UID) + require.NoError(t, err) + + _, err = orgRepo.FetchOrganisationByID(context.Background(), org.UID) + require.Equal(t, datastore.ErrOrgNotFound, err) +} + +func seedUser(t *testing.T, db database.Database) *datastore.User { + user := generateUser(t) + + err := NewUserRepo(db).CreateUser(context.Background(), user) + require.NoError(t, err) + + return user +} diff --git a/database/sqlite3/portal_link.go b/database/sqlite3/portal_link.go new file mode 100644 index 0000000000..3d87f1e8c5 --- /dev/null +++ b/database/sqlite3/portal_link.go @@ -0,0 +1,493 @@ +package sqlite3 + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/frain-dev/convoy/database" + "github.com/frain-dev/convoy/datastore" + "github.com/frain-dev/convoy/util" + "github.com/jmoiron/sqlx" +) + +var ( + ErrPortalLinkNotCreated = errors.New("portal link could not be created") + ErrPortalLinkNotUpdated = errors.New("portal link could not be updated") + ErrPortalLinkNotDeleted = errors.New("portal link could not be deleted") +) + +const ( + createPortalLink = ` + INSERT INTO portal_links (id, project_id, name, token, endpoints, owner_id, can_manage_endpoint) + VALUES ($1, $2, $3, $4, $5, $6, $7); + ` + + createPortalLinkEndpoints = ` + INSERT INTO portal_links_endpoints (portal_link_id, endpoint_id) VALUES (:portal_link_id, :endpoint_id) + ` + + updatePortalLink = ` + UPDATE portal_links + SET + name = $2, + endpoints = $3, + owner_id = $4, + can_manage_endpoint = $5, + updated_at = NOW() + WHERE id = $1 AND deleted_at IS NULL; + ` + + deletePortalLinkEndpoints = ` + DELETE FROM portal_links_endpoints + WHERE portal_link_id = $1 OR endpoint_id = $2 + ` + + fetchPortalLinkById = ` + SELECT + p.id, + p.project_id, + p.name, + p.token, + p.endpoints, + COALESCE(p.can_manage_endpoint, FALSE) AS "can_manage_endpoint", + COALESCE(p.owner_id, '') AS "owner_id", + CASE + WHEN p.owner_id != '' THEN (SELECT count(id) FROM endpoints WHERE owner_id = p.owner_id) + ELSE (SELECT count(portal_link_id) FROM portal_links_endpoints WHERE portal_link_id = p.id) + END AS endpoint_count, + p.created_at, + p.updated_at, + ARRAY_TO_JSON(ARRAY_AGG(DISTINCT CASE WHEN e.id IS NOT NULL THEN cast(JSON_BUILD_OBJECT('uid', e.id, 'name', e.name, 'project_id', e.project_id, 'url', e.url, 'secrets', e.secrets) as jsonb) END)) AS endpoints_metadata + FROM portal_links p + LEFT JOIN portal_links_endpoints pe + ON p.id = pe.portal_link_id + LEFT JOIN endpoints e + ON e.id = pe.endpoint_id + WHERE p.id = $1 AND p.project_id = $2 AND p.deleted_at IS NULL + GROUP BY p.id + ` + + fetchPortalLinkByOwnerID = ` + SELECT + p.id, + p.project_id, + p.name, + p.token, + p.endpoints, + COALESCE(p.can_manage_endpoint, FALSE) AS "can_manage_endpoint", + COALESCE(p.owner_id, '') AS "owner_id", + CASE + WHEN p.owner_id != '' THEN (SELECT count(id) FROM endpoints WHERE owner_id = p.owner_id) + ELSE (SELECT count(portal_link_id) FROM portal_links_endpoints WHERE portal_link_id = p.id) + END AS endpoint_count, + p.created_at, + p.updated_at, + ARRAY_TO_JSON(ARRAY_AGG(DISTINCT CASE WHEN e.id IS NOT NULL THEN cast(JSON_BUILD_OBJECT('uid', e.id, 'name', e.name, 'project_id', e.project_id, 'url', e.url, 'secrets', e.secrets) as jsonb) END)) AS endpoints_metadata + FROM portal_links p + LEFT JOIN portal_links_endpoints pe + ON p.id = pe.portal_link_id + LEFT JOIN endpoints e + ON e.id = pe.endpoint_id + WHERE p.owner_id = $1 AND p.project_id = $2 AND p.deleted_at IS NULL + GROUP BY p.id + ` + + fetchPortalLinkByToken = ` + SELECT + p.id, + p.project_id, + p.name, + p.token, + p.endpoints, + COALESCE(p.can_manage_endpoint, FALSE) AS "can_manage_endpoint", + COALESCE(p.owner_id, '') AS "owner_id", + CASE + WHEN p.owner_id != '' THEN (SELECT count(id) FROM endpoints WHERE owner_id = p.owner_id) + ELSE (SELECT count(portal_link_id) FROM portal_links_endpoints WHERE portal_link_id = p.id) + END AS endpoint_count, + p.created_at, + p.updated_at, + ARRAY_TO_JSON(ARRAY_AGG(DISTINCT CASE WHEN e.id IS NOT NULL THEN cast(JSON_BUILD_OBJECT('uid', e.id, 'name', e.name, 'project_id', e.project_id, 'url', e.url, 'secrets', e.secrets) as jsonb) END)) AS endpoints_metadata + FROM portal_links p + LEFT JOIN portal_links_endpoints pe + ON p.id = pe.portal_link_id + LEFT JOIN endpoints e + ON e.id = pe.endpoint_id + WHERE p.token = $1 AND p.deleted_at IS NULL + GROUP BY p.id + ` + + countPrevPortalLinks = ` + SELECT COUNT(DISTINCT(p.id)) AS count + FROM portal_links p + LEFT JOIN portal_links_endpoints pe + ON p.id = pe.portal_link_id + LEFT JOIN endpoints e + ON e.id = pe.endpoint_id + WHERE p.deleted_at IS NULL + %s + AND p.id > :cursor GROUP BY p.id ORDER BY p.id DESC LIMIT 1` + + fetchPortalLinksPaginated = ` + SELECT + p.id, + p.project_id, + p.name, + p.token, + p.endpoints, + COALESCE(p.can_manage_endpoint, FALSE) AS "can_manage_endpoint", + COALESCE(p.owner_id, '') AS "owner_id", + CASE + WHEN p.owner_id != '' THEN (SELECT count(id) FROM endpoints WHERE owner_id = p.owner_id) + ELSE (SELECT count(portal_link_id) FROM portal_links_endpoints WHERE portal_link_id = p.id) + END AS endpoint_count, + p.created_at, + p.updated_at, + ARRAY_TO_JSON(ARRAY_AGG(DISTINCT CASE WHEN e.id IS NOT NULL THEN cast(JSON_BUILD_OBJECT('uid', e.id, 'name', e.name, 'project_id', e.project_id, 'url', e.url, 'secrets', e.secrets) as jsonb) END)) AS endpoints_metadata + FROM portal_links p + LEFT JOIN portal_links_endpoints pe + ON p.id = pe.portal_link_id + LEFT JOIN endpoints e + ON e.id = pe.endpoint_id + WHERE p.deleted_at IS NULL` + + baseFetchPortalLinksPagedForward = ` + %s + %s + AND p.id <= :cursor + GROUP BY p.id + ORDER BY p.id DESC + LIMIT :limit + ` + + baseFetchPortalLinksPagedBackward = ` + WITH portal_links AS ( + %s + %s + AND p.id >= :cursor + GROUP BY p.id + ORDER BY p.id ASC + LIMIT :limit + ) + + SELECT * FROM portal_links ORDER BY id DESC + ` + + basePortalLinkFilter = ` + AND (p.project_id = :project_id OR :project_id = '')` + + deletePortalLink = ` + UPDATE portal_links SET + deleted_at = NOW() + WHERE id = $1 AND project_id = $2 AND deleted_at IS NULL; + ` +) + +type portalLinkRepo struct { + db *sqlx.DB +} + +func NewPortalLinkRepo(db database.Database) datastore.PortalLinkRepository { + return &portalLinkRepo{db: db.GetDB()} +} + +func (p *portalLinkRepo) CreatePortalLink(ctx context.Context, portal *datastore.PortalLink) error { + tx, err := p.db.BeginTxx(ctx, &sql.TxOptions{}) + if err != nil { + return err + } + defer rollbackTx(tx) + + r, err := tx.ExecContext(ctx, createPortalLink, + portal.UID, + portal.ProjectID, + portal.Name, + portal.Token, + portal.Endpoints, + portal.OwnerID, + portal.CanManageEndpoint, + ) + if err != nil { + return err + } + + rowsAffected, err := r.RowsAffected() + if err != nil { + return err + } + + if rowsAffected < 1 { + return ErrPortalLinkNotCreated + } + + err = p.upsertPortalLinkEndpoint(ctx, tx, portal) + if err != nil { + return err + } + + return tx.Commit() +} + +func (p *portalLinkRepo) UpdatePortalLink(ctx context.Context, projectID string, portal *datastore.PortalLink) error { + tx, err := p.db.BeginTxx(ctx, &sql.TxOptions{}) + if err != nil { + return err + } + defer rollbackTx(tx) + + r, err := tx.ExecContext(ctx, updatePortalLink, + portal.UID, + portal.Name, + portal.Endpoints, + portal.OwnerID, + portal.CanManageEndpoint, + ) + if err != nil { + return err + } + + rowsAffected, err := r.RowsAffected() + if err != nil { + return err + } + + if rowsAffected < 1 { + return ErrPortalLinkNotUpdated + } + + err = p.upsertPortalLinkEndpoint(ctx, tx, portal) + if err != nil { + return err + } + + return tx.Commit() +} + +func (p *portalLinkRepo) FindPortalLinkByID(ctx context.Context, projectID string, id string) (*datastore.PortalLink, error) { + portalLink := &datastore.PortalLink{} + err := p.db.QueryRowxContext(ctx, fetchPortalLinkById, id, projectID).StructScan(portalLink) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, datastore.ErrPortalLinkNotFound + } + return nil, err + } + + return portalLink, nil +} + +func (p *portalLinkRepo) FindPortalLinkByOwnerID(ctx context.Context, projectID string, ownerID string) (*datastore.PortalLink, error) { + portalLink := &datastore.PortalLink{} + err := p.db.QueryRowxContext(ctx, fetchPortalLinkByOwnerID, ownerID, projectID).StructScan(portalLink) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, datastore.ErrPortalLinkNotFound + } + return nil, err + } + + return portalLink, nil +} + +func (p *portalLinkRepo) FindPortalLinkByToken(ctx context.Context, token string) (*datastore.PortalLink, error) { + portalLink := &datastore.PortalLink{} + err := p.db.QueryRowxContext(ctx, fetchPortalLinkByToken, token).StructScan(portalLink) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, datastore.ErrPortalLinkNotFound + } + return nil, err + } + + return portalLink, nil +} + +func (p *portalLinkRepo) LoadPortalLinksPaged(ctx context.Context, projectID string, filter *datastore.FilterBy, pageable datastore.Pageable) ([]datastore.PortalLink, datastore.PaginationData, error) { + var err error + var args []interface{} + var query, filterQuery string + + if !util.IsStringEmpty(filter.EndpointID) { + filter.EndpointIDs = append(filter.EndpointIDs, filter.EndpointID) + } + + arg := map[string]interface{}{ + "project_id": projectID, + "endpoint_ids": filter.EndpointIDs, + "limit": pageable.Limit(), + "cursor": pageable.Cursor(), + } + + if pageable.Direction == datastore.Next { + query = baseFetchPortalLinksPagedForward + } else { + query = baseFetchPortalLinksPagedBackward + } + + filterQuery = basePortalLinkFilter + if len(filter.EndpointIDs) > 0 { + filterQuery += ` AND pe.endpoint_id IN (:endpoint_ids)` + } + + query = fmt.Sprintf(query, fetchPortalLinksPaginated, filterQuery) + query, args, err = sqlx.Named(query, arg) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + query, args, err = sqlx.In(query, args...) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + query = p.db.Rebind(query) + + rows, err := p.db.QueryxContext(ctx, query, args...) + if err != nil { + return nil, datastore.PaginationData{}, err + } + defer closeWithError(rows) + + var portalLinks []datastore.PortalLink + + for rows.Next() { + var link datastore.PortalLink + + err = rows.StructScan(&link) + if err != nil { + return nil, datastore.PaginationData{}, err + } + portalLinks = append(portalLinks, link) + } + + var count datastore.PrevRowCount + if len(portalLinks) > 0 { + var countQuery string + var qargs []interface{} + first := portalLinks[0] + qarg := arg + qarg["cursor"] = first.UID + + cq := fmt.Sprintf(countPrevPortalLinks, filterQuery) + countQuery, qargs, err = sqlx.Named(cq, qarg) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + countQuery, qargs, err = sqlx.In(countQuery, qargs...) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + countQuery = p.db.Rebind(countQuery) + + // count the row number before the first row + rows, err := p.db.QueryxContext(ctx, countQuery, qargs...) + if err != nil { + return nil, datastore.PaginationData{}, err + } + defer closeWithError(rows) + + if rows.Next() { + err = rows.StructScan(&count) + if err != nil { + return nil, datastore.PaginationData{}, err + } + } + } + + ids := make([]string, len(portalLinks)) + for i := range portalLinks { + ids[i] = portalLinks[i].UID + } + + if len(portalLinks) > pageable.PerPage { + portalLinks = portalLinks[:len(portalLinks)-1] + } + + pagination := &datastore.PaginationData{PrevRowCount: count} + pagination = pagination.Build(pageable, ids) + + return portalLinks, *pagination, nil +} + +func (p *portalLinkRepo) RevokePortalLink(ctx context.Context, projectID string, id string) error { + r, err := p.db.ExecContext(ctx, deletePortalLink, id, projectID) + if err != nil { + return err + } + + rowsAffected, err := r.RowsAffected() + if err != nil { + return err + } + + if rowsAffected < 1 { + return ErrPortalLinkNotDeleted + } + + return nil +} + +func (p *portalLinkRepo) upsertPortalLinkEndpoint(ctx context.Context, tx *sqlx.Tx, portal *datastore.PortalLink) error { + var ids []interface{} + + if len(portal.Endpoints) > 0 { + for _, endpointID := range portal.Endpoints { + ids = append(ids, &PortalLinkEndpoint{PortalLinkID: portal.UID, EndpointID: endpointID}) + } + } else if !util.IsStringEmpty(portal.OwnerID) { + rows, err := p.db.QueryxContext(ctx, fetchEndpointsByOwnerId, portal.ProjectID, portal.OwnerID) + if err != nil { + return err + } + defer closeWithError(rows) + + for rows.Next() { + var endpoint datastore.Endpoint + err := rows.StructScan(&endpoint) + if err != nil { + return err + } + + ids = append(ids, &PortalLinkEndpoint{PortalLinkID: portal.UID, EndpointID: endpoint.UID}) + } + + if len(ids) == 0 { + return nil + } + } else { + return errors.New("owner_id or endpoints must be present") + } + + _, err := tx.ExecContext(ctx, deletePortalLinkEndpoints, portal.UID, nil) + if err != nil { + return err + } + + _, err = tx.NamedExecContext(ctx, createPortalLinkEndpoints, ids) + if err != nil { + return err + } + + return nil +} + +type PortalLinkEndpoint struct { + PortalLinkID string `db:"portal_link_id"` + EndpointID string `db:"endpoint_id"` +} + +type PortalLinkPaginated struct { + Count int `db:"count"` + Endpoint struct { + UID string `db:"id"` + Title string `db:"title"` + ProjectID string `db:"project_id"` + SupportEmail string `db:"support_email"` + TargetUrl string `db:"target_url"` + } `db:"endpoint"` + datastore.PortalLink +} diff --git a/database/sqlite3/portal_link_test.go b/database/sqlite3/portal_link_test.go new file mode 100644 index 0000000000..668e42e9d6 --- /dev/null +++ b/database/sqlite3/portal_link_test.go @@ -0,0 +1,240 @@ +//go:build integration +// +build integration + +package sqlite3 + +import ( + "context" + "errors" + "fmt" + "math" + "testing" + "time" + + "github.com/frain-dev/convoy/database" + "github.com/frain-dev/convoy/datastore" + "github.com/oklog/ulid/v2" + "github.com/stretchr/testify/require" +) + +func Test_CreatePortalLink(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + portalLinkRepo := NewPortalLinkRepo(db) + portalLink := generatePortalLink(t, db) + + require.NoError(t, portalLinkRepo.CreatePortalLink(context.Background(), portalLink)) + + newPortalLink, err := portalLinkRepo.FindPortalLinkByID(context.Background(), portalLink.ProjectID, portalLink.UID) + require.NoError(t, err) + + newPortalLink.CreatedAt = time.Time{} + newPortalLink.UpdatedAt = time.Time{} + + require.Equal(t, portalLink.Name, newPortalLink.Name) + require.Equal(t, portalLink.Token, newPortalLink.Token) + require.Equal(t, portalLink.ProjectID, newPortalLink.ProjectID) +} + +func Test_FindPortalLinkByID(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + portalLinkRepo := NewPortalLinkRepo(db) + portalLink := generatePortalLink(t, db) + ctx := context.Background() + + _, err := portalLinkRepo.FindPortalLinkByID(ctx, portalLink.ProjectID, portalLink.UID) + require.Error(t, err) + require.True(t, errors.Is(err, datastore.ErrPortalLinkNotFound)) + + require.NoError(t, portalLinkRepo.CreatePortalLink(ctx, portalLink)) + + newPortalLink, err := portalLinkRepo.FindPortalLinkByID(ctx, portalLink.ProjectID, portalLink.UID) + require.NoError(t, err) + + newPortalLink.CreatedAt = time.Time{} + newPortalLink.UpdatedAt = time.Time{} + + require.Equal(t, portalLink.Name, newPortalLink.Name) + require.Equal(t, portalLink.Token, newPortalLink.Token) + require.Equal(t, portalLink.ProjectID, newPortalLink.ProjectID) +} + +func Test_FindPortalLinkByToken(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + portalLinkRepo := NewPortalLinkRepo(db) + portalLink := generatePortalLink(t, db) + ctx := context.Background() + + _, err := portalLinkRepo.FindPortalLinkByToken(ctx, portalLink.Token) + require.Error(t, err) + require.True(t, errors.Is(err, datastore.ErrPortalLinkNotFound)) + + require.NoError(t, portalLinkRepo.CreatePortalLink(ctx, portalLink)) + + newPortalLink, err := portalLinkRepo.FindPortalLinkByToken(ctx, portalLink.Token) + require.NoError(t, err) + + newPortalLink.CreatedAt = time.Time{} + newPortalLink.UpdatedAt = time.Time{} + + require.Equal(t, portalLink.Name, newPortalLink.Name) + require.Equal(t, portalLink.Token, newPortalLink.Token) + require.Equal(t, portalLink.ProjectID, newPortalLink.ProjectID) +} + +func Test_UpdatePortalLink(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + portalLinkRepo := NewPortalLinkRepo(db) + projectRepo := NewProjectRepo(db) + endpointRepo := NewEndpointRepo(db) + + portalLink := generatePortalLink(t, db) + ctx := context.Background() + + project, err := projectRepo.FetchProjectByID(ctx, portalLink.ProjectID) + require.NoError(t, err) + + require.NoError(t, portalLinkRepo.CreatePortalLink(ctx, portalLink)) + + portalLink.Name = "Updated-Test-Portal-Token" + endpoint := generateEndpoint(project) + + err = endpointRepo.CreateEndpoint(ctx, endpoint, project.UID) + require.NoError(t, err) + + portalLink.Endpoints = []string{endpoint.UID} + require.NoError(t, portalLinkRepo.UpdatePortalLink(ctx, portalLink.ProjectID, portalLink)) + + newPortalLink, err := portalLinkRepo.FindPortalLinkByID(ctx, portalLink.ProjectID, portalLink.UID) + require.NoError(t, err) + + total, _, err := portalLinkRepo.LoadPortalLinksPaged(ctx, project.UID, &datastore.FilterBy{EndpointIDs: []string{endpoint.UID}}, datastore.Pageable{PerPage: 10, Direction: datastore.Next, NextCursor: fmt.Sprintf("%d", math.MaxInt)}) + require.NoError(t, err) + + require.Equal(t, 1, len(total)) + require.Equal(t, endpoint.UID, total[0].EndpointsMetadata[0].UID) + + newPortalLink.CreatedAt = time.Time{} + newPortalLink.UpdatedAt = time.Time{} + + require.Equal(t, portalLink.Name, newPortalLink.Name) + require.Equal(t, portalLink.Token, newPortalLink.Token) + require.Equal(t, portalLink.ProjectID, newPortalLink.ProjectID) +} + +func Test_RevokePortalLink(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + portalLinkRepo := NewPortalLinkRepo(db) + portalLink := generatePortalLink(t, db) + ctx := context.Background() + + require.NoError(t, portalLinkRepo.CreatePortalLink(ctx, portalLink)) + + _, err := portalLinkRepo.FindPortalLinkByID(ctx, portalLink.ProjectID, portalLink.UID) + require.NoError(t, err) + + require.NoError(t, portalLinkRepo.RevokePortalLink(ctx, portalLink.ProjectID, portalLink.UID)) + + _, err = portalLinkRepo.FindPortalLinkByID(ctx, portalLink.ProjectID, portalLink.UID) + require.Error(t, err) + require.True(t, errors.Is(err, datastore.ErrPortalLinkNotFound)) +} + +func Test_LoadPortalLinksPaged(t *testing.T) { + type Expected struct { + paginationData datastore.PaginationData + } + + tests := []struct { + name string + pageData datastore.Pageable + count int + expected Expected + }{ + { + name: "Load Portal Links Paged - 10 records", + pageData: datastore.Pageable{PerPage: 3}, + count: 10, + expected: Expected{ + paginationData: datastore.PaginationData{ + PerPage: 3, + }, + }, + }, + + { + name: "Load Portal Links Paged - 12 records", + pageData: datastore.Pageable{PerPage: 4}, + count: 12, + expected: Expected{ + paginationData: datastore.PaginationData{ + PerPage: 4, + }, + }, + }, + + { + name: "Load Portal Links Paged - 5 records", + pageData: datastore.Pageable{PerPage: 3}, + count: 5, + expected: Expected{ + paginationData: datastore.PaginationData{ + PerPage: 3, + }, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + project := seedProject(t, db) + endpoint := generateEndpoint(project) + portalLinkRepo := NewPortalLinkRepo(db) + NewEndpointRepo(db).CreateEndpoint(context.Background(), endpoint, project.UID) + + for i := 0; i < tc.count; i++ { + portalLink := &datastore.PortalLink{ + UID: ulid.Make().String(), + ProjectID: project.UID, + Name: "Test-Portal-Link", + Token: ulid.Make().String(), + Endpoints: []string{endpoint.UID}, + } + require.NoError(t, portalLinkRepo.CreatePortalLink(context.Background(), portalLink)) + } + + _, pageable, err := portalLinkRepo.LoadPortalLinksPaged(context.Background(), project.UID, &datastore.FilterBy{EndpointID: endpoint.UID}, tc.pageData) + + require.NoError(t, err) + require.Equal(t, tc.expected.paginationData.PerPage, pageable.PerPage) + }) + } +} + +func generatePortalLink(t *testing.T, db database.Database) *datastore.PortalLink { + project := seedProject(t, db) + + endpoint := generateEndpoint(project) + err := NewEndpointRepo(db).CreateEndpoint(context.Background(), endpoint, project.UID) + require.NoError(t, err) + + return &datastore.PortalLink{ + UID: ulid.Make().String(), + ProjectID: project.UID, + Name: "Test-Portal-Link", + Token: ulid.Make().String(), + Endpoints: []string{endpoint.UID}, + } +} diff --git a/database/sqlite3/project.go b/database/sqlite3/project.go new file mode 100644 index 0000000000..90194dd107 --- /dev/null +++ b/database/sqlite3/project.go @@ -0,0 +1,509 @@ +package sqlite3 + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strings" + + "github.com/frain-dev/convoy/database/hooks" + "github.com/r3labs/diff/v3" + + "github.com/jmoiron/sqlx" + "github.com/oklog/ulid/v2" + + "github.com/frain-dev/convoy/database" + "github.com/frain-dev/convoy/datastore" +) + +var ( + ErrProjectConfigNotCreated = errors.New("project config could not be created") + ErrProjectConfigNotUpdated = errors.New("project config could not be updated") + ErrProjectNotCreated = errors.New("project could not be created") + ErrProjectNotUpdated = errors.New("project could not be updated") +) + +const ( + createProject = ` + INSERT INTO projects (id, name, type, logo_url, organisation_id, project_configuration_id) + VALUES ($1, $2, $3, $4, $5, $6) RETURNING id; + ` + + createProjectConfiguration = ` + INSERT INTO project_configurations ( + id, search_policy, + max_payload_read_size, replay_attacks_prevention_enabled, + ratelimit_count, + ratelimit_duration, strategy_type, + strategy_duration, strategy_retry_count, + signature_header, signature_versions, disable_endpoint, + meta_events_enabled, meta_events_type, meta_events_event_type, + meta_events_url, meta_events_secret, meta_events_pub_sub,ssl_enforce_secure_endpoints, + multiple_endpoint_subscriptions + ) + VALUES + ( + $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, + $14, $15, $16, $17, $18, $19, $20 + ); + ` + + updateProjectConfiguration = ` + UPDATE project_configurations SET + max_payload_read_size = $2, + replay_attacks_prevention_enabled = $3, + ratelimit_count = $4, + ratelimit_duration = $5, + strategy_type = $6, + strategy_duration = $7, + strategy_retry_count = $8, + signature_header = $9, + signature_versions = $10, + disable_endpoint = $11, + meta_events_enabled = $12, + meta_events_type = $13, + meta_events_event_type = $14, + meta_events_url = $15, + meta_events_secret = $16, + meta_events_pub_sub = $17, + search_policy = $18, + ssl_enforce_secure_endpoints = $19, + multiple_endpoint_subscriptions = $20, + updated_at = NOW() + WHERE id = $1 AND deleted_at IS NULL; + ` + fetchProjectById = ` + SELECT + p.id, + p.name, + p.type, + p.retained_events, + p.logo_url, + p.organisation_id, + p.project_configuration_id, + c.search_policy AS "config.search_policy", + c.max_payload_read_size AS "config.max_payload_read_size", + c.multiple_endpoint_subscriptions AS "config.multiple_endpoint_subscriptions", + c.replay_attacks_prevention_enabled AS "config.replay_attacks_prevention_enabled", + c.ratelimit_count AS "config.ratelimit.count", + c.ratelimit_duration AS "config.ratelimit.duration", + c.strategy_type AS "config.strategy.type", + c.strategy_duration AS "config.strategy.duration", + c.strategy_retry_count AS "config.strategy.retry_count", + c.signature_header AS "config.signature.header", + c.signature_versions AS "config.signature.versions", + c.disable_endpoint AS "config.disable_endpoint", + c.ssl_enforce_secure_endpoints as "config.ssl.enforce_secure_endpoints", + c.meta_events_enabled AS "config.meta_event.is_enabled", + COALESCE(c.meta_events_type, '') AS "config.meta_event.type", + c.meta_events_event_type AS "config.meta_event.event_type", + COALESCE(c.meta_events_url, '') AS "config.meta_event.url", + COALESCE(c.meta_events_secret, '') AS "config.meta_event.secret", + c.meta_events_pub_sub AS "config.meta_event.pub_sub", + p.created_at, + p.updated_at, + p.deleted_at + FROM projects p + LEFT JOIN project_configurations c + ON p.project_configuration_id = c.id + WHERE p.id = $1 AND p.deleted_at IS NULL; +` + fetchProjects = ` + SELECT + p.id, + p.name, + p.type, + p.retained_events, + p.logo_url, + p.organisation_id, + p.project_configuration_id, + c.search_policy AS "config.search_policy", + c.max_payload_read_size AS "config.max_payload_read_size", + c.multiple_endpoint_subscriptions AS "config.multiple_endpoint_subscriptions", + c.replay_attacks_prevention_enabled AS "config.replay_attacks_prevention_enabled", + c.ratelimit_count AS "config.ratelimit.count", + c.ratelimit_duration AS "config.ratelimit.duration", + c.strategy_type AS "config.strategy.type", + c.strategy_duration AS "config.strategy.duration", + c.ssl_enforce_secure_endpoints as "config.ssl.enforce_secure_endpoints", + c.strategy_retry_count AS "config.strategy.retry_count", + c.signature_header AS "config.signature.header", + c.signature_versions AS "config.signature.versions", + c.meta_events_enabled AS "config.meta_event.is_enabled", + COALESCE(c.meta_events_type, '') AS "config.meta_event.type", + c.meta_events_event_type AS "config.meta_event.event_type", + COALESCE(c.meta_events_url, '') AS "config.meta_event.url", + COALESCE(c.meta_events_secret, '') AS "config.meta_event.secret", + c.meta_events_pub_sub AS "config.meta_event.pub_sub", + p.created_at, + p.updated_at, + p.deleted_at + FROM projects p + LEFT JOIN project_configurations c + ON p.project_configuration_id = c.id + WHERE (p.organisation_id = $1 OR $1 = '') AND p.deleted_at IS NULL ORDER BY p.id; + ` + + updateProjectById = ` + UPDATE projects SET + name = $2, + logo_url = $3, + retained_events = $4, + updated_at = NOW() + WHERE id = $1 AND deleted_at IS NULL; + ` + + deleteProject = ` + UPDATE projects SET + deleted_at = NOW() + WHERE id = $1 AND deleted_at IS NULL; + ` + + deleteProjectEndpoints = ` + UPDATE endpoints SET + deleted_at = NOW() + WHERE project_id = $1 AND deleted_at IS NULL; + ` + + deleteProjectEvents = ` + UPDATE events + SET deleted_at = NOW() + WHERE project_id = $1 AND deleted_at IS NULL; + ` + deleteProjectEndpointSubscriptions = ` + UPDATE subscriptions SET + deleted_at = NOW() + WHERE project_id = $1 AND deleted_at IS NULL; + ` + + projectStatistics = ` + SELECT + (SELECT COUNT(*) FROM subscriptions WHERE project_id = $1 AND deleted_at IS NULL) AS total_subscriptions, + (SELECT COUNT(*) FROM endpoints WHERE project_id = $1 AND deleted_at IS NULL) AS total_endpoints, + (SELECT COUNT(*) FROM sources WHERE project_id = $1 AND deleted_at IS NULL) AS total_sources, + (SELECT COUNT(*) FROM events WHERE project_id = $1 AND deleted_at IS NULL) AS messages_sent; + ` + + updateProjectEndpointStatus = ` + UPDATE endpoints SET status = ?, updated_at = NOW() + WHERE project_id = ? AND status IN (?) AND deleted_at IS NULL RETURNING + id, name, status, owner_id, url, + description, http_timeout, rate_limit, rate_limit_duration, + advanced_signatures, slack_webhook_url, support_email, + app_id, project_id, secrets, created_at, updated_at, + authentication_type AS "authentication.type", + authentication_type_api_key_header_name AS "authentication.api_key.header_name", + authentication_type_api_key_header_value AS "authentication.api_key.header_value"; + ` + + getProjectsWithEventsInTheInterval = ` + SELECT p.id AS id, COUNT(e.id) AS events_count + FROM projects p + LEFT JOIN events e ON p.id = e.project_id + WHERE e.created_at >= NOW() - MAKE_INTERVAL(hours := $1) + AND p.deleted_at IS NULL + GROUP BY p.id + ORDER BY events_count DESC; + ` + + countProjects = ` + SELECT COUNT(*) AS count + FROM projects + WHERE deleted_at IS NULL` +) + +type projectRepo struct { + db *sqlx.DB + hook *hooks.Hook +} + +func NewProjectRepo(db database.Database) datastore.ProjectRepository { + return &projectRepo{db: db.GetDB(), hook: db.GetHook()} +} + +func (o *projectRepo) CountProjects(ctx context.Context) (int64, error) { + var count int64 + err := o.db.GetContext(ctx, &count, countProjects) + if err != nil { + return 0, err + } + + return count, nil +} + +func (p *projectRepo) CreateProject(ctx context.Context, project *datastore.Project) error { + tx, err := p.db.BeginTxx(ctx, &sql.TxOptions{}) + if err != nil { + return err + } + defer rollbackTx(tx) + + rlc := project.Config.GetRateLimitConfig() + sc := project.Config.GetStrategyConfig() + sgc := project.Config.GetSignatureConfig() + me := project.Config.GetMetaEventConfig() + + configID := ulid.Make().String() + result, err := tx.ExecContext(ctx, createProjectConfiguration, + configID, + project.Config.SearchPolicy, + project.Config.MaxIngestSize, + project.Config.ReplayAttacks, + rlc.Count, + rlc.Duration, + sc.Type, + sc.Duration, + sc.RetryCount, + sgc.Header, + sgc.Versions, + project.Config.DisableEndpoint, + me.IsEnabled, + me.Type, + me.EventType, + me.URL, + me.Secret, + me.PubSub, + project.Config.SSL.EnforceSecureEndpoints, + project.Config.MultipleEndpointSubscriptions, + ) + if err != nil { + return err + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + + if rowsAffected < 1 { + return ErrProjectConfigNotCreated + } + + project.ProjectConfigID = configID + proResult, err := tx.ExecContext(ctx, createProject, project.UID, project.Name, project.Type, project.LogoURL, project.OrganisationID, project.ProjectConfigID) + if err != nil { + if strings.Contains(err.Error(), "duplicate") { + return datastore.ErrDuplicateProjectName + } + return err + } + + rowsAffected, err = proResult.RowsAffected() + if err != nil { + return err + } + + if rowsAffected < 1 { + return ErrProjectNotCreated + } + + return tx.Commit() +} + +func (p *projectRepo) LoadProjects(ctx context.Context, f *datastore.ProjectFilter) ([]*datastore.Project, error) { + rows, err := p.db.QueryxContext(ctx, fetchProjects, f.OrgID) + if err != nil { + return nil, err + } + defer closeWithError(rows) + + projects := make([]*datastore.Project, 0) + for rows.Next() { + var proj datastore.Project + + err = rows.StructScan(&proj) + if err != nil { + return nil, err + } + + projects = append(projects, &proj) + } + + return projects, nil +} + +func (p *projectRepo) UpdateProject(ctx context.Context, project *datastore.Project) error { + pro, err := p.FetchProjectByID(ctx, project.UID) + if err != nil { + return err + } + + changelog, err := diff.Diff(pro, project) + if err != nil { + return err + } + + tx, err := p.db.BeginTxx(ctx, &sql.TxOptions{}) + if err != nil { + return err + } + defer rollbackTx(tx) + + pRes, err := tx.ExecContext(ctx, updateProjectById, project.UID, project.Name, project.LogoURL, project.RetainedEvents) + if err != nil { + return fmt.Errorf("update project err: %v", err) + } + + rowsAffected, err := pRes.RowsAffected() + if err != nil { + return err + } + + if rowsAffected < 1 { + return ErrProjectNotUpdated + } + + rlc := project.Config.GetRateLimitConfig() + sc := project.Config.GetStrategyConfig() + sgc := project.Config.GetSignatureConfig() + ssl := project.Config.GetSSLConfig() + me := project.Config.GetMetaEventConfig() + + cRes, err := tx.ExecContext(ctx, updateProjectConfiguration, + project.ProjectConfigID, + project.Config.MaxIngestSize, + project.Config.ReplayAttacks, + rlc.Count, + rlc.Duration, + sc.Type, + sc.Duration, + sc.RetryCount, + sgc.Header, + sgc.Versions, + project.Config.DisableEndpoint, + me.IsEnabled, + me.Type, + me.EventType, + me.URL, + me.Secret, + me.PubSub, + project.Config.SearchPolicy, + ssl.EnforceSecureEndpoints, + project.Config.MultipleEndpointSubscriptions, + ) + if err != nil { + return fmt.Errorf("update project config err: %v", err) + } + + rowsAffected, err = cRes.RowsAffected() + if err != nil { + return err + } + + if rowsAffected < 1 { + return ErrProjectConfigNotUpdated + } + + if !project.Config.DisableEndpoint { + status := []datastore.EndpointStatus{datastore.InactiveEndpointStatus, datastore.PendingEndpointStatus} + query, args, err := sqlx.In(updateProjectEndpointStatus, datastore.ActiveEndpointStatus, project.UID, status) + if err != nil { + return err + } + + query = p.db.Rebind(query) + rows, err := p.db.QueryxContext(ctx, query, args...) + if err != nil { + return err + } + defer closeWithError(rows) + + for rows.Next() { + var endpoint datastore.Endpoint + err := rows.StructScan(&endpoint) + if err != nil { + return err + } + } + } + + err = tx.Commit() + if err != nil { + return err + } + + go p.hook.Fire(datastore.ProjectUpdated, project, changelog) + return nil +} + +func (p *projectRepo) FetchProjectByID(ctx context.Context, id string) (*datastore.Project, error) { + project := &datastore.Project{} + err := p.db.GetContext(ctx, project, fetchProjectById, id) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, datastore.ErrProjectNotFound + } + return nil, err + } + + return project, nil +} + +func (p *projectRepo) FillProjectsStatistics(ctx context.Context, project *datastore.Project) error { + var stats datastore.ProjectStatistics + err := p.db.GetContext(ctx, &stats, projectStatistics, project.UID) + if err != nil { + return err + } + + project.Statistics = &stats + return nil +} + +func (p *projectRepo) DeleteProject(ctx context.Context, id string) error { + tx, err := p.db.BeginTxx(ctx, &sql.TxOptions{}) + if err != nil { + return err + } + defer rollbackTx(tx) + + _, err = tx.ExecContext(ctx, deleteProject, id) + if err != nil { + return err + } + + _, err = tx.ExecContext(ctx, deleteProjectEndpoints, id) + if err != nil { + return err + } + + _, err = tx.ExecContext(ctx, deleteProjectEvents, id) + if err != nil { + return err + } + + _, err = tx.ExecContext(ctx, deleteProjectEndpointSubscriptions, id) + if err != nil { + return err + } + + err = tx.Commit() + if err != nil { + return err + } + + return nil +} + +func (p *projectRepo) GetProjectsWithEventsInTheInterval(ctx context.Context, interval int) ([]datastore.ProjectEvents, error) { + var projects []datastore.ProjectEvents + rows, err := p.db.QueryxContext(ctx, getProjectsWithEventsInTheInterval, interval) + if err != nil { + return nil, err + } + defer closeWithError(rows) + + for rows.Next() { + var proj datastore.ProjectEvents + + err = rows.StructScan(&proj) + if err != nil { + return nil, err + } + + projects = append(projects, proj) + } + + return projects, nil +} diff --git a/database/sqlite3/project_test.go b/database/sqlite3/project_test.go new file mode 100644 index 0000000000..ebd3a764f5 --- /dev/null +++ b/database/sqlite3/project_test.go @@ -0,0 +1,500 @@ +//go:build integration +// +build integration + +package sqlite3 + +import ( + "context" + "encoding/json" + "fmt" + "testing" + "time" + + "github.com/dchest/uniuri" + "github.com/oklog/ulid/v2" + + "github.com/frain-dev/convoy/database" + "github.com/frain-dev/convoy/datastore" + "github.com/frain-dev/convoy/pkg/httpheader" + "github.com/stretchr/testify/require" + "gopkg.in/guregu/null.v4" +) + +func Test_FetchProjectByID(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + org := seedOrg(t, db) + projectRepo := NewProjectRepo(db) + + newProject := &datastore.Project{ + UID: ulid.Make().String(), + Name: "Yet another project", + LogoURL: "s3.com/dsiuirueiy", + OrganisationID: org.UID, + Type: datastore.IncomingProject, + Config: &datastore.DefaultProjectConfig, + } + + require.NoError(t, projectRepo.CreateProject(context.Background(), newProject)) + + dbProject, err := projectRepo.FetchProjectByID(context.Background(), newProject.UID) + require.NoError(t, err) + + require.NotEmpty(t, dbProject.CreatedAt) + require.NotEmpty(t, dbProject.UpdatedAt) + + dbProject.CreatedAt = time.Time{} + dbProject.UpdatedAt = time.Time{} + for i := range dbProject.Config.Signature.Versions { + version := &dbProject.Config.Signature.Versions[i] + require.NotEmpty(t, version.CreatedAt) + version.CreatedAt = time.Time{} + } + + for i := range newProject.Config.Signature.Versions { + version := &newProject.Config.Signature.Versions[i] + require.NotEmpty(t, version.CreatedAt) + version.CreatedAt = time.Time{} + } + + require.Equal(t, newProject, dbProject) +} + +func TestCountProjects(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + projectRepository := NewProjectRepo(db) + org := seedOrg(t, db) + + count := 10 + for i := 0; i < count; i++ { + project := &datastore.Project{ + UID: ulid.Make().String(), + Name: ulid.Make().String(), + OrganisationID: org.UID, + Type: datastore.IncomingProject, + Config: &datastore.DefaultProjectConfig, + } + + err := projectRepository.CreateProject(context.Background(), project) + require.NoError(t, err) + } + + projectCount, err := projectRepository.CountProjects(context.Background()) + require.NoError(t, err) + require.Equal(t, int64(count), projectCount) +} + +func Test_CreateProject(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + projectRepo := NewProjectRepo(db) + + org := seedOrg(t, db) + + const name = "test_project" + + project := &datastore.Project{ + UID: ulid.Make().String(), + Name: name, + OrganisationID: org.UID, + Type: datastore.IncomingProject, + Config: &datastore.DefaultProjectConfig, + } + + err := projectRepo.CreateProject(context.Background(), project) + require.NoError(t, err) + require.NotEmpty(t, project.ProjectConfigID) + + projectWithExistingName := &datastore.Project{ + UID: ulid.Make().String(), + Name: name, + OrganisationID: org.UID, + Config: &datastore.DefaultProjectConfig, + } + + // should not create project with same name + err = projectRepo.CreateProject(context.Background(), projectWithExistingName) + require.Equal(t, datastore.ErrDuplicateProjectName, err) + + // delete exisiting project + err = projectRepo.DeleteProject(context.Background(), project.UID) + require.NoError(t, err) + + // can now create project with same name + err = projectRepo.CreateProject(context.Background(), projectWithExistingName) + require.NoError(t, err) + + projectInDiffOrg := &datastore.Project{ + UID: ulid.Make().String(), + Name: name, + OrganisationID: seedOrg(t, db).UID, + Config: &datastore.DefaultProjectConfig, + } + + // should create project with same name in diff org + err = projectRepo.CreateProject(context.Background(), projectInDiffOrg) + require.NoError(t, err) +} + +func Test_UpdateProject(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + projectRepo := NewProjectRepo(db) + + org := seedOrg(t, db) + + const name = "test_project" + + project := &datastore.Project{ + UID: ulid.Make().String(), + Name: name, + OrganisationID: org.UID, + Config: &datastore.DefaultProjectConfig, + } + + err := projectRepo.CreateProject(context.Background(), project) + require.NoError(t, err) + + updatedProject := &datastore.Project{ + UID: project.UID, + Name: "convoy", + LogoURL: "https:/oilvmm.com", + OrganisationID: project.OrganisationID, + ProjectConfigID: project.ProjectConfigID, // TODO(all): if i comment this line this test never exits, weird problem + Config: &datastore.ProjectConfig{ + MaxIngestSize: 8483, + ReplayAttacks: true, + RateLimit: &datastore.RateLimitConfiguration{ + Count: 8773, + Duration: 7766, + }, + SSL: &datastore.SSLConfiguration{EnforceSecureEndpoints: false}, + Strategy: &datastore.StrategyConfiguration{ + Type: datastore.ExponentialStrategyProvider, + Duration: 2434, + RetryCount: 5737, + }, + Signature: &datastore.SignatureConfiguration{ + Header: "f888fbfb", + Versions: []datastore.SignatureVersion{ + { + UID: ulid.Make().String(), + Hash: "SHA512", + Encoding: datastore.HexEncoding, + CreatedAt: time.Now(), + }, + }, + }, + MetaEvent: &datastore.MetaEventConfiguration{ + IsEnabled: false, + }, + }, + RetainedEvents: 300, + } + + err = projectRepo.UpdateProject(context.Background(), updatedProject) + require.NoError(t, err) + + dbProject, err := projectRepo.FetchProjectByID(context.Background(), project.UID) + require.NoError(t, err) + + require.NotEmpty(t, dbProject.CreatedAt) + require.NotEmpty(t, dbProject.UpdatedAt) + + dbProject.CreatedAt = time.Time{} + dbProject.UpdatedAt = time.Time{} + + for i := range dbProject.Config.Signature.Versions { + version := &dbProject.Config.Signature.Versions[i] + require.NotEmpty(t, version.CreatedAt) + version.CreatedAt = time.Time{} + } + + for i := range updatedProject.Config.Signature.Versions { + version := &updatedProject.Config.Signature.Versions[i] + require.NotEmpty(t, version.CreatedAt) + version.CreatedAt = time.Time{} + } + + require.Equal(t, updatedProject, dbProject) +} + +func Test_LoadProjects(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + org := seedOrg(t, db) + projectRepo := NewProjectRepo(db) + + for i := 0; i < 3; i++ { + project := &datastore.Project{ + UID: ulid.Make().String(), + Name: fmt.Sprintf("%s-project", ulid.Make().String()), + OrganisationID: org.UID, + Config: &datastore.DefaultProjectConfig, + } + + err := projectRepo.CreateProject(context.Background(), project) + require.NoError(t, err) + } + + for i := 0; i < 4; i++ { + project := &datastore.Project{ + UID: ulid.Make().String(), + Name: fmt.Sprintf("%s-project", ulid.Make().String()), + OrganisationID: seedOrg(t, db).UID, + Config: &datastore.DefaultProjectConfig, + } + + err := projectRepo.CreateProject(context.Background(), project) + require.NoError(t, err) + } + + projects, err := projectRepo.LoadProjects(context.Background(), &datastore.ProjectFilter{OrgID: org.UID}) + require.NoError(t, err) + + require.True(t, len(projects) == 3) +} + +func Test_FillProjectStatistics(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + org := seedOrg(t, db) + projectRepo := NewProjectRepo(db) + + project1 := &datastore.Project{ + UID: ulid.Make().String(), + Name: "project1", + Config: &datastore.DefaultProjectConfig, + OrganisationID: org.UID, + } + + project2 := &datastore.Project{ + UID: ulid.Make().String(), + Name: "project2", + Config: &datastore.DefaultProjectConfig, + OrganisationID: org.UID, + } + + err := projectRepo.CreateProject(context.Background(), project1) + require.NoError(t, err) + + err = projectRepo.CreateProject(context.Background(), project2) + require.NoError(t, err) + + endpoint1 := &datastore.Endpoint{ + ProjectID: project1.UID, + Url: "http://google.com", + Name: "test_endpoint", + Secrets: []datastore.Secret{ + { + Value: "12345", + ExpiresAt: null.Time{}, + }, + }, + HttpTimeout: 60, + RateLimit: 3000, + Events: 0, + Status: "", + RateLimitDuration: 10, + Authentication: nil, + CreatedAt: time.Time{}, + UpdatedAt: time.Time{}, + DeletedAt: null.Time{}, + } + + endpoint2 := &datastore.Endpoint{ + UID: ulid.Make().String(), + ProjectID: project2.UID, + Secrets: datastore.Secrets{ + {UID: ulid.Make().String()}, + }, + } + + endpointRepo := NewEndpointRepo(db) + err = endpointRepo.CreateEndpoint(context.Background(), endpoint1, project1.UID) + require.NoError(t, err) + + err = endpointRepo.CreateEndpoint(context.Background(), endpoint2, project2.UID) + require.NoError(t, err) + + source1 := &datastore.Source{ + UID: ulid.Make().String(), + ProjectID: project1.UID, + Name: "Convoy-Prod", + MaskID: uniuri.NewLen(16), + Type: datastore.HTTPSource, + Verifier: &datastore.VerifierConfig{}, + } + + err = NewSourceRepo(db).CreateSource(context.Background(), source1) + require.NoError(t, err) + + subscription := &datastore.Subscription{ + UID: ulid.Make().String(), + Name: "Subscription", + Type: datastore.SubscriptionTypeAPI, + ProjectID: project2.UID, + EndpointID: endpoint1.UID, + AlertConfig: &datastore.DefaultAlertConfig, + RetryConfig: &datastore.DefaultRetryConfig, + FilterConfig: &datastore.FilterConfiguration{ + EventTypes: []string{"some.event"}, + Filter: datastore.FilterSchema{ + Headers: datastore.M{}, + Body: datastore.M{}, + }, + }, + } + + err = NewSubscriptionRepo(db).CreateSubscription(context.Background(), project2.UID, subscription) + require.NoError(t, err) + + err = projectRepo.FillProjectsStatistics(context.Background(), project1) + require.NoError(t, err) + + require.Equal(t, datastore.ProjectStatistics{ + MessagesSent: 0, + TotalEndpoints: 1, + TotalSources: 1, + TotalSubscriptions: 0, + }, *project1.Statistics) + + err = projectRepo.FillProjectsStatistics(context.Background(), project2) + require.NoError(t, err) + + require.Equal(t, datastore.ProjectStatistics{ + MessagesSent: 0, + TotalEndpoints: 1, + TotalSources: 0, + TotalSubscriptions: 1, + }, *project2.Statistics) +} + +func Test_DeleteProject(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + org := seedOrg(t, db) + projectRepo := NewProjectRepo(db) + + project := &datastore.Project{ + UID: ulid.Make().String(), + Name: "project", + Config: &datastore.DefaultProjectConfig, + OrganisationID: org.UID, + } + + err := projectRepo.CreateProject(context.Background(), project) + require.NoError(t, err) + + endpoint := &datastore.Endpoint{ + ProjectID: project.UID, + Url: "http://google.com", + Name: "test_endpoint", + Secrets: []datastore.Secret{ + { + Value: "12345", + ExpiresAt: null.Time{}, + }, + }, + HttpTimeout: 60, + RateLimit: 3000, + } + + endpointRepo := NewEndpointRepo(db) + err = endpointRepo.CreateEndpoint(context.Background(), endpoint, project.UID) + require.NoError(t, err) + + event := &datastore.Event{ + ProjectID: endpoint.ProjectID, + Endpoints: []string{endpoint.UID}, + Headers: httpheader.HTTPHeader{}, + Data: json.RawMessage(`{ + "userId": 1, + "id": 1, + "title": "delectus aut autem", + "completed": false + }`), + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + err = NewEventRepo(db).CreateEvent(context.Background(), event) + require.NoError(t, err) + + sub := &datastore.Subscription{ + Name: "test_sub", + Type: datastore.SubscriptionTypeAPI, + ProjectID: project.UID, + AlertConfig: &datastore.DefaultAlertConfig, + RetryConfig: &datastore.DefaultRetryConfig, + FilterConfig: &datastore.FilterConfiguration{ + EventTypes: []string{"*"}, + Filter: datastore.FilterSchema{ + Headers: datastore.M{}, + Body: datastore.M{}, + }, + }, + RateLimitConfig: &datastore.DefaultRateLimitConfig, + } + + err = NewSubscriptionRepo(db).CreateSubscription(context.Background(), project.UID, sub) + require.NoError(t, err) + + err = projectRepo.DeleteProject(context.Background(), project.UID) + require.NoError(t, err) + + _, err = projectRepo.FetchProjectByID(context.Background(), project.UID) + require.Equal(t, datastore.ErrProjectNotFound, err) + + _, err = NewEventRepo(db).FindEventByID(context.Background(), event.ProjectID, event.UID) + require.Equal(t, datastore.ErrEventNotFound, err) + + _, err = NewEndpointRepo(db).FindEndpointByID(context.Background(), project.UID, endpoint.UID) + require.Equal(t, datastore.ErrEndpointNotFound, err) + + _, err = NewSubscriptionRepo(db).FindSubscriptionByID(context.Background(), project.UID, sub.UID) + require.Equal(t, datastore.ErrSubscriptionNotFound, err) +} + +func seedOrg(t *testing.T, db database.Database) *datastore.Organisation { + user := seedUser(t, db) + + org := &datastore.Organisation{ + UID: ulid.Make().String(), + Name: ulid.Make().String() + "-new_org", + OwnerID: user.UID, + CustomDomain: null.NewString(ulid.Make().String(), true), + AssignedDomain: null.NewString(ulid.Make().String(), true), + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + err := NewOrgRepo(db).CreateOrganisation(context.Background(), org) + require.NoError(t, err) + + return org +} + +func seedProject(t *testing.T, db database.Database) *datastore.Project { + p := &datastore.Project{ + UID: ulid.Make().String(), + Name: "Yet another project", + LogoURL: "s3.com/dsiuirueiy", + OrganisationID: seedOrg(t, db).UID, + Type: datastore.IncomingProject, + Config: &datastore.DefaultProjectConfig, + } + + err := NewProjectRepo(db).CreateProject(context.Background(), p) + require.NoError(t, err) + + return p +} diff --git a/database/sqlite3/source.go b/database/sqlite3/source.go new file mode 100644 index 0000000000..2635270f81 --- /dev/null +++ b/database/sqlite3/source.go @@ -0,0 +1,559 @@ +package sqlite3 + +import ( + "context" + "database/sql" + "errors" + "fmt" + "github.com/lib/pq" + + "github.com/oklog/ulid/v2" + + "github.com/frain-dev/convoy/database" + "github.com/frain-dev/convoy/datastore" + "github.com/frain-dev/convoy/util" + "github.com/jmoiron/sqlx" +) + +const ( + createSource = ` + INSERT INTO sources (id,source_verifier_id,name,type,mask_id,provider,is_disabled,forward_headers,project_id, + pub_sub,custom_response_body,custom_response_content_type,idempotency_keys, body_function, header_function) + VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15); + ` + + createSourceVerifier = ` + INSERT INTO source_verifiers ( + id,type,basic_username,basic_password, + api_key_header_name,api_key_header_value, + hmac_hash,hmac_header,hmac_secret,hmac_encoding + ) + VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10); + ` + + updateSourceById = ` + UPDATE sources SET + name= $2, + type=$3, + mask_id=$4, + provider = $5, + is_disabled=$6, + forward_headers=$7, + project_id =$8, + pub_sub= $9, + custom_response_body = $10, + custom_response_content_type = $11, + idempotency_keys = $12, + body_function = $13, + header_function = $14, + updated_at = NOW() + WHERE id = $1 AND deleted_at IS NULL ; + ` + + updateSourceVerifierById = ` + UPDATE source_verifiers SET + type=$2, + basic_username=$3, + basic_password=$4, + api_key_header_name=$5, + api_key_header_value=$6, + hmac_hash=$7, + hmac_header=$8, + hmac_secret=$9, + hmac_encoding=$10, + updated_at = NOW() + WHERE id = $1 AND deleted_at IS NULL; + ` + + baseFetchSource = ` + SELECT + s.id, + s.name, + s.type, + s.pub_sub, + s.mask_id, + s.provider, + s.is_disabled, + s.forward_headers, + s.idempotency_keys, + s.project_id, + s.body_function, + s.header_function, + COALESCE(s.source_verifier_id, '') AS source_verifier_id, + COALESCE(s.custom_response_body, '') AS "custom_response.body", + COALESCE(s.custom_response_content_type, '') AS "custom_response.content_type", + COALESCE(sv.type, '') AS "verifier.type", + COALESCE(sv.basic_username, '') AS "verifier.basic_auth.username", + COALESCE(sv.basic_password, '') AS "verifier.basic_auth.password", + COALESCE(sv.api_key_header_name, '') AS "verifier.api_key.header_name", + COALESCE(sv.api_key_header_value, '') AS "verifier.api_key.header_value", + COALESCE(sv.hmac_hash, '') AS "verifier.hmac.hash", + COALESCE(sv.hmac_header, '') AS "verifier.hmac.header", + COALESCE(sv.hmac_secret, '') AS "verifier.hmac.secret", + COALESCE(sv.hmac_encoding, '') AS "verifier.hmac.encoding", + s.created_at, + s.updated_at + FROM sources AS s + LEFT JOIN source_verifiers sv ON s.source_verifier_id = sv.id + WHERE s.deleted_at IS NULL + ` + + fetchPubSubSources = ` + SELECT + id, + name, + type, + pub_sub, + mask_id, + provider, + is_disabled, + forward_headers, + idempotency_keys, + body_function, + header_function, + project_id, + created_at, + updated_at + FROM sources + WHERE type = '%s' AND project_id IN (:project_ids) AND deleted_at IS NULL + AND (id <= :cursor OR :cursor = '') + ORDER BY id DESC + LIMIT :limit + ` + + deleteSource = ` + UPDATE sources SET + deleted_at = NOW() + WHERE id = $1 AND project_id = $2 AND deleted_at IS NULL; + ` + + deleteSourceVerifier = ` + UPDATE source_verifiers SET + deleted_at = NOW() + WHERE id = $1 AND deleted_at IS NULL; + ` + + deleteSourceSubscription = ` + UPDATE subscriptions SET + deleted_at = NOW() + WHERE source_id = $1 AND project_id = $2 AND deleted_at IS NULL; + ` + + fetchSourcesPagedFilter = ` + AND (s.type = :type OR :type = '') + AND (s.provider = :provider OR :provider = '') + AND s.name ILIKE :query + AND s.project_id = :project_id + ` + + fetchSourcesPagedForward = ` + %s + %s + AND s.id <= :cursor + GROUP BY s.id, sv.id + ORDER BY s.id DESC + LIMIT :limit + ` + + fetchSourcesPagedBackward = ` + WITH sources AS ( + %s + %s + AND s.id >= :cursor + GROUP BY s.id, sv.id + ORDER BY s.id ASC + LIMIT :limit + ) + + SELECT * FROM sources ORDER BY id DESC + ` + + countPrevSources = ` + SELECT COUNT(DISTINCT(s.id)) AS count + FROM sources s + WHERE s.deleted_at IS NULL + %s + AND s.id > :cursor GROUP BY s.id ORDER BY s.id DESC LIMIT 1` +) + +var ( + fetchSource = baseFetchSource + ` AND %s = $1;` + fetchSourceByName = baseFetchSource + ` AND %s = $1 AND %s = $2;` +) + +var ( + ErrSourceNotCreated = errors.New("source could not be created") + ErrSourceVerifierNotCreated = errors.New("source verifier could not be created") + ErrSourceVerifierNotUpdated = errors.New("source verifier could not be updated") + ErrSourceNotUpdated = errors.New("source could not be updated") +) + +type sourceRepo struct { + db *sqlx.DB +} + +func NewSourceRepo(db database.Database) datastore.SourceRepository { + return &sourceRepo{db: db.GetDB()} +} + +func (s *sourceRepo) CreateSource(ctx context.Context, source *datastore.Source) error { + var sourceVerifierID *string + tx, err := s.db.BeginTxx(ctx, &sql.TxOptions{}) + if err != nil { + return err + } + defer rollbackTx(tx) + + var ( + hmac datastore.HMac + basic datastore.BasicAuth + apiKey datastore.ApiKey + ) + + switch source.Verifier.Type { + case datastore.APIKeyVerifier: + apiKey = *source.Verifier.ApiKey + case datastore.BasicAuthVerifier: + basic = *source.Verifier.BasicAuth + case datastore.HMacVerifier: + hmac = *source.Verifier.HMac + } + + if !util.IsStringEmpty(string(source.Verifier.Type)) { + id := ulid.Make().String() + sourceVerifierID = &id + + result2, err := tx.ExecContext( + ctx, createSourceVerifier, sourceVerifierID, source.Verifier.Type, basic.UserName, basic.Password, + apiKey.HeaderName, apiKey.HeaderValue, hmac.Hash, hmac.Header, hmac.Secret, hmac.Encoding, + ) + if err != nil { + return err + } + + rowsAffected, err := result2.RowsAffected() + if err != nil { + return err + } + + if rowsAffected < 1 { + return ErrSourceVerifierNotCreated + } + } + + if !util.IsStringEmpty(string(source.Verifier.Type)) { + source.VerifierID = *sourceVerifierID + } + + result1, err := tx.ExecContext( + ctx, createSource, source.UID, sourceVerifierID, source.Name, source.Type, source.MaskID, + source.Provider, source.IsDisabled, pq.Array(source.ForwardHeaders), source.ProjectID, + source.PubSub, source.CustomResponse.Body, source.CustomResponse.ContentType, + source.IdempotencyKeys, source.BodyFunction, source.HeaderFunction, + ) + if err != nil { + return err + } + + rowsAffected, err := result1.RowsAffected() + if err != nil { + return err + } + + if rowsAffected < 1 { + return ErrSourceNotCreated + } + + err = tx.Commit() + if err != nil { + return err + } + + return nil +} + +func (s *sourceRepo) UpdateSource(ctx context.Context, projectID string, source *datastore.Source) error { + tx, err := s.db.BeginTxx(ctx, &sql.TxOptions{}) + if err != nil { + return err + } + defer rollbackTx(tx) + + result, err := tx.ExecContext( + ctx, updateSourceById, source.UID, source.Name, source.Type, source.MaskID, + source.Provider, source.IsDisabled, source.ForwardHeaders, projectID, + source.PubSub, source.CustomResponse.Body, source.CustomResponse.ContentType, + source.IdempotencyKeys, source.BodyFunction, source.HeaderFunction, + ) + if err != nil { + return err + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + if rowsAffected < 1 { + return ErrSourceNotUpdated + } + + var ( + hmac datastore.HMac + basic datastore.BasicAuth + apiKey datastore.ApiKey + ) + + switch source.Verifier.Type { + case datastore.APIKeyVerifier: + apiKey = *source.Verifier.ApiKey + case datastore.BasicAuthVerifier: + basic = *source.Verifier.BasicAuth + case datastore.HMacVerifier: + hmac = *source.Verifier.HMac + } + + if !util.IsStringEmpty(string(source.Verifier.Type)) { + result2, err := tx.ExecContext( + ctx, updateSourceVerifierById, source.VerifierID, source.Verifier.Type, basic.UserName, basic.Password, + apiKey.HeaderName, apiKey.HeaderValue, hmac.Hash, hmac.Header, hmac.Secret, hmac.Encoding, + ) + if err != nil { + return err + } + + rowsAffected, err = result2.RowsAffected() + if err != nil { + return err + } + + if rowsAffected < 1 { + return ErrSourceVerifierNotUpdated + } + } + + err = tx.Commit() + if err != nil { + return err + } + + return nil +} + +func (s *sourceRepo) FindSourceByID(ctx context.Context, projectId string, id string) (*datastore.Source, error) { + source := &datastore.Source{} + err := s.db.QueryRowxContext(ctx, fmt.Sprintf(fetchSource, "s.id"), id).StructScan(source) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, datastore.ErrSourceNotFound + } + return nil, err + } + + return source, nil +} + +func (s *sourceRepo) FindSourceByName(ctx context.Context, projectID string, name string) (*datastore.Source, error) { + source := &datastore.Source{} + err := s.db.QueryRowxContext(ctx, fmt.Sprintf(fetchSourceByName, "s.project_id", "s.name"), projectID, name).StructScan(source) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, datastore.ErrSourceNotFound + } + return nil, err + } + + return source, nil +} + +func (s *sourceRepo) FindSourceByMaskID(ctx context.Context, maskID string) (*datastore.Source, error) { + source := &datastore.Source{} + err := s.db.QueryRowxContext(ctx, fmt.Sprintf(fetchSource, "s.mask_id"), maskID).StructScan(source) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, datastore.ErrSourceNotFound + } + return nil, err + } + + return source, nil +} + +func (s *sourceRepo) DeleteSourceByID(ctx context.Context, projectId string, id, sourceVerifierId string) error { + tx, err := s.db.BeginTxx(ctx, &sql.TxOptions{}) + if err != nil { + return err + } + defer rollbackTx(tx) + + _, err = tx.ExecContext(ctx, deleteSourceVerifier, sourceVerifierId) + if err != nil { + return err + } + + _, err = tx.ExecContext(ctx, deleteSource, id, projectId) + if err != nil { + return err + } + + _, err = tx.ExecContext(ctx, deleteSourceSubscription, id, projectId) + if err != nil { + return err + } + + err = tx.Commit() + if err != nil { + return err + } + + return nil +} + +func (s *sourceRepo) LoadSourcesPaged(ctx context.Context, projectID string, filter *datastore.SourceFilter, pageable datastore.Pageable) ([]datastore.Source, datastore.PaginationData, error) { + arg := map[string]interface{}{ + "type": filter.Type, + "provider": filter.Provider, + "project_id": projectID, + "limit": pageable.Limit(), + "cursor": pageable.Cursor(), + "query": "%" + filter.Query + "%", + } + + var query string + if pageable.Direction == datastore.Next { + query = fetchSourcesPagedForward + } else { + query = fetchSourcesPagedBackward + } + + query = fmt.Sprintf(query, baseFetchSource, fetchSourcesPagedFilter) + + query, args, err := sqlx.Named(query, arg) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + query, args, err = sqlx.In(query, args...) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + query = s.db.Rebind(query) + + rows, err := s.db.QueryxContext(ctx, query, args...) + if err != nil { + return nil, datastore.PaginationData{}, err + } + defer closeWithError(rows) + + sources := make([]datastore.Source, 0) + for rows.Next() { + var source datastore.Source + err = rows.StructScan(&source) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + sources = append(sources, source) + } + + var count datastore.PrevRowCount + if len(sources) > 0 { + var countQuery string + var qargs []interface{} + first := sources[0] + qarg := arg + qarg["cursor"] = first.UID + + cq := fmt.Sprintf(countPrevSources, fetchSourcesPagedFilter) + countQuery, qargs, err = sqlx.Named(cq, qarg) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + countQuery = s.db.Rebind(countQuery) + + // count the row number before the first row + rows, err := s.db.QueryxContext(ctx, countQuery, qargs...) + if err != nil { + return nil, datastore.PaginationData{}, err + } + defer closeWithError(rows) + + if rows.Next() { + err = rows.StructScan(&count) + if err != nil { + return nil, datastore.PaginationData{}, err + } + } + } + + ids := make([]string, len(sources)) + for i := range sources { + ids[i] = sources[i].UID + } + + if len(sources) > pageable.PerPage { + sources = sources[:len(sources)-1] + } + + pagination := &datastore.PaginationData{PrevRowCount: count} + pagination = pagination.Build(pageable, ids) + + return sources, *pagination, nil +} + +func (s *sourceRepo) LoadPubSubSourcesByProjectIDs(ctx context.Context, projectIDs []string, pageable datastore.Pageable) ([]datastore.Source, datastore.PaginationData, error) { + arg := map[string]interface{}{ + "project_ids": projectIDs, + "limit": pageable.Limit(), + "cursor": pageable.Cursor(), + } + + query := fmt.Sprintf(fetchPubSubSources, datastore.PubSubSource) + + query, args, err := sqlx.Named(query, arg) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + query, args, err = sqlx.In(query, args...) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + query = s.db.Rebind(query) + + rows, err := s.db.QueryxContext(ctx, query, args...) + if err != nil { + return nil, datastore.PaginationData{}, err + } + defer closeWithError(rows) + + sources := make([]datastore.Source, 0) + for rows.Next() { + var source datastore.Source + err = rows.StructScan(&source) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + sources = append(sources, source) + } + + // Bypass pagination.Build here since we're only dealing with forward paging here + var hasNext bool + var cursor string + if len(sources) > pageable.PerPage { + cursor = sources[len(sources)-1].UID + sources = sources[:len(sources)-1] + hasNext = true + } + + pagination := &datastore.PaginationData{ + PerPage: int64(pageable.PerPage), + HasNextPage: hasNext, + NextPageCursor: cursor, + } + + return sources, *pagination, nil +} diff --git a/database/sqlite3/source_test.go b/database/sqlite3/source_test.go new file mode 100644 index 0000000000..b3da30912e --- /dev/null +++ b/database/sqlite3/source_test.go @@ -0,0 +1,288 @@ +//go:build integration +// +build integration + +package sqlite3 + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/dchest/uniuri" + "github.com/frain-dev/convoy/database" + "github.com/frain-dev/convoy/datastore" + "github.com/oklog/ulid/v2" + "github.com/stretchr/testify/require" +) + +func Test_CreateSource(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + sourceRepo := NewSourceRepo(db) + source := generateSource(t, db) + + require.NoError(t, sourceRepo.CreateSource(context.Background(), source)) + + newSource, err := sourceRepo.FindSourceByID(context.Background(), source.ProjectID, source.UID) + require.NoError(t, err) + + newSource.CreatedAt = time.Time{} + newSource.UpdatedAt = time.Time{} + + require.Equal(t, source, newSource) +} + +func Test_FindSourceByID(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + sourceRepo := NewSourceRepo(db) + source := generateSource(t, db) + + _, err := sourceRepo.FindSourceByID(context.Background(), source.ProjectID, source.UID) + + require.Error(t, err) + require.True(t, errors.Is(err, datastore.ErrSourceNotFound)) + + require.NoError(t, sourceRepo.CreateSource(context.Background(), source)) + + newSource, err := sourceRepo.FindSourceByID(context.Background(), source.ProjectID, source.UID) + require.NoError(t, err) + + newSource.CreatedAt = time.Time{} + newSource.UpdatedAt = time.Time{} + + require.Equal(t, source, newSource) +} + +func Test_FindSourceByName(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + sourceRepo := NewSourceRepo(db) + source := generateSource(t, db) + + _, err := sourceRepo.FindSourceByName(context.Background(), source.ProjectID, source.Name) + + require.Error(t, err) + require.True(t, errors.Is(err, datastore.ErrSourceNotFound)) + + require.NoError(t, sourceRepo.CreateSource(context.Background(), source)) + + newSource, err := sourceRepo.FindSourceByName(context.Background(), source.ProjectID, source.Name) + require.NoError(t, err) + + newSource.CreatedAt = time.Time{} + newSource.UpdatedAt = time.Time{} + + require.Equal(t, source, newSource) +} + +func Test_FindSourceByMaskID(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + sourceRepo := NewSourceRepo(db) + source := generateSource(t, db) + + _, err := sourceRepo.FindSourceByMaskID(context.Background(), source.MaskID) + + require.Error(t, err) + require.True(t, errors.Is(err, datastore.ErrSourceNotFound)) + + require.NoError(t, sourceRepo.CreateSource(context.Background(), source)) + + newSource, err := sourceRepo.FindSourceByMaskID(context.Background(), source.MaskID) + require.NoError(t, err) + + newSource.CreatedAt = time.Time{} + newSource.UpdatedAt = time.Time{} + + require.Equal(t, source, newSource) +} + +func Test_UpdateSource(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + sourceRepo := NewSourceRepo(db) + source := generateSource(t, db) + + require.NoError(t, sourceRepo.CreateSource(context.Background(), source)) + + name := "Convoy-Dev" + source.Name = name + source.IsDisabled = true + source.CustomResponse = datastore.CustomResponse{ + Body: "/ref/", + ContentType: "application/json", + } + require.NoError(t, sourceRepo.UpdateSource(context.Background(), source.ProjectID, source)) + + newSource, err := sourceRepo.FindSourceByID(context.Background(), source.ProjectID, source.UID) + require.NoError(t, err) + + newSource.CreatedAt = time.Time{} + newSource.UpdatedAt = time.Time{} + + require.Equal(t, source, newSource) +} + +func Test_DeleteSource(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + sourceRepo := NewSourceRepo(db) + subRepo := NewSubscriptionRepo(db) + source := generateSource(t, db) + + require.NoError(t, sourceRepo.CreateSource(context.Background(), source)) + + sub := &datastore.Subscription{ + Name: "test_sub", + Type: datastore.SubscriptionTypeAPI, + ProjectID: source.ProjectID, + SourceID: source.UID, + AlertConfig: &datastore.DefaultAlertConfig, + RetryConfig: &datastore.DefaultRetryConfig, + FilterConfig: &datastore.FilterConfiguration{ + EventTypes: []string{"*"}, + Filter: datastore.FilterSchema{ + Headers: datastore.M{}, + Body: datastore.M{}, + }, + }, + RateLimitConfig: &datastore.DefaultRateLimitConfig, + } + + err := subRepo.CreateSubscription(context.Background(), source.ProjectID, sub) + require.NoError(t, err) + + _, err = sourceRepo.FindSourceByID(context.Background(), source.ProjectID, source.UID) + require.NoError(t, err) + + require.NoError(t, sourceRepo.DeleteSourceByID(context.Background(), source.ProjectID, source.UID, source.VerifierID)) + + _, err = sourceRepo.FindSourceByID(context.Background(), source.ProjectID, source.UID) + require.Error(t, err) + require.True(t, errors.Is(err, datastore.ErrSourceNotFound)) + + _, err = subRepo.FindSubscriptionByID(context.Background(), source.ProjectID, sub.UID) + require.Error(t, err) + require.True(t, errors.Is(err, datastore.ErrSubscriptionNotFound)) +} + +func Test_LoadSourcesPaged(t *testing.T) { + type Expected struct { + paginationData datastore.PaginationData + } + + tests := []struct { + name string + pageData datastore.Pageable + count int + expected Expected + }{ + { + name: "Load Sources Paged - 10 records", + pageData: datastore.Pageable{PerPage: 3}, + count: 10, + expected: Expected{ + paginationData: datastore.PaginationData{ + PerPage: 3, + }, + }, + }, + + { + name: "Load Sources Paged - 12 records", + pageData: datastore.Pageable{PerPage: 4}, + count: 12, + expected: Expected{ + paginationData: datastore.PaginationData{ + PerPage: 4, + }, + }, + }, + + { + name: "Load Sources Paged - 5 records", + pageData: datastore.Pageable{PerPage: 3}, + count: 5, + expected: Expected{ + paginationData: datastore.PaginationData{ + PerPage: 3, + }, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + sourceRepo := NewSourceRepo(db) + project := seedProject(t, db) + + for i := 0; i < tc.count; i++ { + source := &datastore.Source{ + UID: ulid.Make().String(), + ProjectID: project.UID, + Name: "Convoy-Prod", + MaskID: uniuri.NewLen(16), + Type: datastore.HTTPSource, + Verifier: &datastore.VerifierConfig{ + Type: datastore.HMacVerifier, + HMac: &datastore.HMac{ + Header: "X-Paystack-Signature", + Hash: "SHA512", + Secret: "Paystack Secret", + }, + }, + } + require.NoError(t, sourceRepo.CreateSource(context.Background(), source)) + } + + _, pageable, err := sourceRepo.LoadSourcesPaged(context.Background(), project.UID, &datastore.SourceFilter{}, tc.pageData) + + require.NoError(t, err) + require.Equal(t, tc.expected.paginationData.PerPage, pageable.PerPage) + }) + } +} + +func generateSource(t *testing.T, db database.Database) *datastore.Source { + project := seedProject(t, db) + + return &datastore.Source{ + UID: ulid.Make().String(), + ProjectID: project.UID, + Name: "Convoy-Prod", + MaskID: uniuri.NewLen(16), + CustomResponse: datastore.CustomResponse{ + Body: "/dover/", + ContentType: "text/plain", + }, + Type: datastore.HTTPSource, + Verifier: &datastore.VerifierConfig{ + Type: datastore.HMacVerifier, + HMac: &datastore.HMac{ + Header: "X-Paystack-Signature", + Hash: "SHA512", + Secret: "Paystack Secret", + }, + ApiKey: &datastore.ApiKey{}, + BasicAuth: &datastore.BasicAuth{}, + }, + } +} + +func seedSource(t *testing.T, db database.Database) *datastore.Source { + source := generateSource(t, db) + + require.NoError(t, NewSourceRepo(db).CreateSource(context.Background(), source)) + return source +} diff --git a/database/sqlite3/sqlite3.go b/database/sqlite3/sqlite3.go index 85e141926a..1bacc629a6 100644 --- a/database/sqlite3/sqlite3.go +++ b/database/sqlite3/sqlite3.go @@ -15,6 +15,10 @@ import ( const pkgName = "sqlite3" +type DbCtxKey string + +const TransactionCtx DbCtxKey = "transaction" + type Sqlite struct { dbx *sqlx.DB hook *hooks.Hook @@ -82,3 +86,20 @@ func closeWithError(closer io.Closer) { fmt.Printf("%v, an error occurred while closing the client", err) } } + +func GetTx(ctx context.Context, db *sqlx.DB) (*sqlx.Tx, bool, error) { + isWrapped := false + + wrappedTx, ok := ctx.Value(TransactionCtx).(*sqlx.Tx) + if ok && wrappedTx != nil { + isWrapped = true + return wrappedTx, isWrapped, nil + } + + tx, err := db.BeginTxx(ctx, nil) + if err != nil { + return nil, isWrapped, err + } + + return tx, isWrapped, nil +} diff --git a/database/sqlite3/sqlite_test.go b/database/sqlite3/sqlite_test.go new file mode 100644 index 0000000000..5724da0ed6 --- /dev/null +++ b/database/sqlite3/sqlite_test.go @@ -0,0 +1,78 @@ +//go:build integration +// +build integration + +package sqlite3 + +import ( + "context" + "fmt" + "github.com/frain-dev/convoy/internal/pkg/migrator" + "os" + "sync" + "testing" + + "github.com/frain-dev/convoy/pkg/log" + + "github.com/frain-dev/convoy/database" + "github.com/frain-dev/convoy/database/hooks" + "github.com/frain-dev/convoy/datastore" + + "github.com/stretchr/testify/require" +) + +var ( + once = sync.Once{} + _db *Sqlite +) + +func getDB(t *testing.T) (database.Database, func()) { + once.Do(func() { + var err error + + dbHooks := hooks.Init() + dbHooks.RegisterHook(datastore.EndpointCreated, func(data interface{}, changelog interface{}) {}) + + _db, err = NewDB("file::memory:?cache=shared", log.NewLogger(os.Stdout)) + require.NoError(t, err) + + // run migrations + m := migrator.New(_db, "sqlite3") + err = m.Up() + require.NoError(t, err) + }) + + return _db, func() { + require.NoError(t, _db.truncateTables()) + } +} + +func (s *Sqlite) truncateTables() error { + tables := []string{ + "event_deliveries", + "events", + "api_keys", + "subscriptions", + "source_verifiers", + "sources", + "configurations", + "devices", + "portal_links", + "organisation_invites", + "applications", + "endpoints", + "projects", + "project_configurations", + "organisation_members", + "organisations", + "users", + } + + for _, table := range tables { + _, err := s.dbx.ExecContext(context.Background(), fmt.Sprintf("delete from %s;", table)) + if err != nil { + return err + } + } + + return nil +} diff --git a/database/sqlite3/subscription.go b/database/sqlite3/subscription.go new file mode 100644 index 0000000000..e727b517de --- /dev/null +++ b/database/sqlite3/subscription.go @@ -0,0 +1,844 @@ +package sqlite3 + +import ( + "context" + "database/sql" + "errors" + "fmt" + "math" + "time" + + "github.com/frain-dev/convoy/database" + "github.com/frain-dev/convoy/datastore" + "github.com/frain-dev/convoy/pkg/compare" + "github.com/frain-dev/convoy/pkg/flatten" + "github.com/frain-dev/convoy/util" + "github.com/jmoiron/sqlx" +) + +const ( + createSubscription = ` + INSERT INTO subscriptions ( + id,name,type, + project_id,endpoint_id,device_id, + source_id,alert_config_count,alert_config_threshold, + retry_config_type,retry_config_duration, + retry_config_retry_count,filter_config_event_types, + filter_config_filter_headers,filter_config_filter_body, + filter_config_filter_is_flattened, + rate_limit_config_count,rate_limit_config_duration,function + ) + VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19); + ` + + updateSubscription = ` + UPDATE subscriptions SET + name=$3, + endpoint_id=$4, + source_id=$5, + alert_config_count=$6, + alert_config_threshold=$7, + retry_config_type=$8, + retry_config_duration=$9, + retry_config_retry_count=$10, + filter_config_event_types=$11, + filter_config_filter_headers=$12, + filter_config_filter_body=$13, + filter_config_filter_is_flattened=$14, + rate_limit_config_count=$15, + rate_limit_config_duration=$16, + function=$17, + updated_at=now() + WHERE id = $1 AND project_id = $2 + AND deleted_at IS NULL; + ` + + baseFetchSubscription = ` + SELECT + s.id,s.name,s.type, + s.project_id, + s.created_at, + s.updated_at, s.function, + + COALESCE(s.endpoint_id,'') AS "endpoint_id", + COALESCE(s.device_id,'') AS "device_id", + COALESCE(s.source_id,'') AS "source_id", + + s.alert_config_count AS "alert_config.count", + s.alert_config_threshold AS "alert_config.threshold", + s.retry_config_type AS "retry_config.type", + s.retry_config_duration AS "retry_config.duration", + s.retry_config_retry_count AS "retry_config.retry_count", + s.filter_config_event_types AS "filter_config.event_types", + s.filter_config_filter_headers AS "filter_config.filter.headers", + s.filter_config_filter_body AS "filter_config.filter.body", + s.filter_config_filter_is_flattened AS "filter_config.filter.is_flattened", + s.rate_limit_config_count AS "rate_limit_config.count", + s.rate_limit_config_duration AS "rate_limit_config.duration", + + COALESCE(em.secrets,'[]') AS "endpoint_metadata.secrets", + COALESCE(em.id,'') AS "endpoint_metadata.id", + COALESCE(em.name,'') AS "endpoint_metadata.name", + COALESCE(em.project_id,'') AS "endpoint_metadata.project_id", + COALESCE(em.support_email,'') AS "endpoint_metadata.support_email", + COALESCE(em.url,'') AS "endpoint_metadata.url", + COALESCE(em.status, '') AS "endpoint_metadata.status", + COALESCE(em.owner_id, '') AS "endpoint_metadata.owner_id", + + COALESCE(d.id,'') AS "device_metadata.id", + COALESCE(d.status,'') AS "device_metadata.status", + COALESCE(d.host_name,'') AS "device_metadata.host_name", + + COALESCE(sm.id,'') AS "source_metadata.id", + COALESCE(sm.name,'') AS "source_metadata.name", + COALESCE(sm.type,'') AS "source_metadata.type", + COALESCE(sm.mask_id,'') AS "source_metadata.mask_id", + COALESCE(sm.project_id,'') AS "source_metadata.project_id", + COALESCE(sm.is_disabled,FALSE) AS "source_metadata.is_disabled", + + COALESCE(sv.type, '') AS "source_metadata.verifier.type", + COALESCE(sv.basic_username, '') AS "source_metadata.verifier.basic_auth.username", + COALESCE(sv.basic_password, '') AS "source_metadata.verifier.basic_auth.password", + COALESCE(sv.api_key_header_name, '') AS "source_metadata.verifier.api_key.header_name", + COALESCE(sv.api_key_header_value, '') AS "source_metadata.verifier.api_key.header_value", + COALESCE(sv.hmac_hash, '') AS "source_metadata.verifier.hmac.hash", + COALESCE(sv.hmac_header, '') AS "source_metadata.verifier.hmac.header", + COALESCE(sv.hmac_secret, '') AS "source_metadata.verifier.hmac.secret", + COALESCE(sv.hmac_encoding, '') AS "source_metadata.verifier.hmac.encoding" + + FROM subscriptions s + LEFT JOIN endpoints em ON s.endpoint_id = em.id + LEFT JOIN sources sm ON s.source_id = sm.id + LEFT JOIN source_verifiers sv ON sv.id = sm.source_verifier_id + LEFT JOIN devices d ON s.device_id = d.id + WHERE s.deleted_at IS NULL ` + + fetchSubscriptionsForBroadcast = ` + select id, type, project_id, endpoint_id, function, + filter_config_event_types AS "filter_config.event_types", + filter_config_filter_headers AS "filter_config.filter.headers", + filter_config_filter_body AS "filter_config.filter.body", + filter_config_filter_is_flattened AS "filter_config.filter.is_flattened" + from subscriptions + where (ARRAY[$4] <@ filter_config_event_types OR ARRAY['*'] <@ filter_config_event_types) + AND id > $1 + AND project_id = $2 + AND deleted_at is null + ORDER BY id LIMIT $3` + + loadAllSubscriptionsConfiguration = ` + select name, id, type, project_id, endpoint_id, function, updated_at, + filter_config_event_types AS "filter_config.event_types", + filter_config_filter_headers AS "filter_config.filter.headers", + filter_config_filter_body AS "filter_config.filter.body", + filter_config_filter_is_flattened AS "filter_config.filter.is_flattened" + from subscriptions + where id > ? + AND project_id IN (?) + AND deleted_at is null + ORDER BY id LIMIT ?` + + fetchUpdatedSubscriptions = ` + select name, id, type, project_id, endpoint_id, function, updated_at, + filter_config_event_types AS "filter_config.event_types", + filter_config_filter_headers AS "filter_config.filter.headers", + filter_config_filter_body AS "filter_config.filter.body", + filter_config_filter_is_flattened AS "filter_config.filter.is_flattened" + from subscriptions + where updated_at > ? + AND id > ? + AND project_id IN (?) + AND deleted_at is null + ORDER BY id LIMIT ?` + + countDeletedSubscriptions = ` + select COUNT(id) from subscriptions + where (deleted_at IS NOT NULL AND deleted_at > ?) + AND project_id IN (?)` + + countUpdatedSubscriptions = ` + SELECT COUNT(*) + FROM ( + SELECT DISTINCT id + FROM subscriptions + WHERE deleted_at IS NULL + AND updated_at > ? + AND project_id IN (?) + ) AS distinct_ids` + + fetchDeletedSubscriptions = ` + select id,deleted_at, project_id, + filter_config_event_types AS "filter_config.event_types" + from subscriptions + where (deleted_at IS NOT NULL AND deleted_at > ?) + AND id > ? + AND project_id IN (?) + ORDER BY id LIMIT ?` + + baseFetchSubscriptionsPagedForward = ` + %s + %s + AND s.id <= :cursor + GROUP BY s.id, em.id, sm.id, sv.id, d.id + ORDER BY s.id DESC + LIMIT :limit + ` + + baseFetchSubscriptionsPagedBackward = ` + WITH subscriptions AS ( + %s + %s + AND s.id >= :cursor + GROUP BY s.id, em.id, sm.id, sv.id, d.id + ORDER BY s.id ASC + LIMIT :limit + ) + + SELECT * FROM subscriptions ORDER BY id DESC + ` + + countProjectSubscriptions = ` + SELECT COUNT(s.id) AS count + FROM subscriptions s + WHERE s.deleted_at IS NULL + AND s.project_id IN (?)` + + countEndpointSubscriptions = ` + SELECT COUNT(s.id) AS count + FROM subscriptions s + WHERE s.deleted_at IS NULL + AND s.project_id = $1 AND s.endpoint_id = $2` + + countPrevSubscriptions = ` + SELECT COUNT(DISTINCT(s.id)) AS count + FROM subscriptions s + WHERE s.deleted_at IS NULL + %s + AND s.id > :cursor GROUP BY s.id ORDER BY s.id DESC LIMIT 1` + + fetchSubscriptionByID = baseFetchSubscription + ` AND %s = $1 AND %s = $2;` + + fetchSubscriptionByDeviceID = ` + SELECT + s.id,s.name,s.type, + s.project_id, + s.created_at, + s.updated_at, s.function, + + COALESCE(s.endpoint_id,'') AS "endpoint_id", + COALESCE(s.device_id,'') AS "device_id", + COALESCE(s.source_id,'') AS "source_id", + + s.alert_config_count AS "alert_config.count", + s.alert_config_threshold AS "alert_config.threshold", + s.retry_config_type AS "retry_config.type", + s.retry_config_duration AS "retry_config.duration", + s.retry_config_retry_count AS "retry_config.retry_count", + s.filter_config_event_types AS "filter_config.event_types", + s.filter_config_filter_headers AS "filter_config.filter.headers", + s.filter_config_filter_body AS "filter_config.filter.body", + s.rate_limit_config_count AS "rate_limit_config.count", + s.rate_limit_config_duration AS "rate_limit_config.duration", + + COALESCE(d.id,'') AS "device_metadata.id", + COALESCE(d.status,'') AS "device_metadata.status", + COALESCE(d.host_name,'') AS "device_metadata.host_name" + + FROM subscriptions s + LEFT JOIN devices d ON s.device_id = d.id + WHERE s.device_id = $1 AND s.project_id = $2 AND s.type = $3` + + fetchCLISubscriptions = baseFetchSubscription + `AND %s = $1 AND %s = $2` + + deleteSubscriptions = ` + UPDATE subscriptions SET + deleted_at = NOW() + WHERE id = $1 AND project_id = $2; + ` +) + +var ( + ErrSubscriptionNotCreated = errors.New("subscription could not be created") + ErrSubscriptionNotUpdated = errors.New("subscription could not be updated") + ErrSubscriptionNotDeleted = errors.New("subscription could not be deleted") +) + +type subscriptionRepo struct { + db *sqlx.DB +} + +func NewSubscriptionRepo(db database.Database) datastore.SubscriptionRepository { + return &subscriptionRepo{db: db.GetDB()} +} + +func (s *subscriptionRepo) FetchUpdatedSubscriptions(ctx context.Context, projectIDs []string, t time.Time, pageSize int64) ([]datastore.Subscription, error) { + return s.fetchChangedSubscriptionConfig(ctx, countUpdatedSubscriptions, fetchUpdatedSubscriptions, projectIDs, t, pageSize) +} + +func (s *subscriptionRepo) FetchDeletedSubscriptions(ctx context.Context, projectIDs []string, t time.Time, pageSize int64) ([]datastore.Subscription, error) { + return s.fetchChangedSubscriptionConfig(ctx, countDeletedSubscriptions, fetchDeletedSubscriptions, projectIDs, t, pageSize) +} + +func (s *subscriptionRepo) LoadAllSubscriptionConfig(ctx context.Context, projectIDs []string, pageSize int64) ([]datastore.Subscription, error) { + if len(projectIDs) == 0 { + return []datastore.Subscription{}, nil + } + + query, args, err := sqlx.In(countProjectSubscriptions, projectIDs) + if err != nil { + return nil, err + } + + var subCount int64 + err = s.db.GetContext(ctx, &subCount, s.db.Rebind(query), args...) + if err != nil { + return nil, err + } + + if subCount == 0 { + return []datastore.Subscription{}, nil + } + + subs := make([]datastore.Subscription, subCount) + cursor := "0" + var rows *sqlx.Rows // reuse the mem + counter := 0 + numBatches := int64(math.Ceil(float64(subCount) / float64(pageSize))) + + for i := int64(0); i < numBatches; i++ { + query, args, err = sqlx.In(loadAllSubscriptionsConfiguration, cursor, projectIDs, pageSize) + if err != nil { + return nil, err + } + + rows, err = s.db.QueryxContext(ctx, s.db.Rebind(query), args...) + if err != nil { + return nil, err + } + + // using func to avoid calling defer in a loop, that can easily fill up function stack and cause a crash + func() { + defer closeWithError(rows) + for rows.Next() { + sub := datastore.Subscription{} + if err = rows.StructScan(&sub); err != nil { + return + } + + nullifyEmptyConfig(&sub) + subs[counter] = sub + counter++ + } + + if counter > 0 { + cursor = subs[counter-1].UID + } + }() + + if err != nil { + return nil, err + } + + } + + return subs[:counter], nil +} + +func (s *subscriptionRepo) FetchSubscriptionsForBroadcast(ctx context.Context, projectID string, eventType string, pageSize int) ([]datastore.Subscription, error) { + var _subs []datastore.Subscription + cursor := "0" + + for { + rows, err := s.db.QueryxContext(ctx, fetchSubscriptionsForBroadcast, cursor, projectID, pageSize, eventType) + if err != nil { + return nil, err + } + + subscriptions, err := scanSubscriptions(rows) + if err != nil { + return nil, err + } + + if len(subscriptions) == 0 { + break + } + + _subs = append(_subs, subscriptions...) + cursor = subscriptions[len(subscriptions)-1].UID + } + + return _subs, nil +} + +func (s *subscriptionRepo) fetchChangedSubscriptionConfig(ctx context.Context, countQuery, query string, projectIDs []string, t time.Time, pageSize int64) ([]datastore.Subscription, error) { + if len(projectIDs) == 0 { + return []datastore.Subscription{}, nil + } + + q, args, err := sqlx.In(countQuery, t, projectIDs) + if err != nil { + return nil, err + } + + var subCount int64 + err = s.db.GetContext(ctx, &subCount, s.db.Rebind(q), args...) + if err != nil { + return nil, err + } + + if subCount == 0 { + return []datastore.Subscription{}, nil + } + + subs := make([]datastore.Subscription, subCount) + cursor := "0" + var rows *sqlx.Rows // reuse the mem + counter := 0 + numBatches := int64(math.Ceil(float64(subCount) / float64(pageSize))) + + for i := int64(0); i < numBatches; i++ { + q, args, err = sqlx.In(query, t, cursor, projectIDs, pageSize) + if err != nil { + return nil, err + } + + rows, err = s.db.QueryxContext(ctx, s.db.Rebind(q), args...) + if err != nil { + return nil, err + } + + // using func to avoid calling defer in a loop, that can easily fill up function stack and cause a crash + func() { + defer closeWithError(rows) + for rows.Next() { + sub := datastore.Subscription{} + if err = rows.StructScan(&sub); err != nil { + return + } + + nullifyEmptyConfig(&sub) + subs[counter] = sub + counter++ + } + + if counter > 0 { + cursor = subs[counter-1].UID + } + }() + + if err != nil { + return nil, err + } + } + + return subs[:counter], nil +} + +func (s *subscriptionRepo) CreateSubscription(ctx context.Context, projectID string, subscription *datastore.Subscription) error { + if projectID != subscription.ProjectID { + return datastore.ErrNotAuthorisedToAccessDocument + } + + ac := subscription.GetAlertConfig() + rc := subscription.GetRetryConfig() + fc := subscription.GetFilterConfig() + rlc := subscription.GetRateLimitConfig() + + var endpointID, sourceID, deviceID *string + if !util.IsStringEmpty(subscription.EndpointID) { + endpointID = &subscription.EndpointID + } + + if !util.IsStringEmpty(subscription.SourceID) { + sourceID = &subscription.SourceID + } + + if !util.IsStringEmpty(subscription.DeviceID) { + deviceID = &subscription.DeviceID + } + + err := fc.Filter.Body.Flatten() + if err != nil { + return fmt.Errorf("failed to flatten body filter: %v", err) + } + + err = fc.Filter.Headers.Flatten() + if err != nil { + return fmt.Errorf("failed to flatten header filter: %v", err) + } + + fc.Filter.IsFlattened = true // this is just a flag so we can identify old records + + result, err := s.db.ExecContext( + ctx, createSubscription, subscription.UID, + subscription.Name, subscription.Type, subscription.ProjectID, + endpointID, deviceID, sourceID, + ac.Count, ac.Threshold, rc.Type, rc.Duration, rc.RetryCount, + fc.EventTypes, fc.Filter.Headers, fc.Filter.Body, fc.Filter.IsFlattened, + rlc.Count, rlc.Duration, subscription.Function, + ) + if err != nil { + return err + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + + if rowsAffected < 1 { + return ErrSubscriptionNotCreated + } + + _subscription := &datastore.Subscription{} + err = s.db.QueryRowxContext(ctx, fmt.Sprintf(fetchSubscriptionByID, "s.id", "s.project_id"), subscription.UID, projectID).StructScan(_subscription) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return datastore.ErrSubscriptionNotFound + } + return err + } + + nullifyEmptyConfig(_subscription) + *subscription = *_subscription + + return nil +} + +func (s *subscriptionRepo) UpdateSubscription(ctx context.Context, projectID string, subscription *datastore.Subscription) error { + ac := subscription.GetAlertConfig() + rc := subscription.GetRetryConfig() + fc := subscription.GetFilterConfig() + rlc := subscription.GetRateLimitConfig() + + var sourceID *string + if !util.IsStringEmpty(subscription.SourceID) { + sourceID = &subscription.SourceID + } + + err := fc.Filter.Body.Flatten() + if err != nil { + return fmt.Errorf("failed to flatten body filter: %v", err) + } + + err = fc.Filter.Headers.Flatten() + if err != nil { + return fmt.Errorf("failed to flatten header filter: %v", err) + } + + fc.Filter.IsFlattened = true // this is just a flag so we can identify old records + + result, err := s.db.ExecContext( + ctx, updateSubscription, subscription.UID, projectID, + subscription.Name, subscription.EndpointID, sourceID, + ac.Count, ac.Threshold, rc.Type, rc.Duration, rc.RetryCount, + fc.EventTypes, fc.Filter.Headers, fc.Filter.Body, fc.Filter.IsFlattened, + rlc.Count, rlc.Duration, subscription.Function, + ) + if err != nil { + return err + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + + if rowsAffected < 1 { + return ErrSubscriptionNotUpdated + } + + _subscription := &datastore.Subscription{} + err = s.db.QueryRowxContext(ctx, fmt.Sprintf(fetchSubscriptionByID, "s.id", "s.project_id"), subscription.UID, projectID).StructScan(_subscription) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return datastore.ErrSubscriptionNotFound + } + return err + } + + nullifyEmptyConfig(_subscription) + *subscription = *_subscription + + return nil +} + +func (s *subscriptionRepo) LoadSubscriptionsPaged(ctx context.Context, projectID string, filter *datastore.FilterBy, pageable datastore.Pageable) ([]datastore.Subscription, datastore.PaginationData, error) { + var rows *sqlx.Rows + var err error + + arg := map[string]interface{}{ + "project_id": projectID, + "endpoint_ids": filter.EndpointIDs, + "limit": pageable.Limit(), + "cursor": pageable.Cursor(), + "name": fmt.Sprintf("%%%s%%", filter.SubscriptionName), + } + + var query, filterQuery string + if pageable.Direction == datastore.Next { + query = baseFetchSubscriptionsPagedForward + } else { + query = baseFetchSubscriptionsPagedBackward + } + + filterQuery = ` AND s.project_id = :project_id` + if len(filter.EndpointIDs) > 0 { + filterQuery += ` AND s.endpoint_id IN (:endpoint_ids)` + } + + if !util.IsStringEmpty(filter.SubscriptionName) { + filterQuery += ` AND s.name LIKE :name` + } + + query = fmt.Sprintf(query, baseFetchSubscription, filterQuery) + + query, args, err := sqlx.Named(query, arg) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + query, args, err = sqlx.In(query, args...) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + query = s.db.Rebind(query) + + rows, err = s.db.QueryxContext(ctx, query, args...) + if err != nil { + return nil, datastore.PaginationData{}, err + } + defer closeWithError(rows) + + subscriptions, err := scanSubscriptions(rows) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + var count datastore.PrevRowCount + if len(subscriptions) > 0 { + var countQuery string + var qargs []interface{} + first := subscriptions[0] + qarg := arg + qarg["cursor"] = first.UID + + cq := fmt.Sprintf(countPrevSubscriptions, filterQuery) + countQuery, qargs, err = sqlx.Named(cq, qarg) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + countQuery, qargs, err = sqlx.In(countQuery, qargs...) + if err != nil { + return nil, datastore.PaginationData{}, err + } + + countQuery = s.db.Rebind(countQuery) + + // count the row number before the first row + rows, err := s.db.QueryxContext(ctx, countQuery, qargs...) + if err != nil { + return nil, datastore.PaginationData{}, err + } + defer closeWithError(rows) + + if rows.Next() { + err = rows.StructScan(&count) + if err != nil { + return nil, datastore.PaginationData{}, err + } + } + } + + ids := make([]string, len(subscriptions)) + for i := range subscriptions { + ids[i] = subscriptions[i].UID + } + + if len(subscriptions) > pageable.PerPage { + subscriptions = subscriptions[:len(subscriptions)-1] + } + + pagination := &datastore.PaginationData{PrevRowCount: count} + pagination = pagination.Build(pageable, ids) + + return subscriptions, *pagination, nil +} + +func (s *subscriptionRepo) DeleteSubscription(ctx context.Context, projectID string, subscription *datastore.Subscription) error { + result, err := s.db.ExecContext(ctx, deleteSubscriptions, subscription.UID, projectID) + if err != nil { + return err + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + + if rowsAffected < 1 { + return ErrSubscriptionNotDeleted + } + + return nil +} + +func (s *subscriptionRepo) FindSubscriptionByID(ctx context.Context, projectID string, subscriptionID string) (*datastore.Subscription, error) { + subscription := &datastore.Subscription{} + err := s.db.QueryRowxContext(ctx, fmt.Sprintf(fetchSubscriptionByID, "s.id", "s.project_id"), subscriptionID, projectID).StructScan(subscription) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, datastore.ErrSubscriptionNotFound + } + return nil, err + } + + nullifyEmptyConfig(subscription) + + return subscription, nil +} + +func (s *subscriptionRepo) FindSubscriptionsBySourceID(ctx context.Context, projectID string, sourceID string) ([]datastore.Subscription, error) { + rows, err := s.db.QueryxContext(ctx, fmt.Sprintf(fetchSubscriptionByID, "s.project_id", "s.source_id"), projectID, sourceID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, datastore.ErrSubscriptionNotFound + } + + return nil, err + } + + return scanSubscriptions(rows) +} + +func (s *subscriptionRepo) FindSubscriptionsByEndpointID(ctx context.Context, projectId string, endpointID string) ([]datastore.Subscription, error) { + rows, err := s.db.QueryxContext(ctx, fmt.Sprintf(fetchSubscriptionByID, "s.project_id", "s.endpoint_id"), projectId, endpointID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, datastore.ErrSubscriptionNotFound + } + + return nil, err + } + + return scanSubscriptions(rows) +} + +func (s *subscriptionRepo) FindSubscriptionByDeviceID(ctx context.Context, projectId string, deviceID string, subscriptionType datastore.SubscriptionType) (*datastore.Subscription, error) { + subscription := &datastore.Subscription{} + err := s.db.QueryRowxContext(ctx, fetchSubscriptionByDeviceID, deviceID, projectId, subscriptionType).StructScan(subscription) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, datastore.ErrSubscriptionNotFound + } + + return nil, err + } + + nullifyEmptyConfig(subscription) + + return subscription, nil +} + +func (s *subscriptionRepo) FindCLISubscriptions(ctx context.Context, projectID string) ([]datastore.Subscription, error) { + subscriptions, err := s.db.QueryxContext(ctx, fmt.Sprintf(fetchCLISubscriptions, "s.project_id", "s.type"), projectID, datastore.SubscriptionTypeCLI) + if err != nil { + return nil, err + } + + return scanSubscriptions(subscriptions) +} + +func (s *subscriptionRepo) CountEndpointSubscriptions(ctx context.Context, projectID, endpointID string) (int64, error) { + var count int64 + + err := s.db.GetContext(ctx, &count, countEndpointSubscriptions, projectID, endpointID) + if err != nil { + return 0, err + } + + return count, nil +} + +func (s *subscriptionRepo) TestSubscriptionFilter(_ context.Context, payload, filter interface{}, isFlattened bool) (bool, error) { + if payload == nil || filter == nil { + return true, nil + } + + p, err := flatten.Flatten(payload) + if err != nil { + return false, err + } + + if !isFlattened { + filter, err = flatten.Flatten(filter) + if err != nil { + return false, err + } + } + + // The filter must be of type flatten.M, because flatten.Flatten always returns that type, + // so whether pre-flattened or not, this must hold true + v, ok := filter.(flatten.M) + if !ok { + return false, fmt.Errorf("unknown type %T for filter", filter) + } + return compare.Compare(p, v) +} + +func (s *subscriptionRepo) CompareFlattenedPayload(_ context.Context, payload, filter flatten.M, isFlattened bool) (bool, error) { + if payload == nil || filter == nil { + return true, nil + } + + if !isFlattened { + var err error + filter, err = flatten.Flatten(filter) + if err != nil { + return false, err + } + } + + return compare.Compare(payload, filter) +} + +var ( + emptyAlertConfig = datastore.AlertConfiguration{} + emptyRetryConfig = datastore.RetryConfiguration{} + emptyRateLimitConfig = datastore.RateLimitConfiguration{} +) + +func nullifyEmptyConfig(sub *datastore.Subscription) { + if sub.AlertConfig != nil && *sub.AlertConfig == emptyAlertConfig { + sub.AlertConfig = nil + } + + if sub.RetryConfig != nil && *sub.RetryConfig == emptyRetryConfig { + sub.RetryConfig = nil + } + + if sub.RateLimitConfig != nil && *sub.RateLimitConfig == emptyRateLimitConfig { + sub.RateLimitConfig = nil + } +} + +func scanSubscriptions(rows *sqlx.Rows) ([]datastore.Subscription, error) { + subscriptions := make([]datastore.Subscription, 0) + var err error + defer closeWithError(rows) + + for rows.Next() { + sub := datastore.Subscription{} + err = rows.StructScan(&sub) + if err != nil { + return nil, err + } + nullifyEmptyConfig(&sub) + + subscriptions = append(subscriptions, sub) + } + + return subscriptions, nil +} diff --git a/database/sqlite3/subscription_test.go b/database/sqlite3/subscription_test.go new file mode 100644 index 0000000000..356ab77b6c --- /dev/null +++ b/database/sqlite3/subscription_test.go @@ -0,0 +1,526 @@ +//go:build integration +// +build integration + +package sqlite3 + +import ( + "context" + "fmt" + "math" + "testing" + "time" + + "github.com/frain-dev/convoy/database" + "github.com/frain-dev/convoy/datastore" + "github.com/oklog/ulid/v2" + "github.com/stretchr/testify/require" +) + +func generateSubscription(project *datastore.Project, source *datastore.Source, endpoint *datastore.Endpoint, device *datastore.Device) *datastore.Subscription { + uid := ulid.Make().String() + return &datastore.Subscription{ + UID: uid, + Name: "Subscription-" + uid, + Type: datastore.SubscriptionTypeAPI, + ProjectID: project.UID, + SourceID: source.UID, + EndpointID: endpoint.UID, + DeviceID: device.UID, + AlertConfig: &datastore.AlertConfiguration{ + Count: 10, + Threshold: "1m", + }, + RetryConfig: &datastore.RetryConfiguration{ + Type: "linear", + Duration: 3, + RetryCount: 10, + }, + FilterConfig: &datastore.FilterConfiguration{ + EventTypes: []string{"some.event"}, + Filter: datastore.FilterSchema{ + Headers: datastore.M{}, + Body: datastore.M{}, + }, + }, + } +} + +func seedSubscription(t *testing.T, db database.Database, project *datastore.Project, source *datastore.Source, endpoint *datastore.Endpoint, device *datastore.Device) *datastore.Subscription { + s := generateSubscription(project, source, endpoint, device) + require.NoError(t, NewSubscriptionRepo(db).CreateSubscription(context.Background(), project.UID, s)) + return s +} + +func Test_LoadSubscriptionsPaged(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + subRepo := NewSubscriptionRepo(db) + + source := seedSource(t, db) + project := seedProject(t, db) + endpoint := seedEndpoint(t, db) + device := seedDevice(t, db) + subMap := map[string]*datastore.Subscription{} + var newSub *datastore.Subscription + for i := 0; i < 100; i++ { + newSub = generateSubscription(project, source, endpoint, device) + require.NoError(t, subRepo.CreateSubscription(context.Background(), project.UID, newSub)) + subMap[newSub.UID] = newSub + } + + type Expected struct { + paginationData datastore.PaginationData + } + + tests := []struct { + name string + EndpointIDs []string + SubscriptionName string + pageData datastore.Pageable + expected Expected + }{ + { + name: "Load Subscriptions Paged - 10 records", + pageData: datastore.Pageable{PerPage: 3, Direction: datastore.Next, NextCursor: fmt.Sprintf("%d", math.MaxInt)}, + expected: Expected{ + paginationData: datastore.PaginationData{ + PerPage: 3, + }, + }, + }, + + { + name: "Load Subscriptions Paged - 1 record - filter by name", + SubscriptionName: newSub.Name, + pageData: datastore.Pageable{PerPage: 1, Direction: datastore.Next, NextCursor: fmt.Sprintf("%d", math.MaxInt)}, + expected: Expected{ + paginationData: datastore.PaginationData{ + PerPage: 1, + }, + }, + }, + + { + name: "Load Subscriptions Paged - 12 records", + pageData: datastore.Pageable{PerPage: 4, Direction: datastore.Next, NextCursor: fmt.Sprintf("%d", math.MaxInt)}, + expected: Expected{ + paginationData: datastore.PaginationData{ + PerPage: 4, + }, + }, + }, + + { + name: "Load Subscriptions Paged - 0 records", + pageData: datastore.Pageable{PerPage: 10, Direction: datastore.Next, NextCursor: fmt.Sprintf("%d", math.MaxInt)}, + expected: Expected{ + paginationData: datastore.PaginationData{ + PerPage: 10, + }, + }, + }, + + { + name: "Load Subscriptions Paged with Endpoint ID - 1 record", + EndpointIDs: []string{endpoint.UID}, + pageData: datastore.Pageable{PerPage: 3, Direction: datastore.Next, NextCursor: fmt.Sprintf("%d", math.MaxInt)}, + expected: Expected{ + paginationData: datastore.PaginationData{ + PerPage: 3, + }, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + subs, pageable, err := subRepo.LoadSubscriptionsPaged(context.Background(), project.UID, &datastore.FilterBy{EndpointIDs: tc.EndpointIDs, SubscriptionName: tc.SubscriptionName}, tc.pageData) + require.NoError(t, err) + + require.Equal(t, tc.expected.paginationData.PerPage, pageable.PerPage) + + require.Equal(t, tc.expected.paginationData.PerPage, int64(len(subs))) + + for _, dbSub := range subs { + + require.NotEmpty(t, dbSub.CreatedAt) + require.NotEmpty(t, dbSub.UpdatedAt) + + dbSub.CreatedAt, dbSub.UpdatedAt = time.Time{}, time.Time{} + + require.Equal(t, dbSub.Endpoint.UID, endpoint.UID) + require.Equal(t, dbSub.Endpoint.Name, endpoint.Name) + require.Equal(t, dbSub.Endpoint.ProjectID, endpoint.ProjectID) + require.Equal(t, dbSub.Endpoint.SupportEmail, endpoint.SupportEmail) + + require.Equal(t, dbSub.Source.UID, source.UID) + require.Equal(t, dbSub.Source.Name, source.Name) + require.Equal(t, dbSub.Source.Type, source.Type) + require.Equal(t, dbSub.Source.MaskID, source.MaskID) + require.Equal(t, dbSub.Source.ProjectID, source.ProjectID) + require.Equal(t, dbSub.Source.IsDisabled, source.IsDisabled) + + dbSub.Source, dbSub.Endpoint, dbSub.Device = nil, nil, nil + + require.Equal(t, dbSub.UID, subMap[dbSub.UID].UID) + } + }) + } +} + +func Test_DeleteSubscription(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + subRepo := NewSubscriptionRepo(db) + + project := seedProject(t, db) + source := seedSource(t, db) + endpoint := seedEndpoint(t, db) + + newSub := generateSubscription(project, source, endpoint, &datastore.Device{}) + + err := subRepo.CreateSubscription(context.Background(), project.UID, newSub) + require.NoError(t, err) + + // delete the sub + err = subRepo.DeleteSubscription(context.Background(), project.UID, newSub) + require.NoError(t, err) + + // Fetch sub again + _, err = subRepo.FindSubscriptionByID(context.Background(), project.UID, newSub.UID) + require.Equal(t, err, datastore.ErrSubscriptionNotFound) +} + +func Test_CreateSubscription(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + subRepo := NewSubscriptionRepo(db) + + project := seedProject(t, db) + source := seedSource(t, db) + endpoint := seedEndpoint(t, db) + + newSub := generateSubscription(project, source, endpoint, &datastore.Device{}) + require.NoError(t, subRepo.CreateSubscription(context.Background(), newSub.ProjectID, newSub)) + + dbSub, err := subRepo.FindSubscriptionByID(context.Background(), newSub.ProjectID, newSub.UID) + require.NoError(t, err) + + require.NotEmpty(t, dbSub.CreatedAt) + require.NotEmpty(t, dbSub.UpdatedAt) + + dbSub.CreatedAt, dbSub.UpdatedAt = time.Time{}, time.Time{} + dbSub.Source, dbSub.Endpoint, dbSub.Device = nil, nil, nil + + require.Equal(t, dbSub.UID, newSub.UID) +} + +func Test_CountEndpointSubscriptions(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + subRepo := NewSubscriptionRepo(db) + + project := seedProject(t, db) + source := seedSource(t, db) + endpoint := seedEndpoint(t, db) + + newSub1 := generateSubscription(project, source, endpoint, &datastore.Device{}) + require.NoError(t, subRepo.CreateSubscription(context.Background(), newSub1.ProjectID, newSub1)) + + newSub2 := generateSubscription(project, source, endpoint, &datastore.Device{}) + require.NoError(t, subRepo.CreateSubscription(context.Background(), newSub2.ProjectID, newSub2)) + + count, err := subRepo.CountEndpointSubscriptions(context.Background(), newSub1.ProjectID, endpoint.UID) + require.NoError(t, err) + + require.Equal(t, int64(2), count) +} + +func Test_UpdateSubscription(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + subRepo := NewSubscriptionRepo(db) + + project := seedProject(t, db) + source := seedSource(t, db) + endpoint := seedEndpoint(t, db) + + newSub := generateSubscription(project, source, endpoint, &datastore.Device{}) + require.NoError(t, subRepo.CreateSubscription(context.Background(), newSub.ProjectID, newSub)) + + update := &datastore.Subscription{ + UID: newSub.UID, + Name: "tyne&wear", + ProjectID: newSub.ProjectID, + Type: newSub.Type, + SourceID: seedSource(t, db).UID, + EndpointID: seedEndpoint(t, db).UID, + AlertConfig: &datastore.DefaultAlertConfig, + RetryConfig: &datastore.DefaultRetryConfig, + FilterConfig: newSub.FilterConfig, + RateLimitConfig: &datastore.DefaultRateLimitConfig, + } + + err := subRepo.UpdateSubscription(context.Background(), project.UID, update) + require.NoError(t, err) + + dbSub, err := subRepo.FindSubscriptionByID(context.Background(), newSub.ProjectID, newSub.UID) + require.NoError(t, err) + + require.NotEmpty(t, dbSub.CreatedAt) + require.NotEmpty(t, dbSub.UpdatedAt) + + dbSub.CreatedAt, dbSub.UpdatedAt = time.Time{}, time.Time{} + dbSub.Source, dbSub.Endpoint, dbSub.Device = nil, nil, nil + + require.Equal(t, dbSub.UID, update.UID) + require.Equal(t, dbSub.Name, update.Name) +} + +func Test_FindSubscriptionByID(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + subRepo := NewSubscriptionRepo(db) + + project := seedProject(t, db) + source := seedSource(t, db) + endpoint := seedEndpoint(t, db) + + newSub := generateSubscription(project, source, endpoint, &datastore.Device{}) + + // Fetch sub again + _, err := subRepo.FindSubscriptionByID(context.Background(), project.UID, newSub.UID) + require.Error(t, err) + require.EqualError(t, err, datastore.ErrSubscriptionNotFound.Error()) + + require.NoError(t, subRepo.CreateSubscription(context.Background(), project.UID, newSub)) + + // Fetch sub again + dbSub, err := subRepo.FindSubscriptionByID(context.Background(), project.UID, newSub.UID) + require.NoError(t, err) + + require.NotEmpty(t, dbSub.CreatedAt) + require.NotEmpty(t, dbSub.UpdatedAt) + + dbSub.CreatedAt, dbSub.UpdatedAt = time.Time{}, time.Time{} + require.Equal(t, dbSub.Endpoint.UID, endpoint.UID) + require.Equal(t, dbSub.Endpoint.Name, endpoint.Name) + require.Equal(t, dbSub.Endpoint.ProjectID, endpoint.ProjectID) + require.Equal(t, dbSub.Endpoint.SupportEmail, endpoint.SupportEmail) + + require.Equal(t, dbSub.Source.UID, source.UID) + require.Equal(t, dbSub.Source.Name, source.Name) + require.Equal(t, dbSub.Source.Type, source.Type) + require.Equal(t, dbSub.Source.MaskID, source.MaskID) + require.Equal(t, dbSub.Source.ProjectID, source.ProjectID) + require.Equal(t, dbSub.Source.IsDisabled, source.IsDisabled) + + dbSub.Source, dbSub.Endpoint, dbSub.Device = nil, nil, nil + + require.Equal(t, dbSub.UID, newSub.UID) +} + +func Test_FindSubscriptionsBySourceID(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + subRepo := NewSubscriptionRepo(db) + + source := seedSource(t, db) + project := seedProject(t, db) + endpoint := seedEndpoint(t, db) + + subMap := map[string]*datastore.Subscription{} + for i := 0; i < 5; i++ { + var newSub *datastore.Subscription + if i == 3 { + newSub = generateSubscription(project, seedSource(t, db), endpoint, &datastore.Device{}) + } else { + newSub = generateSubscription(project, source, endpoint, &datastore.Device{}) + } + + require.NoError(t, subRepo.CreateSubscription(context.Background(), project.UID, newSub)) + subMap[newSub.UID] = newSub + } + + // Fetch sub again + dbSubs, err := subRepo.FindSubscriptionsBySourceID(context.Background(), project.UID, source.UID) + require.NoError(t, err) + require.Equal(t, 4, len(dbSubs)) + + for _, dbSub := range dbSubs { + + require.NotEmpty(t, dbSub.CreatedAt) + require.NotEmpty(t, dbSub.UpdatedAt) + + dbSub.CreatedAt, dbSub.UpdatedAt = time.Time{}, time.Time{} + require.Equal(t, dbSub.Endpoint.UID, endpoint.UID) + require.Equal(t, dbSub.Endpoint.Name, endpoint.Name) + require.Equal(t, dbSub.Endpoint.ProjectID, endpoint.ProjectID) + require.Equal(t, dbSub.Endpoint.SupportEmail, endpoint.SupportEmail) + + require.Equal(t, dbSub.Source.UID, source.UID) + require.Equal(t, dbSub.Source.Name, source.Name) + require.Equal(t, dbSub.Source.Type, source.Type) + require.Equal(t, dbSub.Source.MaskID, source.MaskID) + require.Equal(t, dbSub.Source.ProjectID, source.ProjectID) + require.Equal(t, dbSub.Source.IsDisabled, source.IsDisabled) + + dbSub.Source, dbSub.Endpoint, dbSub.Device = nil, nil, nil + + require.Equal(t, dbSub.UID, subMap[dbSub.UID].UID) + } +} + +func Test_FindSubscriptionByEndpointID(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + subRepo := NewSubscriptionRepo(db) + + source := seedSource(t, db) + project := seedProject(t, db) + endpoint := seedEndpoint(t, db) + + subMap := map[string]*datastore.Subscription{} + for i := 0; i < 8; i++ { + var newSub *datastore.Subscription + if i == 3 { + newSub = generateSubscription(project, source, seedEndpoint(t, db), &datastore.Device{}) + } else { + newSub = generateSubscription(project, source, endpoint, &datastore.Device{}) + } + + require.NoError(t, subRepo.CreateSubscription(context.Background(), project.UID, newSub)) + subMap[newSub.UID] = newSub + } + + // Fetch sub again + dbSubs, err := subRepo.FindSubscriptionsByEndpointID(context.Background(), project.UID, endpoint.UID) + require.NoError(t, err) + require.Equal(t, 7, len(dbSubs)) + + for _, dbSub := range dbSubs { + + require.NotEmpty(t, dbSub.CreatedAt) + require.NotEmpty(t, dbSub.UpdatedAt) + + dbSub.CreatedAt, dbSub.UpdatedAt = time.Time{}, time.Time{} + require.Equal(t, dbSub.Endpoint.UID, endpoint.UID) + require.Equal(t, dbSub.Endpoint.Name, endpoint.Name) + require.Equal(t, dbSub.Endpoint.ProjectID, endpoint.ProjectID) + require.Equal(t, dbSub.Endpoint.SupportEmail, endpoint.SupportEmail) + + require.Equal(t, dbSub.Source.UID, source.UID) + require.Equal(t, dbSub.Source.Name, source.Name) + require.Equal(t, dbSub.Source.Type, source.Type) + require.Equal(t, dbSub.Source.MaskID, source.MaskID) + require.Equal(t, dbSub.Source.ProjectID, source.ProjectID) + require.Equal(t, dbSub.Source.IsDisabled, source.IsDisabled) + + dbSub.Source, dbSub.Endpoint, dbSub.Device = nil, nil, nil + + require.Equal(t, dbSub.UID, subMap[dbSub.UID].UID) + } +} + +func Test_FindSubscriptionByDeviceID(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + subRepo := NewSubscriptionRepo(db) + + project := seedProject(t, db) + source := seedSource(t, db) + endpoint := seedEndpoint(t, db) + device := seedDevice(t, db) + newSub := generateSubscription(project, source, endpoint, device) + + require.NoError(t, subRepo.CreateSubscription(context.Background(), project.UID, newSub)) + + // Fetch sub again + dbSub, err := subRepo.FindSubscriptionByDeviceID(context.Background(), project.UID, device.UID, newSub.Type) + require.NoError(t, err) + + require.NotEmpty(t, dbSub.CreatedAt) + require.NotEmpty(t, dbSub.UpdatedAt) + + dbSub.CreatedAt, dbSub.UpdatedAt = time.Time{}, time.Time{} + require.Nil(t, dbSub.Endpoint) + require.Nil(t, dbSub.Source) + + require.Equal(t, device.UID, dbSub.Device.UID) + require.Equal(t, device.HostName, dbSub.Device.HostName) + + dbSub.Source, dbSub.Endpoint, dbSub.Device = nil, nil, nil + + require.Equal(t, dbSub.UID, newSub.UID) +} + +func Test_FindCLISubscriptions(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + subRepo := NewSubscriptionRepo(db) + + source := seedSource(t, db) + project := seedProject(t, db) + endpoint := seedEndpoint(t, db) + + for i := 0; i < 8; i++ { + newSub := &datastore.Subscription{ + UID: ulid.Make().String(), + Name: "Subscription", + Type: datastore.SubscriptionTypeCLI, + ProjectID: project.UID, + SourceID: source.UID, + EndpointID: endpoint.UID, + AlertConfig: &datastore.AlertConfiguration{ + Count: 10, + Threshold: "1m", + }, + RetryConfig: &datastore.RetryConfiguration{ + Type: "linear", + Duration: 3, + RetryCount: 10, + }, + FilterConfig: &datastore.FilterConfiguration{ + EventTypes: []string{"some.event"}, + Filter: datastore.FilterSchema{ + Headers: datastore.M{}, + Body: datastore.M{}, + }, + }, + } + + require.NoError(t, subRepo.CreateSubscription(context.Background(), project.UID, newSub)) + } + + // Fetch sub again + dbSubs, err := subRepo.FindCLISubscriptions(context.Background(), project.UID) + require.NoError(t, err) + require.Equal(t, 8, len(dbSubs)) +} + +func seedDevice(t *testing.T, db database.Database) *datastore.Device { + project := seedProject(t, db) + endpoint := seedEndpoint(t, db) + + d := &datastore.Device{ + UID: ulid.Make().String(), + ProjectID: project.UID, + EndpointID: endpoint.UID, + HostName: "host1", + Status: datastore.DeviceStatusOnline, + } + + err := NewDeviceRepo(db).CreateDevice(context.Background(), d) + require.NoError(t, err) + return d +} diff --git a/database/sqlite3/users.go b/database/sqlite3/users.go new file mode 100644 index 0000000000..768e2ac97d --- /dev/null +++ b/database/sqlite3/users.go @@ -0,0 +1,177 @@ +package sqlite3 + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strings" + + "github.com/frain-dev/convoy/database" + "github.com/frain-dev/convoy/datastore" + "github.com/jmoiron/sqlx" +) + +const ( + createUser = ` + INSERT INTO users ( + id,first_name,last_name,email,password, + email_verified,reset_password_token, email_verification_token, + reset_password_expires_at,email_verification_expires_at, auth_type) + VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11) + ` + + updateUser = ` + UPDATE users SET + first_name = $2, + last_name=$3, + email=$4, + password=$5, + email_verified=$6, + reset_password_token=$7, + email_verification_token=$8, + reset_password_expires_at=$9, + email_verification_expires_at=$10 + WHERE id = $1 AND deleted_at IS NULL; + ` + + fetchUsers = ` + SELECT * FROM users + WHERE deleted_at IS NULL + ` + + countUsers = ` + SELECT COUNT(*) AS count + FROM users + WHERE deleted_at IS NULL` +) + +var ( + ErrUserNotCreated = errors.New("user could not be created") + ErrUserNotUpdated = errors.New("user could not be updated") +) + +type userRepo struct { + db *sqlx.DB +} + +func NewUserRepo(db database.Database) datastore.UserRepository { + return &userRepo{db: db.GetDB()} +} + +func (u *userRepo) CreateUser(ctx context.Context, user *datastore.User) error { + result, err := u.db.ExecContext(ctx, + createUser, + user.UID, + user.FirstName, + user.LastName, + user.Email, + user.Password, + user.EmailVerified, + user.ResetPasswordToken, + user.EmailVerificationToken, + user.ResetPasswordExpiresAt, + user.EmailVerificationExpiresAt, + user.AuthType, + ) + if err != nil { + if strings.Contains(err.Error(), "duplicate") { + return datastore.ErrDuplicateEmail + } + return err + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + + if rowsAffected < 1 { + return ErrUserNotCreated + } + + return nil +} + +func (u *userRepo) UpdateUser(ctx context.Context, user *datastore.User) error { + result, err := u.db.Exec( + updateUser, user.UID, user.FirstName, user.LastName, user.Email, user.Password, user.EmailVerified, user.ResetPasswordToken, + user.EmailVerificationToken, user.ResetPasswordExpiresAt, user.EmailVerificationExpiresAt, + ) + if err != nil { + return err + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + + if rowsAffected < 1 { + return ErrUserNotUpdated + } + + return nil +} + +func (u *userRepo) FindUserByEmail(ctx context.Context, email string) (*datastore.User, error) { + user := &datastore.User{} + err := u.db.QueryRowxContext(ctx, fmt.Sprintf("%s AND email = $1;", fetchUsers), email).StructScan(user) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, datastore.ErrUserNotFound + } + return nil, err + } + + return user, nil +} + +func (u *userRepo) FindUserByID(ctx context.Context, id string) (*datastore.User, error) { + user := &datastore.User{} + err := u.db.QueryRowxContext(ctx, fmt.Sprintf("%s AND id = $1;", fetchUsers), id).StructScan(user) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, datastore.ErrUserNotFound + } + return nil, err + } + + return user, nil +} + +func (u *userRepo) FindUserByToken(ctx context.Context, token string) (*datastore.User, error) { + user := &datastore.User{} + err := u.db.QueryRowxContext(ctx, fmt.Sprintf("%s AND reset_password_token = $1;", fetchUsers), token).StructScan(user) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, datastore.ErrUserNotFound + } + return nil, err + } + + return user, nil +} + +func (u *userRepo) FindUserByEmailVerificationToken(ctx context.Context, token string) (*datastore.User, error) { + user := &datastore.User{} + err := u.db.QueryRowxContext(ctx, fmt.Sprintf("%s AND email_verification_token = $1;", fetchUsers), token).StructScan(user) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, datastore.ErrUserNotFound + } + return nil, err + } + + return user, nil +} + +func (o *userRepo) CountUsers(ctx context.Context) (int64, error) { + var count int64 + err := o.db.GetContext(ctx, &count, countUsers) + if err != nil { + return 0, err + } + + return count, nil +} diff --git a/database/sqlite3/users_test.go b/database/sqlite3/users_test.go new file mode 100644 index 0000000000..afb95f24e2 --- /dev/null +++ b/database/sqlite3/users_test.go @@ -0,0 +1,292 @@ +//go:build integration +// +build integration + +package sqlite3 + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + "github.com/frain-dev/convoy/datastore" + "github.com/oklog/ulid/v2" + "github.com/stretchr/testify/require" +) + +func Test_CreateUser(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + tt := []struct { + name string + users []datastore.User + isDuplicateEmail bool + }{ + { + name: "create user", + users: []datastore.User{ + { + UID: ulid.Make().String(), + FirstName: "test", + LastName: "test", + Email: fmt.Sprintf("%s@test.com", ulid.Make().String()), + }, + }, + }, + { + name: "cannot create user with existing email", + isDuplicateEmail: true, + users: []datastore.User{ + { + UID: ulid.Make().String(), + FirstName: "test", + LastName: "test", + Email: "test@test.com", + EmailVerified: true, + Password: "dvsdvdkhjskuis", + ResetPasswordToken: "dvsdvdkhjskuis", + EmailVerificationToken: "v878678768686868", + ResetPasswordExpiresAt: time.Now(), + EmailVerificationExpiresAt: time.Now(), + }, + { + UID: ulid.Make().String(), + FirstName: "test", + LastName: "test", + Email: "test@test.com", + }, + }, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + userRepo := NewUserRepo(db) + + for i, user := range tc.users { + if i == 0 { + require.NoError(t, userRepo.CreateUser(context.Background(), &user)) + } + + user := &datastore.User{ + UID: user.UID, + FirstName: user.FirstName, + LastName: user.LastName, + Email: user.Email, + } + + if i > 0 && tc.isDuplicateEmail { + err := userRepo.CreateUser(context.Background(), user) + require.Error(t, err) + require.ErrorIs(t, err, datastore.ErrDuplicateEmail) + } + } + }) + } +} + +func TestCountUsers(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + userRepo := NewUserRepo(db) + count := 10 + + for i := 0; i < count; i++ { + u := &datastore.User{ + UID: ulid.Make().String(), + FirstName: "test", + LastName: "test", + Email: fmt.Sprintf("%s@test.com", ulid.Make().String()), + } + + err := userRepo.CreateUser(context.Background(), u) + require.NoError(t, err) + } + + userCount, err := userRepo.CountUsers(context.Background()) + + require.NoError(t, err) + require.Equal(t, int64(count), userCount) +} + +func Test_FindUserByEmail(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + userRepo := NewUserRepo(db) + + user := generateUser(t) + + _, err := userRepo.FindUserByEmail(context.Background(), user.Email) + require.Error(t, err) + require.True(t, errors.Is(err, datastore.ErrUserNotFound)) + + require.NoError(t, userRepo.CreateUser(context.Background(), user)) + + newUser, err := userRepo.FindUserByEmail(context.Background(), user.Email) + require.NoError(t, err) + + require.NotEmpty(t, newUser.CreatedAt) + require.NotEmpty(t, newUser.UpdatedAt) + + newUser.CreatedAt = time.Time{} + newUser.UpdatedAt = time.Time{} + + require.InDelta(t, user.EmailVerificationExpiresAt.Unix(), newUser.EmailVerificationExpiresAt.Unix(), float64(time.Hour)) + require.InDelta(t, user.ResetPasswordExpiresAt.Unix(), newUser.ResetPasswordExpiresAt.Unix(), float64(time.Hour)) + + user.EmailVerificationExpiresAt, user.ResetPasswordExpiresAt = time.Time{}, time.Time{} + newUser.EmailVerificationExpiresAt, newUser.ResetPasswordExpiresAt = time.Time{}, time.Time{} + + require.Equal(t, user, newUser) +} + +func Test_FindUserByID(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + userRepo := NewUserRepo(db) + + user := generateUser(t) + + _, err := userRepo.FindUserByID(context.Background(), user.UID) + require.Error(t, err) + require.True(t, errors.Is(err, datastore.ErrUserNotFound)) + + require.NoError(t, userRepo.CreateUser(context.Background(), user)) + newUser, err := userRepo.FindUserByID(context.Background(), user.UID) + require.NoError(t, err) + + require.NotEmpty(t, newUser.CreatedAt) + require.NotEmpty(t, newUser.UpdatedAt) + + newUser.CreatedAt = time.Time{} + newUser.UpdatedAt = time.Time{} + + require.InDelta(t, user.EmailVerificationExpiresAt.Unix(), newUser.EmailVerificationExpiresAt.Unix(), float64(time.Hour)) + require.InDelta(t, user.ResetPasswordExpiresAt.Unix(), newUser.ResetPasswordExpiresAt.Unix(), float64(time.Hour)) + + user.EmailVerificationExpiresAt, user.ResetPasswordExpiresAt = time.Time{}, time.Time{} + newUser.EmailVerificationExpiresAt, newUser.ResetPasswordExpiresAt = time.Time{}, time.Time{} + + require.Equal(t, user, newUser) +} + +func Test_FindUserByToken(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + userRepo := NewUserRepo(db) + + user := generateUser(t) + token := "fd7fidyfhdjhfdjhfjdh" + + user.ResetPasswordToken = token + + require.NoError(t, userRepo.CreateUser(context.Background(), user)) + newUser, err := userRepo.FindUserByToken(context.Background(), token) + require.NoError(t, err) + + require.NotEmpty(t, newUser.CreatedAt) + require.NotEmpty(t, newUser.UpdatedAt) + + newUser.CreatedAt = time.Time{} + newUser.UpdatedAt = time.Time{} + + require.InDelta(t, user.EmailVerificationExpiresAt.Unix(), newUser.EmailVerificationExpiresAt.Unix(), float64(time.Hour)) + require.InDelta(t, user.ResetPasswordExpiresAt.Unix(), newUser.ResetPasswordExpiresAt.Unix(), float64(time.Hour)) + + user.EmailVerificationExpiresAt, user.ResetPasswordExpiresAt = time.Time{}, time.Time{} + newUser.EmailVerificationExpiresAt, newUser.ResetPasswordExpiresAt = time.Time{}, time.Time{} + + require.Equal(t, user, newUser) +} + +func Test_FindUserByEmailVerificationToken(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + userRepo := NewUserRepo(db) + + user := generateUser(t) + token := "fd7fidyfhdjhfdjhfjdh" + + user.EmailVerificationToken = token + + require.NoError(t, userRepo.CreateUser(context.Background(), user)) + newUser, err := userRepo.FindUserByEmailVerificationToken(context.Background(), token) + require.NoError(t, err) + + require.NotEmpty(t, newUser.CreatedAt) + require.NotEmpty(t, newUser.UpdatedAt) + + newUser.CreatedAt = time.Time{} + newUser.UpdatedAt = time.Time{} + + require.InDelta(t, user.EmailVerificationExpiresAt.Unix(), newUser.EmailVerificationExpiresAt.Unix(), float64(time.Hour)) + require.InDelta(t, user.ResetPasswordExpiresAt.Unix(), newUser.ResetPasswordExpiresAt.Unix(), float64(time.Hour)) + + user.EmailVerificationExpiresAt, user.ResetPasswordExpiresAt = time.Time{}, time.Time{} + newUser.EmailVerificationExpiresAt, newUser.ResetPasswordExpiresAt = time.Time{}, time.Time{} + + require.Equal(t, user, newUser) +} + +func Test_UpdateUser(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + userRepo := NewUserRepo(db) + user := generateUser(t) + + require.NoError(t, userRepo.CreateUser(context.Background(), user)) + + updatedUser := &datastore.User{ + UID: user.UID, + FirstName: fmt.Sprintf("test%s", ulid.Make().String()), + LastName: fmt.Sprintf("test%s", ulid.Make().String()), + Email: fmt.Sprintf("%s@test.com", ulid.Make().String()), + EmailVerified: !user.EmailVerified, + Password: ulid.Make().String(), + ResetPasswordToken: fmt.Sprintf("%s-reset-password-token", ulid.Make().String()), + EmailVerificationToken: fmt.Sprintf("%s-verification-token", ulid.Make().String()), + ResetPasswordExpiresAt: time.Now().Add(time.Hour).UTC(), + EmailVerificationExpiresAt: time.Now().Add(time.Hour).UTC(), + } + + require.NoError(t, userRepo.UpdateUser(context.Background(), updatedUser)) + + dbUser, err := userRepo.FindUserByID(context.Background(), user.UID) + require.NoError(t, err) + + require.Equal(t, dbUser.UID, updatedUser.UID) + + dbUser.CreatedAt = time.Time{} + dbUser.UpdatedAt = time.Time{} + + require.InDelta(t, user.EmailVerificationExpiresAt.Unix(), dbUser.EmailVerificationExpiresAt.Unix(), float64(time.Hour)) + require.InDelta(t, user.ResetPasswordExpiresAt.Unix(), dbUser.ResetPasswordExpiresAt.Unix(), float64(time.Hour)) + + updatedUser.EmailVerificationExpiresAt, updatedUser.ResetPasswordExpiresAt = time.Time{}, time.Time{} + dbUser.EmailVerificationExpiresAt, dbUser.ResetPasswordExpiresAt = time.Time{}, time.Time{} + + require.Equal(t, updatedUser, dbUser) +} + +func generateUser(t *testing.T) *datastore.User { + return &datastore.User{ + UID: ulid.Make().String(), + FirstName: "test", + LastName: "test", + Email: fmt.Sprintf("%s@test.com", ulid.Make().String()), + EmailVerified: true, + Password: "dvsdvdkhjskuis", + ResetPasswordToken: "dvsdvdkhjskuis", + EmailVerificationToken: "v878678768686868", + ResetPasswordExpiresAt: time.Now(), + EmailVerificationExpiresAt: time.Now(), + } +} From ff89e066de5e875b40505995552121c07b48136d Mon Sep 17 00:00:00 2001 From: Raymond Tukpe Date: Thu, 14 Nov 2024 15:08:41 +0100 Subject: [PATCH 4/7] chore: add schema constraints --- sql/sqlite3/1731413863.sql | 196 ++++++++++++++++++++----------------- 1 file changed, 107 insertions(+), 89 deletions(-) diff --git a/sql/sqlite3/1731413863.sql b/sql/sqlite3/1731413863.sql index 7d6ca8eb77..19ccdd9005 100644 --- a/sql/sqlite3/1731413863.sql +++ b/sql/sqlite3/1731413863.sql @@ -1,4 +1,5 @@ -- +migrate Up +-- configurations create table if not exists configurations ( id TEXT not null, @@ -27,6 +28,7 @@ create table if not exists configurations cb_consecutive_failure_threshold INTEGER default 10 not null ); +-- event endpoints create table if not exists events_endpoints ( event_id TEXT not null, @@ -47,12 +49,14 @@ create index if not exists events_endpoints_temp_event_id_idx create unique index if not exists idx_uq_constraint_events_endpoints_event_id_endpoint_id on events_endpoints (event_id, endpoint_id); +-- migrations create table if not exists gorp_migrations ( id TEXT not null primary key, applied_at TEXT ); +-- project configurations create table if not exists project_configurations ( id TEXT not null primary key, @@ -80,6 +84,7 @@ create table if not exists project_configurations deleted_at TEXT ); +-- source verifiers create table if not exists source_verifiers ( id TEXT not null primary key, @@ -98,6 +103,7 @@ create table if not exists source_verifiers deleted_at TEXT ); +-- token bucket create table if not exists token_bucket ( key TEXT not null primary key, @@ -108,6 +114,7 @@ create table if not exists token_bucket expires_at TEXT not null ); +-- users create table if not exists users ( id TEXT not null primary key, @@ -123,11 +130,14 @@ create table if not exists users auth_type TEXT default 'local' not null, created_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), updated_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), - deleted_at TEXT, - constraint users_email_key - unique (email, deleted_at) + deleted_at TEXT ); +CREATE UNIQUE INDEX if not exists idx_unique_email_deleted_at + ON users(email) + WHERE deleted_at IS NULL; + +-- organisations create table if not exists organisations ( id TEXT not null primary key, @@ -141,10 +151,11 @@ create table if not exists organisations FOREIGN KEY(owner_id) REFERENCES users(id) ); -create unique index if not exists organisations_custom_domain +create unique index if not exists idx_organisations_custom_domain_deleted_at on organisations (custom_domain, assigned_domain) where (deleted_at IS NULL); +--projects create table if not exists projects ( id TEXT not null primary key, @@ -162,7 +173,11 @@ create table if not exists projects FOREIGN KEY(project_configuration_id) REFERENCES project_configurations(id) ); --- todo(raymond): deprecate me +create unique index if not exists idx_name_organisation_id_deleted_at + on projects (organisation_id, name) + where (deleted_at IS NULL); + +-- applications todo(raymond): deprecate me create table if not exists applications ( id TEXT not null primary key, @@ -176,7 +191,7 @@ create table if not exists applications FOREIGN KEY(project_id) REFERENCES projects(id) ); --- todo(raymond): deprecate me +-- devices todo(raymond): deprecate me create table if not exists devices ( id TEXT not null primary key, @@ -190,6 +205,7 @@ create table if not exists devices FOREIGN KEY(project_id) REFERENCES projects(id) ); +-- endpoints create table if not exists endpoints ( id TEXT not null primary key, @@ -216,6 +232,23 @@ create table if not exists endpoints FOREIGN KEY(project_id) REFERENCES projects(id) ); +CREATE UNIQUE INDEX if not exists idx_name_project_id_deleted_at + ON endpoints(name, project_id) + WHERE deleted_at IS NULL; + +create index if not exists idx_endpoints_name_key + on endpoints (name); + +create index if not exists idx_endpoints_app_id_key + on endpoints (app_id); + +create index if not exists idx_endpoints_owner_id_key + on endpoints (owner_id); + +create index if not exists idx_endpoints_project_id_key + on endpoints (project_id); + +-- api keys create table if not exists api_keys ( id TEXT not null primary key, @@ -232,24 +265,19 @@ create table if not exists api_keys updated_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), expires_at TEXT, deleted_at TEXT, - constraint api_keys_mask_id_key unique (mask_id, deleted_at), FOREIGN KEY(role_project) REFERENCES projects(id), FOREIGN KEY(role_endpoint) REFERENCES endpoints(id), FOREIGN KEY(user_id) REFERENCES users(id) ); +CREATE UNIQUE INDEX if not exists idx_mask_id_deleted_at + ON api_keys(mask_id) + WHERE deleted_at IS NULL; + create index if not exists idx_api_keys_mask_id on api_keys (mask_id); -create index if not exists idx_endpoints_app_id_key - on endpoints (app_id); - -create index if not exists idx_endpoints_owner_id_key - on endpoints (owner_id); - -create index if not exists idx_endpoints_project_id_key - on endpoints (project_id); - +-- event types create table if not exists event_types ( id TEXT not null primary key, @@ -263,6 +291,10 @@ create table if not exists event_types FOREIGN KEY(project_id) REFERENCES projects(id) ); +CREATE UNIQUE INDEX if not exists idx_name_project_id_deleted_at + ON event_types(name, project_id) + WHERE deprecated_at IS NULL; + create index if not exists idx_event_types_category on event_types (category); @@ -285,6 +317,8 @@ create index if not exists idx_event_types_name_not_deprecated on event_types (name) where (deprecated_at IS NULL); + +-- jobs create table if not exists jobs ( id TEXT not null primary key, @@ -300,6 +334,7 @@ create table if not exists jobs FOREIGN KEY(project_id) REFERENCES projects(id) ); +-- meta events create table if not exists meta_events ( id TEXT not null primary key, @@ -314,6 +349,7 @@ create table if not exists meta_events FOREIGN KEY(project_id) REFERENCES projects(id) ); +-- organisation invites create table if not exists organisation_invites ( id TEXT not null primary key, @@ -328,18 +364,22 @@ create table if not exists organisation_invites updated_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), expires_at TEXT not null, deleted_at TEXT, - constraint organisation_invites_token_key unique (token, deleted_at), FOREIGN KEY(role_project) REFERENCES projects(id), FOREIGN KEY(organisation_id) REFERENCES organisations(id), FOREIGN KEY(role_endpoint) REFERENCES endpoints(id) ); +CREATE UNIQUE INDEX if not exists idx_token_organisation_id_deleted_at + ON organisation_invites(token, organisation_id) + WHERE deleted_at IS NULL; + create index if not exists idx_organisation_invites_token_key on organisation_invites (token); create unique index if not exists organisation_invites_invitee_email on organisation_invites (organisation_id, invitee_email, deleted_at); +-- organisation members create table if not exists organisation_members ( id TEXT not null primary key, @@ -351,14 +391,16 @@ create table if not exists organisation_members created_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), updated_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), deleted_at TEXT, - constraint organisation_members_user_id_org_id_key - unique (organisation_id, user_id, deleted_at), FOREIGN KEY(role_project) REFERENCES projects(id), FOREIGN KEY(role_endpoint) REFERENCES endpoints(id), FOREIGN KEY(organisation_id) REFERENCES organisations(id), FOREIGN KEY(user_id) REFERENCES users(id) ); +CREATE UNIQUE INDEX if not exists idx_organisation_id_user_id_deleted_at + ON organisation_members(organisation_id, user_id) + WHERE deleted_at IS NULL; + create index if not exists idx_organisation_members_deleted_at_key on organisation_members (deleted_at); @@ -380,11 +422,13 @@ create table if not exists portal_links deleted_at TEXT, owner_id TEXT, can_manage_endpoint BOOLEAN default false, - constraint portal_links_token - unique (token, deleted_at), FOREIGN KEY(project_id) REFERENCES projects(id) ); +CREATE UNIQUE INDEX if not exists idx_token_deleted_at + ON portal_links(token) + WHERE deleted_at IS NULL; + create index if not exists idx_portal_links_owner_id_key on portal_links (owner_id); @@ -408,6 +452,8 @@ create index if not exists idx_portal_links_endpoints_enpdoint_id create index if not exists idx_portal_links_endpoints_portal_link_id on portal_links_endpoints (portal_link_id); + +-- sources create table if not exists sources ( id TEXT not null primary key, @@ -428,11 +474,24 @@ create table if not exists sources header_function TEXT, created_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), updated_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), - constraint sources_mask_id unique (mask_id, deleted_at), FOREIGN KEY(project_id) REFERENCES projects(id), FOREIGN KEY(source_verifier_id) REFERENCES source_verifiers(id) ); +CREATE UNIQUE INDEX if not exists idx_mask_id_project_id_deleted_at + ON sources(mask_id, project_id) + WHERE deleted_at IS NULL; + +create index if not exists idx_sources_mask_id + on sources (mask_id); + +create index if not exists idx_sources_project_id + on sources (project_id); + +create index if not exists idx_sources_source_verifier_id + on sources (source_verifier_id); + +-- events create table if not exists events ( id TEXT not null primary key, @@ -484,54 +543,7 @@ create index if not exists idx_project_id_on_not_deleted on events (project_id) where (deleted_at IS NULL); -create table if not exists events_search -( - id TEXT not null primary key, - event_type TEXT not null, - endpoints TEXT, - project_id TEXT not null, - source_id TEXT, - headers TEXT, - raw TEXT not null, - data TEXT not null, - url_query_params TEXT, - idempotency_key TEXT, - is_duplicate_event BOOLEAN default false, - search_token TEXT, - created_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), - updated_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), - deleted_at TEXT, - FOREIGN KEY(source_id) REFERENCES sources(id), - FOREIGN KEY(project_id) REFERENCES projects(id) -); - -create index if not exists idx_events_search_created_at_key - on events_search (created_at); - -create index if not exists idx_events_search_deleted_at_key - on events_search (deleted_at); - -create index if not exists idx_events_search_project_id_deleted_at_key - on events_search (project_id, deleted_at); - -create index if not exists idx_events_search_project_id_key - on events_search (project_id); - -create index if not exists idx_events_search_source_id_key - on events_search (source_id); - -create index if not exists idx_events_search_token_key - on events_search (search_token); - -create index if not exists idx_sources_mask_id - on sources (mask_id); - -create index if not exists idx_sources_project_id - on sources (project_id); - -create index if not exists idx_sources_source_verifier_id - on sources (source_verifier_id); - +-- subscriptions create table if not exists subscriptions ( id TEXT not null primary key, @@ -562,6 +574,30 @@ create table if not exists subscriptions FOREIGN KEY(project_id) REFERENCES projects(id) ); +CREATE UNIQUE INDEX if not exists idx_name_project_id_deleted_at + ON subscriptions(name, project_id) + WHERE deleted_at IS NULL; + +create index if not exists idx_subscriptions_filter_config_event_types_key + on subscriptions (filter_config_event_types); + +create index if not exists idx_subscriptions_id_deleted_at + on subscriptions (id, deleted_at) + where (deleted_at IS NOT NULL); + +create index if not exists idx_subscriptions_name_key + on subscriptions (name) + where (deleted_at IS NULL); + +create index if not exists idx_subscriptions_updated_at + on subscriptions (updated_at) + where (deleted_at IS NULL); + +create index if not exists idx_subscriptions_updated_at_id_project_id + on subscriptions (updated_at, id, project_id) + where (deleted_at IS NULL); + +-- event deliveries create table if not exists event_deliveries ( id TEXT not null primary key, @@ -591,6 +627,7 @@ create table if not exists event_deliveries FOREIGN KEY(project_id) REFERENCES projects(id) ); +-- delivery attempts create table if not exists delivery_attempts ( id TEXT not null primary key, @@ -664,25 +701,6 @@ create index if not exists idx_event_deliveries_status create index if not exists idx_event_deliveries_status_key on event_deliveries (status); -create index if not exists idx_subscriptions_filter_config_event_types_key - on subscriptions (filter_config_event_types); - -create index if not exists idx_subscriptions_id_deleted_at - on subscriptions (id, deleted_at) - where (deleted_at IS NOT NULL); - -create index if not exists idx_subscriptions_name_key - on subscriptions (name) - where (deleted_at IS NULL); - -create index if not exists idx_subscriptions_updated_at - on subscriptions (updated_at) - where (deleted_at IS NULL); - -create index if not exists idx_subscriptions_updated_at_id_project_id - on subscriptions (updated_at, id, project_id) - where (deleted_at IS NULL); - -- +migrate Down drop table if exists configurations; drop table if exists events_endpoints; From aee13bd693861671047443acbc98a54350d38153 Mon Sep 17 00:00:00 2001 From: Raymond Tukpe Date: Thu, 14 Nov 2024 15:09:51 +0100 Subject: [PATCH 5/7] feat: fixed user repo --- database/sqlite3/sqlite3.go | 43 +++- database/sqlite3/sqlite_test.go | 2 +- database/sqlite3/users.go | 99 ++++++--- database/sqlite3/users_test.go | 354 ++++++++++++++++---------------- 4 files changed, 289 insertions(+), 209 deletions(-) diff --git a/database/sqlite3/sqlite3.go b/database/sqlite3/sqlite3.go index 1bacc629a6..723fbd97e2 100644 --- a/database/sqlite3/sqlite3.go +++ b/database/sqlite3/sqlite3.go @@ -7,17 +7,21 @@ import ( "fmt" "github.com/frain-dev/convoy/database/hooks" "github.com/frain-dev/convoy/pkg/log" + "gopkg.in/guregu/null.v4" "io" + "time" "github.com/jmoiron/sqlx" _ "github.com/mattn/go-sqlite3" ) -const pkgName = "sqlite3" - type DbCtxKey string -const TransactionCtx DbCtxKey = "transaction" +const ( + pkgName = "sqlite3" + Rfc3339Milli = "2006-01-02T15:04:05.000Z" + TransactionCtx DbCtxKey = "transaction" +) type Sqlite struct { dbx *sqlx.DB @@ -103,3 +107,36 @@ func GetTx(ctx context.Context, db *sqlx.DB) (*sqlx.Tx, bool, error) { return tx, isWrapped, nil } + +func timeAsString(t time.Time) string { + return t.Format(Rfc3339Milli) +} + +func nullTimeAsString(t null.Time) *string { + strVal := "" + if t.Valid { + strVal = t.Time.Format(Rfc3339Milli) + return &strVal + } + return &strVal +} + +func asTime(ts string) time.Time { + t, err := time.Parse(Rfc3339Milli, ts) + if err != nil { + return time.Now() + } + return t +} + +func asNullTime(ts *string) null.Time { + if ts == nil { + return null.NewTime(time.Time{}, false) + } + + t, err := time.Parse(Rfc3339Milli, *ts) + if err != nil { + return null.NewTime(time.Now(), false) + } + return null.NewTime(t, true) +} diff --git a/database/sqlite3/sqlite_test.go b/database/sqlite3/sqlite_test.go index 5724da0ed6..651b5becec 100644 --- a/database/sqlite3/sqlite_test.go +++ b/database/sqlite3/sqlite_test.go @@ -32,7 +32,7 @@ func getDB(t *testing.T) (database.Database, func()) { dbHooks := hooks.Init() dbHooks.RegisterHook(datastore.EndpointCreated, func(data interface{}, changelog interface{}) {}) - _db, err = NewDB("file::memory:?cache=shared", log.NewLogger(os.Stdout)) + _db, err = NewDB("test.db?cache=shared", log.NewLogger(os.Stdout)) require.NoError(t, err) // run migrations diff --git a/database/sqlite3/users.go b/database/sqlite3/users.go index 768e2ac97d..bf66153338 100644 --- a/database/sqlite3/users.go +++ b/database/sqlite3/users.go @@ -5,34 +5,35 @@ import ( "database/sql" "errors" "fmt" - "strings" - "github.com/frain-dev/convoy/database" "github.com/frain-dev/convoy/datastore" "github.com/jmoiron/sqlx" + "strings" + "time" ) const ( createUser = ` INSERT INTO users ( - id,first_name,last_name,email,password, - email_verified,reset_password_token, email_verification_token, - reset_password_expires_at,email_verification_expires_at, auth_type) + id,first_name,last_name,email,password, email_verified, + reset_password_token, email_verification_token, reset_password_expires_at, + email_verification_expires_at, auth_type) VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11) ` updateUser = ` UPDATE users SET - first_name = $2, - last_name=$3, - email=$4, - password=$5, - email_verified=$6, - reset_password_token=$7, - email_verification_token=$8, - reset_password_expires_at=$9, - email_verification_expires_at=$10 - WHERE id = $1 AND deleted_at IS NULL; + first_name=$1, + last_name=$2, + email=$3, + password=$4, + email_verified=$5, + reset_password_token=$6, + email_verification_token=$7, + reset_password_expires_at=$8, + email_verification_expires_at=$9, + updated_at=$10 + WHERE id = $11 AND deleted_at IS NULL; ` fetchUsers = ` @@ -75,7 +76,7 @@ func (u *userRepo) CreateUser(ctx context.Context, user *datastore.User) error { user.AuthType, ) if err != nil { - if strings.Contains(err.Error(), "duplicate") { + if strings.Contains(err.Error(), "constraint") { return datastore.ErrDuplicateEmail } return err @@ -94,9 +95,9 @@ func (u *userRepo) CreateUser(ctx context.Context, user *datastore.User) error { } func (u *userRepo) UpdateUser(ctx context.Context, user *datastore.User) error { - result, err := u.db.Exec( - updateUser, user.UID, user.FirstName, user.LastName, user.Email, user.Password, user.EmailVerified, user.ResetPasswordToken, - user.EmailVerificationToken, user.ResetPasswordExpiresAt, user.EmailVerificationExpiresAt, + result, err := u.db.ExecContext(ctx, + updateUser, user.FirstName, user.LastName, user.Email, user.Password, user.EmailVerified, user.ResetPasswordToken, + user.EmailVerificationToken, user.ResetPasswordExpiresAt, user.EmailVerificationExpiresAt, time.Now(), user.UID, ) if err != nil { return err @@ -115,7 +116,7 @@ func (u *userRepo) UpdateUser(ctx context.Context, user *datastore.User) error { } func (u *userRepo) FindUserByEmail(ctx context.Context, email string) (*datastore.User, error) { - user := &datastore.User{} + user := &dbUser{} err := u.db.QueryRowxContext(ctx, fmt.Sprintf("%s AND email = $1;", fetchUsers), email).StructScan(user) if err != nil { if errors.Is(err, sql.ErrNoRows) { @@ -124,11 +125,11 @@ func (u *userRepo) FindUserByEmail(ctx context.Context, email string) (*datastor return nil, err } - return user, nil + return user.toDatastoreUser(), nil } func (u *userRepo) FindUserByID(ctx context.Context, id string) (*datastore.User, error) { - user := &datastore.User{} + user := &dbUser{} err := u.db.QueryRowxContext(ctx, fmt.Sprintf("%s AND id = $1;", fetchUsers), id).StructScan(user) if err != nil { if errors.Is(err, sql.ErrNoRows) { @@ -137,11 +138,11 @@ func (u *userRepo) FindUserByID(ctx context.Context, id string) (*datastore.User return nil, err } - return user, nil + return user.toDatastoreUser(), nil } func (u *userRepo) FindUserByToken(ctx context.Context, token string) (*datastore.User, error) { - user := &datastore.User{} + user := &dbUser{} err := u.db.QueryRowxContext(ctx, fmt.Sprintf("%s AND reset_password_token = $1;", fetchUsers), token).StructScan(user) if err != nil { if errors.Is(err, sql.ErrNoRows) { @@ -150,11 +151,11 @@ func (u *userRepo) FindUserByToken(ctx context.Context, token string) (*datastor return nil, err } - return user, nil + return user.toDatastoreUser(), nil } func (u *userRepo) FindUserByEmailVerificationToken(ctx context.Context, token string) (*datastore.User, error) { - user := &datastore.User{} + user := &dbUser{} err := u.db.QueryRowxContext(ctx, fmt.Sprintf("%s AND email_verification_token = $1;", fetchUsers), token).StructScan(user) if err != nil { if errors.Is(err, sql.ErrNoRows) { @@ -163,15 +164,51 @@ func (u *userRepo) FindUserByEmailVerificationToken(ctx context.Context, token s return nil, err } - return user, nil + return user.toDatastoreUser(), nil } -func (o *userRepo) CountUsers(ctx context.Context) (int64, error) { - var count int64 - err := o.db.GetContext(ctx, &count, countUsers) +func (u *userRepo) CountUsers(ctx context.Context) (int64, error) { + var userCount int64 + err := u.db.GetContext(ctx, &userCount, countUsers) if err != nil { return 0, err } - return count, nil + return userCount, nil +} + +type dbUser struct { + UID string `db:"id"` + FirstName string `db:"first_name"` + LastName string `db:"last_name"` + Email string `db:"email"` + EmailVerified bool `db:"email_verified"` + Password string `db:"password"` + ResetPasswordToken string `db:"reset_password_token"` + EmailVerificationToken string `db:"email_verification_token"` + CreatedAt string `db:"created_at,omitempty"` + UpdatedAt string `db:"updated_at,omitempty"` + DeletedAt *string `db:"deleted_at"` + ResetPasswordExpiresAt string `db:"reset_password_expires_at,omitempty"` + EmailVerificationExpiresAt string `db:"email_verification_expires_at,omitempty"` + AuthType string `db:"auth_type"` +} + +func (uu *dbUser) toDatastoreUser() *datastore.User { + return &datastore.User{ + UID: uu.UID, + FirstName: uu.FirstName, + LastName: uu.LastName, + Email: uu.Email, + AuthType: uu.AuthType, + EmailVerified: uu.EmailVerified, + Password: uu.Password, + ResetPasswordToken: uu.ResetPasswordToken, + EmailVerificationToken: uu.EmailVerificationToken, + CreatedAt: asTime(uu.CreatedAt), + UpdatedAt: asTime(uu.UpdatedAt), + DeletedAt: asNullTime(uu.DeletedAt), + ResetPasswordExpiresAt: asTime(uu.ResetPasswordExpiresAt), + EmailVerificationExpiresAt: asTime(uu.EmailVerificationExpiresAt), + } } diff --git a/database/sqlite3/users_test.go b/database/sqlite3/users_test.go index afb95f24e2..aedb271a79 100644 --- a/database/sqlite3/users_test.go +++ b/database/sqlite3/users_test.go @@ -7,6 +7,7 @@ import ( "context" "errors" "fmt" + "gopkg.in/guregu/null.v4" "testing" "time" @@ -15,31 +16,39 @@ import ( "github.com/stretchr/testify/require" ) -func Test_CreateUser(t *testing.T) { - db, closeFn := getDB(t) - defer closeFn() - - tt := []struct { - name string - users []datastore.User - isDuplicateEmail bool - }{ - { - name: "create user", - users: []datastore.User{ - { +func TestUsers(t *testing.T) { + t.Run("Test_CreateUser", func(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + tt := []struct { + name string + user datastore.User + fn func(datastore.UserRepository) + isDuplicateEmail bool + }{ + { + name: "create user", + user: datastore.User{ UID: ulid.Make().String(), FirstName: "test", LastName: "test", Email: fmt.Sprintf("%s@test.com", ulid.Make().String()), }, }, - }, - { - name: "cannot create user with existing email", - isDuplicateEmail: true, - users: []datastore.User{ - { + { + name: "cannot create user with existing email", + isDuplicateEmail: true, + fn: func(ur datastore.UserRepository) { + err := ur.CreateUser(context.Background(), &datastore.User{ + UID: ulid.Make().String(), + FirstName: "test", + LastName: "test", + Email: "test@test.com", + }) + require.NoError(t, err) + }, + user: datastore.User{ UID: ulid.Make().String(), FirstName: "test", LastName: "test", @@ -51,232 +60,229 @@ func Test_CreateUser(t *testing.T) { ResetPasswordExpiresAt: time.Now(), EmailVerificationExpiresAt: time.Now(), }, - { - UID: ulid.Make().String(), - FirstName: "test", - LastName: "test", - Email: "test@test.com", - }, }, - }, - } + } - for _, tc := range tt { - t.Run(tc.name, func(t *testing.T) { - userRepo := NewUserRepo(db) + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + repo := NewUserRepo(db) - for i, user := range tc.users { - if i == 0 { - require.NoError(t, userRepo.CreateUser(context.Background(), &user)) + if tc.fn != nil { + tc.fn(repo) } - user := &datastore.User{ - UID: user.UID, - FirstName: user.FirstName, - LastName: user.LastName, - Email: user.Email, + u := &datastore.User{ + UID: tc.user.UID, + FirstName: tc.user.FirstName, + LastName: tc.user.LastName, + Email: tc.user.Email, } - if i > 0 && tc.isDuplicateEmail { - err := userRepo.CreateUser(context.Background(), user) + if tc.isDuplicateEmail { + err := repo.CreateUser(context.Background(), u) require.Error(t, err) require.ErrorIs(t, err, datastore.ErrDuplicateEmail) } - } - }) - } -} + }) + } + }) -func TestCountUsers(t *testing.T) { - db, closeFn := getDB(t) - defer closeFn() + t.Run("TestCountUsers", func(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() - userRepo := NewUserRepo(db) - count := 10 + repo := NewUserRepo(db) + c := 10 - for i := 0; i < count; i++ { - u := &datastore.User{ - UID: ulid.Make().String(), - FirstName: "test", - LastName: "test", - Email: fmt.Sprintf("%s@test.com", ulid.Make().String()), - } + for i := 0; i < c; i++ { + u := &datastore.User{ + UID: ulid.Make().String(), + FirstName: "test", + LastName: "test", + Email: fmt.Sprintf("%s@test.com", ulid.Make().String()), + } - err := userRepo.CreateUser(context.Background(), u) - require.NoError(t, err) - } + err := repo.CreateUser(context.Background(), u) + require.NoError(t, err) + } - userCount, err := userRepo.CountUsers(context.Background()) + userCount, err := repo.CountUsers(context.Background()) - require.NoError(t, err) - require.Equal(t, int64(count), userCount) -} + require.NoError(t, err) + require.Equal(t, int64(c), userCount) + }) -func Test_FindUserByEmail(t *testing.T) { - db, closeFn := getDB(t) - defer closeFn() + t.Run("Test_FindUserByEmail", func(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() - userRepo := NewUserRepo(db) + repo := NewUserRepo(db) - user := generateUser(t) + user := generateUser(t) - _, err := userRepo.FindUserByEmail(context.Background(), user.Email) - require.Error(t, err) - require.True(t, errors.Is(err, datastore.ErrUserNotFound)) + _, err := repo.FindUserByEmail(context.Background(), user.Email) + require.Error(t, err) + require.True(t, errors.Is(err, datastore.ErrUserNotFound)) - require.NoError(t, userRepo.CreateUser(context.Background(), user)) + require.NoError(t, repo.CreateUser(context.Background(), user)) - newUser, err := userRepo.FindUserByEmail(context.Background(), user.Email) - require.NoError(t, err) + newUser, err := repo.FindUserByEmail(context.Background(), user.Email) + require.NoError(t, err) - require.NotEmpty(t, newUser.CreatedAt) - require.NotEmpty(t, newUser.UpdatedAt) + require.NotEmpty(t, newUser.CreatedAt) + require.NotEmpty(t, newUser.UpdatedAt) - newUser.CreatedAt = time.Time{} - newUser.UpdatedAt = time.Time{} + newUser.CreatedAt = time.Time{} + newUser.UpdatedAt = time.Time{} - require.InDelta(t, user.EmailVerificationExpiresAt.Unix(), newUser.EmailVerificationExpiresAt.Unix(), float64(time.Hour)) - require.InDelta(t, user.ResetPasswordExpiresAt.Unix(), newUser.ResetPasswordExpiresAt.Unix(), float64(time.Hour)) + require.InDelta(t, user.EmailVerificationExpiresAt.Unix(), newUser.EmailVerificationExpiresAt.Unix(), float64(time.Hour)) + require.InDelta(t, user.ResetPasswordExpiresAt.Unix(), newUser.ResetPasswordExpiresAt.Unix(), float64(time.Hour)) - user.EmailVerificationExpiresAt, user.ResetPasswordExpiresAt = time.Time{}, time.Time{} - newUser.EmailVerificationExpiresAt, newUser.ResetPasswordExpiresAt = time.Time{}, time.Time{} + user.EmailVerificationExpiresAt, user.ResetPasswordExpiresAt = time.Time{}, time.Time{} + newUser.EmailVerificationExpiresAt, newUser.ResetPasswordExpiresAt = time.Time{}, time.Time{} - require.Equal(t, user, newUser) -} + require.Equal(t, user, newUser) + }) -func Test_FindUserByID(t *testing.T) { - db, closeFn := getDB(t) - defer closeFn() + t.Run("Test_FindUserByID", func(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() - userRepo := NewUserRepo(db) + repo := NewUserRepo(db) - user := generateUser(t) + user := generateUser(t) - _, err := userRepo.FindUserByID(context.Background(), user.UID) - require.Error(t, err) - require.True(t, errors.Is(err, datastore.ErrUserNotFound)) + _, err := repo.FindUserByID(context.Background(), user.UID) + require.Error(t, err) + require.True(t, errors.Is(err, datastore.ErrUserNotFound)) - require.NoError(t, userRepo.CreateUser(context.Background(), user)) - newUser, err := userRepo.FindUserByID(context.Background(), user.UID) - require.NoError(t, err) + require.NoError(t, repo.CreateUser(context.Background(), user)) + newUser, err := repo.FindUserByID(context.Background(), user.UID) + require.NoError(t, err) - require.NotEmpty(t, newUser.CreatedAt) - require.NotEmpty(t, newUser.UpdatedAt) + require.NotEmpty(t, newUser.CreatedAt) + require.NotEmpty(t, newUser.UpdatedAt) - newUser.CreatedAt = time.Time{} - newUser.UpdatedAt = time.Time{} + newUser.CreatedAt = time.Time{} + newUser.UpdatedAt = time.Time{} + newUser.DeletedAt = null.NewTime(time.Time{}, false) - require.InDelta(t, user.EmailVerificationExpiresAt.Unix(), newUser.EmailVerificationExpiresAt.Unix(), float64(time.Hour)) - require.InDelta(t, user.ResetPasswordExpiresAt.Unix(), newUser.ResetPasswordExpiresAt.Unix(), float64(time.Hour)) + require.InDelta(t, user.EmailVerificationExpiresAt.Unix(), newUser.EmailVerificationExpiresAt.Unix(), float64(time.Hour)) + require.InDelta(t, user.ResetPasswordExpiresAt.Unix(), newUser.ResetPasswordExpiresAt.Unix(), float64(time.Hour)) - user.EmailVerificationExpiresAt, user.ResetPasswordExpiresAt = time.Time{}, time.Time{} - newUser.EmailVerificationExpiresAt, newUser.ResetPasswordExpiresAt = time.Time{}, time.Time{} + user.EmailVerificationExpiresAt, user.ResetPasswordExpiresAt = time.Time{}, time.Time{} + newUser.EmailVerificationExpiresAt, newUser.ResetPasswordExpiresAt = time.Time{}, time.Time{} - require.Equal(t, user, newUser) -} + require.Equal(t, user, newUser) + }) -func Test_FindUserByToken(t *testing.T) { - db, closeFn := getDB(t) - defer closeFn() + t.Run("Test_FindUserByToken", func(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() - userRepo := NewUserRepo(db) + repo := NewUserRepo(db) - user := generateUser(t) - token := "fd7fidyfhdjhfdjhfjdh" + user := generateUser(t) + token := "fd7fidyfhdjhfdjhfjdh" - user.ResetPasswordToken = token + user.ResetPasswordToken = token - require.NoError(t, userRepo.CreateUser(context.Background(), user)) - newUser, err := userRepo.FindUserByToken(context.Background(), token) - require.NoError(t, err) + require.NoError(t, repo.CreateUser(context.Background(), user)) + newUser, err := repo.FindUserByToken(context.Background(), token) + require.NoError(t, err) - require.NotEmpty(t, newUser.CreatedAt) - require.NotEmpty(t, newUser.UpdatedAt) + require.NotEmpty(t, newUser.CreatedAt) + require.NotEmpty(t, newUser.UpdatedAt) - newUser.CreatedAt = time.Time{} - newUser.UpdatedAt = time.Time{} + newUser.CreatedAt = time.Time{} + newUser.UpdatedAt = time.Time{} + newUser.DeletedAt = null.NewTime(time.Time{}, false) - require.InDelta(t, user.EmailVerificationExpiresAt.Unix(), newUser.EmailVerificationExpiresAt.Unix(), float64(time.Hour)) - require.InDelta(t, user.ResetPasswordExpiresAt.Unix(), newUser.ResetPasswordExpiresAt.Unix(), float64(time.Hour)) + require.InDelta(t, user.EmailVerificationExpiresAt.Unix(), newUser.EmailVerificationExpiresAt.Unix(), float64(time.Hour)) + require.InDelta(t, user.ResetPasswordExpiresAt.Unix(), newUser.ResetPasswordExpiresAt.Unix(), float64(time.Hour)) - user.EmailVerificationExpiresAt, user.ResetPasswordExpiresAt = time.Time{}, time.Time{} - newUser.EmailVerificationExpiresAt, newUser.ResetPasswordExpiresAt = time.Time{}, time.Time{} + user.EmailVerificationExpiresAt, user.ResetPasswordExpiresAt = time.Time{}, time.Time{} + newUser.EmailVerificationExpiresAt, newUser.ResetPasswordExpiresAt = time.Time{}, time.Time{} - require.Equal(t, user, newUser) -} + require.Equal(t, user, newUser) + }) -func Test_FindUserByEmailVerificationToken(t *testing.T) { - db, closeFn := getDB(t) - defer closeFn() + t.Run("Test_FindUserByEmailVerificationToken", func(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() - userRepo := NewUserRepo(db) + repo := NewUserRepo(db) - user := generateUser(t) - token := "fd7fidyfhdjhfdjhfjdh" + user := generateUser(t) + token := "fd7fidyfhdjhfdjhfjdh" - user.EmailVerificationToken = token + user.EmailVerificationToken = token - require.NoError(t, userRepo.CreateUser(context.Background(), user)) - newUser, err := userRepo.FindUserByEmailVerificationToken(context.Background(), token) - require.NoError(t, err) + require.NoError(t, repo.CreateUser(context.Background(), user)) + newUser, err := repo.FindUserByEmailVerificationToken(context.Background(), token) + require.NoError(t, err) - require.NotEmpty(t, newUser.CreatedAt) - require.NotEmpty(t, newUser.UpdatedAt) + require.NotEmpty(t, newUser.CreatedAt) + require.NotEmpty(t, newUser.UpdatedAt) - newUser.CreatedAt = time.Time{} - newUser.UpdatedAt = time.Time{} + newUser.CreatedAt = time.Time{} + newUser.UpdatedAt = time.Time{} + newUser.DeletedAt = null.NewTime(time.Time{}, false) - require.InDelta(t, user.EmailVerificationExpiresAt.Unix(), newUser.EmailVerificationExpiresAt.Unix(), float64(time.Hour)) - require.InDelta(t, user.ResetPasswordExpiresAt.Unix(), newUser.ResetPasswordExpiresAt.Unix(), float64(time.Hour)) + require.InDelta(t, user.EmailVerificationExpiresAt.Unix(), newUser.EmailVerificationExpiresAt.Unix(), float64(time.Hour)) + require.InDelta(t, user.ResetPasswordExpiresAt.Unix(), newUser.ResetPasswordExpiresAt.Unix(), float64(time.Hour)) - user.EmailVerificationExpiresAt, user.ResetPasswordExpiresAt = time.Time{}, time.Time{} - newUser.EmailVerificationExpiresAt, newUser.ResetPasswordExpiresAt = time.Time{}, time.Time{} + user.EmailVerificationExpiresAt, user.ResetPasswordExpiresAt = time.Time{}, time.Time{} + newUser.EmailVerificationExpiresAt, newUser.ResetPasswordExpiresAt = time.Time{}, time.Time{} - require.Equal(t, user, newUser) -} + require.Equal(t, user, newUser) + }) -func Test_UpdateUser(t *testing.T) { - db, closeFn := getDB(t) - defer closeFn() + t.Run("Test_UpdateUser", func(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() - userRepo := NewUserRepo(db) - user := generateUser(t) + repo := NewUserRepo(db) + user := generateUser(t) - require.NoError(t, userRepo.CreateUser(context.Background(), user)) + require.NoError(t, repo.CreateUser(context.Background(), user)) - updatedUser := &datastore.User{ - UID: user.UID, - FirstName: fmt.Sprintf("test%s", ulid.Make().String()), - LastName: fmt.Sprintf("test%s", ulid.Make().String()), - Email: fmt.Sprintf("%s@test.com", ulid.Make().String()), - EmailVerified: !user.EmailVerified, - Password: ulid.Make().String(), - ResetPasswordToken: fmt.Sprintf("%s-reset-password-token", ulid.Make().String()), - EmailVerificationToken: fmt.Sprintf("%s-verification-token", ulid.Make().String()), - ResetPasswordExpiresAt: time.Now().Add(time.Hour).UTC(), - EmailVerificationExpiresAt: time.Now().Add(time.Hour).UTC(), - } + updatedUser := &datastore.User{ + UID: user.UID, + FirstName: fmt.Sprintf("test%s", ulid.Make().String()), + LastName: fmt.Sprintf("test%s", ulid.Make().String()), + Email: user.Email, + EmailVerified: !user.EmailVerified, + Password: ulid.Make().String(), + ResetPasswordToken: fmt.Sprintf("%s-reset-password-token", ulid.Make().String()), + EmailVerificationToken: fmt.Sprintf("%s-verification-token", ulid.Make().String()), + ResetPasswordExpiresAt: time.Now().Add(time.Hour).UTC(), + EmailVerificationExpiresAt: time.Now().Add(time.Hour).UTC(), + } - require.NoError(t, userRepo.UpdateUser(context.Background(), updatedUser)) + require.NoError(t, repo.UpdateUser(context.Background(), updatedUser)) - dbUser, err := userRepo.FindUserByID(context.Background(), user.UID) - require.NoError(t, err) + userByID, err := repo.FindUserByID(context.Background(), user.UID) + require.NoError(t, err) - require.Equal(t, dbUser.UID, updatedUser.UID) + require.Equal(t, userByID.UID, updatedUser.UID) - dbUser.CreatedAt = time.Time{} - dbUser.UpdatedAt = time.Time{} + userByID.CreatedAt = time.Time{} + userByID.UpdatedAt = time.Time{} + userByID.DeletedAt = null.NewTime(time.Time{}, false) - require.InDelta(t, user.EmailVerificationExpiresAt.Unix(), dbUser.EmailVerificationExpiresAt.Unix(), float64(time.Hour)) - require.InDelta(t, user.ResetPasswordExpiresAt.Unix(), dbUser.ResetPasswordExpiresAt.Unix(), float64(time.Hour)) + require.InDelta(t, user.EmailVerificationExpiresAt.Unix(), userByID.EmailVerificationExpiresAt.Unix(), float64(time.Hour)) + require.InDelta(t, user.ResetPasswordExpiresAt.Unix(), userByID.ResetPasswordExpiresAt.Unix(), float64(time.Hour)) - updatedUser.EmailVerificationExpiresAt, updatedUser.ResetPasswordExpiresAt = time.Time{}, time.Time{} - dbUser.EmailVerificationExpiresAt, dbUser.ResetPasswordExpiresAt = time.Time{}, time.Time{} + updatedUser.EmailVerificationExpiresAt, updatedUser.ResetPasswordExpiresAt = time.Time{}, time.Time{} + userByID.EmailVerificationExpiresAt, userByID.ResetPasswordExpiresAt = time.Time{}, time.Time{} - require.Equal(t, updatedUser, dbUser) + require.Equal(t, updatedUser, userByID) + }) } func generateUser(t *testing.T) *datastore.User { + t.Helper() return &datastore.User{ UID: ulid.Make().String(), FirstName: "test", From 0ec67a908f67b3e737a2dd15ab77c43d2e3b614f Mon Sep 17 00:00:00 2001 From: Raymond Tukpe Date: Thu, 14 Nov 2024 15:53:19 +0100 Subject: [PATCH 6/7] feat: fixed subscription repo --- database/sqlite3/subscription.go | 124 +++-- database/sqlite3/subscription_test.go | 687 +++++++++++++------------- sql/sqlite3/1731413863.sql | 50 +- 3 files changed, 454 insertions(+), 407 deletions(-) diff --git a/database/sqlite3/subscription.go b/database/sqlite3/subscription.go index e727b517de..699153da45 100644 --- a/database/sqlite3/subscription.go +++ b/database/sqlite3/subscription.go @@ -5,6 +5,7 @@ import ( "database/sql" "errors" "fmt" + "gopkg.in/guregu/null.v4" "math" "time" @@ -33,23 +34,23 @@ const ( updateSubscription = ` UPDATE subscriptions SET - name=$3, - endpoint_id=$4, - source_id=$5, - alert_config_count=$6, - alert_config_threshold=$7, - retry_config_type=$8, - retry_config_duration=$9, - retry_config_retry_count=$10, - filter_config_event_types=$11, - filter_config_filter_headers=$12, - filter_config_filter_body=$13, - filter_config_filter_is_flattened=$14, - rate_limit_config_count=$15, - rate_limit_config_duration=$16, - function=$17, - updated_at=now() - WHERE id = $1 AND project_id = $2 + name=$1, + endpoint_id=$2, + source_id=$3, + alert_config_count=$4, + alert_config_threshold=$5, + retry_config_type=$6, + retry_config_duration=$7, + retry_config_retry_count=$8, + filter_config_event_types=$9, + filter_config_filter_headers=$10, + filter_config_filter_body=$11, + filter_config_filter_is_flattened=$12, + rate_limit_config_count=$13, + rate_limit_config_duration=$14, + function=$15, + updated_at=$16 + WHERE id = $17 AND project_id = $18 AND deleted_at IS NULL; ` @@ -252,8 +253,8 @@ const ( deleteSubscriptions = ` UPDATE subscriptions SET - deleted_at = NOW() - WHERE id = $1 AND project_id = $2; + deleted_at = $1 + WHERE id = $2 AND project_id = $3; ` ) @@ -320,13 +321,13 @@ func (s *subscriptionRepo) LoadAllSubscriptionConfig(ctx context.Context, projec func() { defer closeWithError(rows) for rows.Next() { - sub := datastore.Subscription{} + sub := dbSubscription{} if err = rows.StructScan(&sub); err != nil { return } nullifyEmptyConfig(&sub) - subs[counter] = sub + subs[counter] = *sub.toDatastoreSubscription() counter++ } @@ -411,13 +412,13 @@ func (s *subscriptionRepo) fetchChangedSubscriptionConfig(ctx context.Context, c func() { defer closeWithError(rows) for rows.Next() { - sub := datastore.Subscription{} + sub := dbSubscription{} if err = rows.StructScan(&sub); err != nil { return } nullifyEmptyConfig(&sub) - subs[counter] = sub + subs[counter] = *sub.toDatastoreSubscription() counter++ } @@ -490,7 +491,7 @@ func (s *subscriptionRepo) CreateSubscription(ctx context.Context, projectID str return ErrSubscriptionNotCreated } - _subscription := &datastore.Subscription{} + _subscription := &dbSubscription{} err = s.db.QueryRowxContext(ctx, fmt.Sprintf(fetchSubscriptionByID, "s.id", "s.project_id"), subscription.UID, projectID).StructScan(_subscription) if err != nil { if errors.Is(err, sql.ErrNoRows) { @@ -500,7 +501,7 @@ func (s *subscriptionRepo) CreateSubscription(ctx context.Context, projectID str } nullifyEmptyConfig(_subscription) - *subscription = *_subscription + *subscription = *_subscription.toDatastoreSubscription() return nil } @@ -529,11 +530,10 @@ func (s *subscriptionRepo) UpdateSubscription(ctx context.Context, projectID str fc.Filter.IsFlattened = true // this is just a flag so we can identify old records result, err := s.db.ExecContext( - ctx, updateSubscription, subscription.UID, projectID, - subscription.Name, subscription.EndpointID, sourceID, - ac.Count, ac.Threshold, rc.Type, rc.Duration, rc.RetryCount, - fc.EventTypes, fc.Filter.Headers, fc.Filter.Body, fc.Filter.IsFlattened, - rlc.Count, rlc.Duration, subscription.Function, + ctx, updateSubscription, subscription.Name, subscription.EndpointID, sourceID, + ac.Count, ac.Threshold, rc.Type, rc.Duration, rc.RetryCount, fc.EventTypes, + fc.Filter.Headers, fc.Filter.Body, fc.Filter.IsFlattened, rlc.Count, + rlc.Duration, subscription.Function, time.Now(), subscription.UID, projectID, ) if err != nil { return err @@ -548,7 +548,7 @@ func (s *subscriptionRepo) UpdateSubscription(ctx context.Context, projectID str return ErrSubscriptionNotUpdated } - _subscription := &datastore.Subscription{} + _subscription := &dbSubscription{} err = s.db.QueryRowxContext(ctx, fmt.Sprintf(fetchSubscriptionByID, "s.id", "s.project_id"), subscription.UID, projectID).StructScan(_subscription) if err != nil { if errors.Is(err, sql.ErrNoRows) { @@ -558,7 +558,7 @@ func (s *subscriptionRepo) UpdateSubscription(ctx context.Context, projectID str } nullifyEmptyConfig(_subscription) - *subscription = *_subscription + *subscription = *_subscription.toDatastoreSubscription() return nil } @@ -668,7 +668,7 @@ func (s *subscriptionRepo) LoadSubscriptionsPaged(ctx context.Context, projectID } func (s *subscriptionRepo) DeleteSubscription(ctx context.Context, projectID string, subscription *datastore.Subscription) error { - result, err := s.db.ExecContext(ctx, deleteSubscriptions, subscription.UID, projectID) + result, err := s.db.ExecContext(ctx, deleteSubscriptions, time.Now(), subscription.UID, projectID) if err != nil { return err } @@ -686,7 +686,7 @@ func (s *subscriptionRepo) DeleteSubscription(ctx context.Context, projectID str } func (s *subscriptionRepo) FindSubscriptionByID(ctx context.Context, projectID string, subscriptionID string) (*datastore.Subscription, error) { - subscription := &datastore.Subscription{} + subscription := &dbSubscription{} err := s.db.QueryRowxContext(ctx, fmt.Sprintf(fetchSubscriptionByID, "s.id", "s.project_id"), subscriptionID, projectID).StructScan(subscription) if err != nil { if errors.Is(err, sql.ErrNoRows) { @@ -697,7 +697,7 @@ func (s *subscriptionRepo) FindSubscriptionByID(ctx context.Context, projectID s nullifyEmptyConfig(subscription) - return subscription, nil + return subscription.toDatastoreSubscription(), nil } func (s *subscriptionRepo) FindSubscriptionsBySourceID(ctx context.Context, projectID string, sourceID string) ([]datastore.Subscription, error) { @@ -727,7 +727,7 @@ func (s *subscriptionRepo) FindSubscriptionsByEndpointID(ctx context.Context, pr } func (s *subscriptionRepo) FindSubscriptionByDeviceID(ctx context.Context, projectId string, deviceID string, subscriptionType datastore.SubscriptionType) (*datastore.Subscription, error) { - subscription := &datastore.Subscription{} + subscription := &dbSubscription{} err := s.db.QueryRowxContext(ctx, fetchSubscriptionByDeviceID, deviceID, projectId, subscriptionType).StructScan(subscription) if err != nil { if errors.Is(err, sql.ErrNoRows) { @@ -739,7 +739,7 @@ func (s *subscriptionRepo) FindSubscriptionByDeviceID(ctx context.Context, proje nullifyEmptyConfig(subscription) - return subscription, nil + return subscription.toDatastoreSubscription(), nil } func (s *subscriptionRepo) FindCLISubscriptions(ctx context.Context, projectID string) ([]datastore.Subscription, error) { @@ -810,7 +810,7 @@ var ( emptyRateLimitConfig = datastore.RateLimitConfiguration{} ) -func nullifyEmptyConfig(sub *datastore.Subscription) { +func nullifyEmptyConfig(sub *dbSubscription) { if sub.AlertConfig != nil && *sub.AlertConfig == emptyAlertConfig { sub.AlertConfig = nil } @@ -830,15 +830,59 @@ func scanSubscriptions(rows *sqlx.Rows) ([]datastore.Subscription, error) { defer closeWithError(rows) for rows.Next() { - sub := datastore.Subscription{} + sub := dbSubscription{} err = rows.StructScan(&sub) if err != nil { return nil, err } nullifyEmptyConfig(&sub) - subscriptions = append(subscriptions, sub) + subscriptions = append(subscriptions, *sub.toDatastoreSubscription()) } return subscriptions, nil } + +type dbSubscription struct { + UID string `db:"id"` + Name string `db:"name"` + Type datastore.SubscriptionType `db:"type"` + ProjectID string `db:"project_id"` + SourceID string `db:"source_id"` + EndpointID string `db:"endpoint_id"` + DeviceID string `db:"device_id"` + Function null.String `db:"function"` + Source *datastore.Source `db:"source_metadata"` + Endpoint *datastore.Endpoint `db:"endpoint_metadata"` + Device *datastore.Device `db:"device_metadata"` + AlertConfig *datastore.AlertConfiguration `db:"alert_config"` + RetryConfig *datastore.RetryConfiguration `db:"retry_config"` + FilterConfig *datastore.FilterConfiguration `db:"filter_config"` + RateLimitConfig *datastore.RateLimitConfiguration `db:"rate_limit_config"` + CreatedAt string `db:"created_at"` + UpdatedAt string `db:"updated_at"` + DeletedAt *string `db:"deleted_at"` +} + +func (ss *dbSubscription) toDatastoreSubscription() *datastore.Subscription { + return &datastore.Subscription{ + UID: ss.UID, + Name: ss.Name, + Type: ss.Type, + ProjectID: ss.ProjectID, + SourceID: ss.SourceID, + EndpointID: ss.EndpointID, + DeviceID: ss.DeviceID, + Function: ss.Function, + Source: ss.Source, + Endpoint: ss.Endpoint, + Device: ss.Device, + AlertConfig: ss.AlertConfig, + RetryConfig: ss.RetryConfig, + FilterConfig: ss.FilterConfig, + RateLimitConfig: ss.RateLimitConfig, + CreatedAt: asTime(ss.CreatedAt), + UpdatedAt: asTime(ss.UpdatedAt), + DeletedAt: asNullTime(ss.DeletedAt), + } +} diff --git a/database/sqlite3/subscription_test.go b/database/sqlite3/subscription_test.go index 356ab77b6c..8541b2b8aa 100644 --- a/database/sqlite3/subscription_test.go +++ b/database/sqlite3/subscription_test.go @@ -51,310 +51,260 @@ func seedSubscription(t *testing.T, db database.Database, project *datastore.Pro return s } -func Test_LoadSubscriptionsPaged(t *testing.T) { - db, closeFn := getDB(t) - defer closeFn() - - subRepo := NewSubscriptionRepo(db) - - source := seedSource(t, db) - project := seedProject(t, db) - endpoint := seedEndpoint(t, db) - device := seedDevice(t, db) - subMap := map[string]*datastore.Subscription{} - var newSub *datastore.Subscription - for i := 0; i < 100; i++ { - newSub = generateSubscription(project, source, endpoint, device) - require.NoError(t, subRepo.CreateSubscription(context.Background(), project.UID, newSub)) - subMap[newSub.UID] = newSub - } +func TestSubscription(t *testing.T) { + t.Run("LoadSubscriptionsPaged", func(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + subRepo := NewSubscriptionRepo(db) + + source := seedSource(t, db) + project := seedProject(t, db) + endpoint := seedEndpoint(t, db) + device := seedDevice(t, db) + subMap := map[string]*datastore.Subscription{} + var newSub *datastore.Subscription + for i := 0; i < 100; i++ { + newSub = generateSubscription(project, source, endpoint, device) + require.NoError(t, subRepo.CreateSubscription(context.Background(), project.UID, newSub)) + subMap[newSub.UID] = newSub + } - type Expected struct { - paginationData datastore.PaginationData - } + type Expected struct { + paginationData datastore.PaginationData + } - tests := []struct { - name string - EndpointIDs []string - SubscriptionName string - pageData datastore.Pageable - expected Expected - }{ - { - name: "Load Subscriptions Paged - 10 records", - pageData: datastore.Pageable{PerPage: 3, Direction: datastore.Next, NextCursor: fmt.Sprintf("%d", math.MaxInt)}, - expected: Expected{ - paginationData: datastore.PaginationData{ - PerPage: 3, + tests := []struct { + name string + EndpointIDs []string + SubscriptionName string + pageData datastore.Pageable + expected Expected + }{ + { + name: "Load Subscriptions Paged - 10 records", + pageData: datastore.Pageable{PerPage: 3, Direction: datastore.Next, NextCursor: fmt.Sprintf("%d", math.MaxInt)}, + expected: Expected{ + paginationData: datastore.PaginationData{ + PerPage: 3, + }, }, }, - }, - { - name: "Load Subscriptions Paged - 1 record - filter by name", - SubscriptionName: newSub.Name, - pageData: datastore.Pageable{PerPage: 1, Direction: datastore.Next, NextCursor: fmt.Sprintf("%d", math.MaxInt)}, - expected: Expected{ - paginationData: datastore.PaginationData{ - PerPage: 1, + { + name: "Load Subscriptions Paged - 1 record - filter by name", + SubscriptionName: newSub.Name, + pageData: datastore.Pageable{PerPage: 1, Direction: datastore.Next, NextCursor: fmt.Sprintf("%d", math.MaxInt)}, + expected: Expected{ + paginationData: datastore.PaginationData{ + PerPage: 1, + }, }, }, - }, - { - name: "Load Subscriptions Paged - 12 records", - pageData: datastore.Pageable{PerPage: 4, Direction: datastore.Next, NextCursor: fmt.Sprintf("%d", math.MaxInt)}, - expected: Expected{ - paginationData: datastore.PaginationData{ - PerPage: 4, + { + name: "Load Subscriptions Paged - 12 records", + pageData: datastore.Pageable{PerPage: 4, Direction: datastore.Next, NextCursor: fmt.Sprintf("%d", math.MaxInt)}, + expected: Expected{ + paginationData: datastore.PaginationData{ + PerPage: 4, + }, }, }, - }, - { - name: "Load Subscriptions Paged - 0 records", - pageData: datastore.Pageable{PerPage: 10, Direction: datastore.Next, NextCursor: fmt.Sprintf("%d", math.MaxInt)}, - expected: Expected{ - paginationData: datastore.PaginationData{ - PerPage: 10, + { + name: "Load Subscriptions Paged - 0 records", + pageData: datastore.Pageable{PerPage: 10, Direction: datastore.Next, NextCursor: fmt.Sprintf("%d", math.MaxInt)}, + expected: Expected{ + paginationData: datastore.PaginationData{ + PerPage: 10, + }, }, }, - }, - { - name: "Load Subscriptions Paged with Endpoint ID - 1 record", - EndpointIDs: []string{endpoint.UID}, - pageData: datastore.Pageable{PerPage: 3, Direction: datastore.Next, NextCursor: fmt.Sprintf("%d", math.MaxInt)}, - expected: Expected{ - paginationData: datastore.PaginationData{ - PerPage: 3, + { + name: "Load Subscriptions Paged with Endpoint ID - 1 record", + EndpointIDs: []string{endpoint.UID}, + pageData: datastore.Pageable{PerPage: 3, Direction: datastore.Next, NextCursor: fmt.Sprintf("%d", math.MaxInt)}, + expected: Expected{ + paginationData: datastore.PaginationData{ + PerPage: 3, + }, }, }, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - subs, pageable, err := subRepo.LoadSubscriptionsPaged(context.Background(), project.UID, &datastore.FilterBy{EndpointIDs: tc.EndpointIDs, SubscriptionName: tc.SubscriptionName}, tc.pageData) - require.NoError(t, err) - - require.Equal(t, tc.expected.paginationData.PerPage, pageable.PerPage) - - require.Equal(t, tc.expected.paginationData.PerPage, int64(len(subs))) - - for _, dbSub := range subs { - - require.NotEmpty(t, dbSub.CreatedAt) - require.NotEmpty(t, dbSub.UpdatedAt) - - dbSub.CreatedAt, dbSub.UpdatedAt = time.Time{}, time.Time{} - - require.Equal(t, dbSub.Endpoint.UID, endpoint.UID) - require.Equal(t, dbSub.Endpoint.Name, endpoint.Name) - require.Equal(t, dbSub.Endpoint.ProjectID, endpoint.ProjectID) - require.Equal(t, dbSub.Endpoint.SupportEmail, endpoint.SupportEmail) - - require.Equal(t, dbSub.Source.UID, source.UID) - require.Equal(t, dbSub.Source.Name, source.Name) - require.Equal(t, dbSub.Source.Type, source.Type) - require.Equal(t, dbSub.Source.MaskID, source.MaskID) - require.Equal(t, dbSub.Source.ProjectID, source.ProjectID) - require.Equal(t, dbSub.Source.IsDisabled, source.IsDisabled) - - dbSub.Source, dbSub.Endpoint, dbSub.Device = nil, nil, nil + } - require.Equal(t, dbSub.UID, subMap[dbSub.UID].UID) - } - }) - } -} + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + subs, pageable, err := subRepo.LoadSubscriptionsPaged(context.Background(), project.UID, &datastore.FilterBy{EndpointIDs: tc.EndpointIDs, SubscriptionName: tc.SubscriptionName}, tc.pageData) + require.NoError(t, err) -func Test_DeleteSubscription(t *testing.T) { - db, closeFn := getDB(t) - defer closeFn() + require.Equal(t, tc.expected.paginationData.PerPage, pageable.PerPage) - subRepo := NewSubscriptionRepo(db) + require.Equal(t, tc.expected.paginationData.PerPage, int64(len(subs))) - project := seedProject(t, db) - source := seedSource(t, db) - endpoint := seedEndpoint(t, db) + for _, dbSub := range subs { - newSub := generateSubscription(project, source, endpoint, &datastore.Device{}) + require.NotEmpty(t, dbSub.CreatedAt) + require.NotEmpty(t, dbSub.UpdatedAt) - err := subRepo.CreateSubscription(context.Background(), project.UID, newSub) - require.NoError(t, err) + dbSub.CreatedAt, dbSub.UpdatedAt = time.Time{}, time.Time{} - // delete the sub - err = subRepo.DeleteSubscription(context.Background(), project.UID, newSub) - require.NoError(t, err) + require.Equal(t, dbSub.Endpoint.UID, endpoint.UID) + require.Equal(t, dbSub.Endpoint.Name, endpoint.Name) + require.Equal(t, dbSub.Endpoint.ProjectID, endpoint.ProjectID) + require.Equal(t, dbSub.Endpoint.SupportEmail, endpoint.SupportEmail) - // Fetch sub again - _, err = subRepo.FindSubscriptionByID(context.Background(), project.UID, newSub.UID) - require.Equal(t, err, datastore.ErrSubscriptionNotFound) -} + require.Equal(t, dbSub.Source.UID, source.UID) + require.Equal(t, dbSub.Source.Name, source.Name) + require.Equal(t, dbSub.Source.Type, source.Type) + require.Equal(t, dbSub.Source.MaskID, source.MaskID) + require.Equal(t, dbSub.Source.ProjectID, source.ProjectID) + require.Equal(t, dbSub.Source.IsDisabled, source.IsDisabled) -func Test_CreateSubscription(t *testing.T) { - db, closeFn := getDB(t) - defer closeFn() + dbSub.Source, dbSub.Endpoint, dbSub.Device = nil, nil, nil - subRepo := NewSubscriptionRepo(db) + require.Equal(t, dbSub.UID, subMap[dbSub.UID].UID) + } + }) + } + }) - project := seedProject(t, db) - source := seedSource(t, db) - endpoint := seedEndpoint(t, db) + t.Run("DeleteSubscription", func(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() - newSub := generateSubscription(project, source, endpoint, &datastore.Device{}) - require.NoError(t, subRepo.CreateSubscription(context.Background(), newSub.ProjectID, newSub)) + subRepo := NewSubscriptionRepo(db) - dbSub, err := subRepo.FindSubscriptionByID(context.Background(), newSub.ProjectID, newSub.UID) - require.NoError(t, err) + project := seedProject(t, db) + source := seedSource(t, db) + endpoint := seedEndpoint(t, db) - require.NotEmpty(t, dbSub.CreatedAt) - require.NotEmpty(t, dbSub.UpdatedAt) + newSub := generateSubscription(project, source, endpoint, &datastore.Device{}) - dbSub.CreatedAt, dbSub.UpdatedAt = time.Time{}, time.Time{} - dbSub.Source, dbSub.Endpoint, dbSub.Device = nil, nil, nil + err := subRepo.CreateSubscription(context.Background(), project.UID, newSub) + require.NoError(t, err) - require.Equal(t, dbSub.UID, newSub.UID) -} + // delete the sub + err = subRepo.DeleteSubscription(context.Background(), project.UID, newSub) + require.NoError(t, err) -func Test_CountEndpointSubscriptions(t *testing.T) { - db, closeFn := getDB(t) - defer closeFn() + // Fetch sub again + _, err = subRepo.FindSubscriptionByID(context.Background(), project.UID, newSub.UID) + require.Equal(t, err, datastore.ErrSubscriptionNotFound) + }) - subRepo := NewSubscriptionRepo(db) + t.Run("CreateSubscription", func(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() - project := seedProject(t, db) - source := seedSource(t, db) - endpoint := seedEndpoint(t, db) + subRepo := NewSubscriptionRepo(db) - newSub1 := generateSubscription(project, source, endpoint, &datastore.Device{}) - require.NoError(t, subRepo.CreateSubscription(context.Background(), newSub1.ProjectID, newSub1)) + project := seedProject(t, db) + source := seedSource(t, db) + endpoint := seedEndpoint(t, db) - newSub2 := generateSubscription(project, source, endpoint, &datastore.Device{}) - require.NoError(t, subRepo.CreateSubscription(context.Background(), newSub2.ProjectID, newSub2)) + newSub := generateSubscription(project, source, endpoint, &datastore.Device{}) + require.NoError(t, subRepo.CreateSubscription(context.Background(), newSub.ProjectID, newSub)) - count, err := subRepo.CountEndpointSubscriptions(context.Background(), newSub1.ProjectID, endpoint.UID) - require.NoError(t, err) + dbSub, err := subRepo.FindSubscriptionByID(context.Background(), newSub.ProjectID, newSub.UID) + require.NoError(t, err) - require.Equal(t, int64(2), count) -} + require.NotEmpty(t, dbSub.CreatedAt) + require.NotEmpty(t, dbSub.UpdatedAt) -func Test_UpdateSubscription(t *testing.T) { - db, closeFn := getDB(t) - defer closeFn() + dbSub.CreatedAt, dbSub.UpdatedAt = time.Time{}, time.Time{} + dbSub.Source, dbSub.Endpoint, dbSub.Device = nil, nil, nil - subRepo := NewSubscriptionRepo(db) + require.Equal(t, dbSub.UID, newSub.UID) + }) - project := seedProject(t, db) - source := seedSource(t, db) - endpoint := seedEndpoint(t, db) + t.Run("CountEndpointSubscriptions", func(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() - newSub := generateSubscription(project, source, endpoint, &datastore.Device{}) - require.NoError(t, subRepo.CreateSubscription(context.Background(), newSub.ProjectID, newSub)) - - update := &datastore.Subscription{ - UID: newSub.UID, - Name: "tyne&wear", - ProjectID: newSub.ProjectID, - Type: newSub.Type, - SourceID: seedSource(t, db).UID, - EndpointID: seedEndpoint(t, db).UID, - AlertConfig: &datastore.DefaultAlertConfig, - RetryConfig: &datastore.DefaultRetryConfig, - FilterConfig: newSub.FilterConfig, - RateLimitConfig: &datastore.DefaultRateLimitConfig, - } + subRepo := NewSubscriptionRepo(db) - err := subRepo.UpdateSubscription(context.Background(), project.UID, update) - require.NoError(t, err) + project := seedProject(t, db) + source := seedSource(t, db) + endpoint := seedEndpoint(t, db) - dbSub, err := subRepo.FindSubscriptionByID(context.Background(), newSub.ProjectID, newSub.UID) - require.NoError(t, err) + newSub1 := generateSubscription(project, source, endpoint, &datastore.Device{}) + require.NoError(t, subRepo.CreateSubscription(context.Background(), newSub1.ProjectID, newSub1)) - require.NotEmpty(t, dbSub.CreatedAt) - require.NotEmpty(t, dbSub.UpdatedAt) + newSub2 := generateSubscription(project, source, endpoint, &datastore.Device{}) + require.NoError(t, subRepo.CreateSubscription(context.Background(), newSub2.ProjectID, newSub2)) - dbSub.CreatedAt, dbSub.UpdatedAt = time.Time{}, time.Time{} - dbSub.Source, dbSub.Endpoint, dbSub.Device = nil, nil, nil + count, err := subRepo.CountEndpointSubscriptions(context.Background(), newSub1.ProjectID, endpoint.UID) + require.NoError(t, err) - require.Equal(t, dbSub.UID, update.UID) - require.Equal(t, dbSub.Name, update.Name) -} + require.Equal(t, int64(2), count) + }) -func Test_FindSubscriptionByID(t *testing.T) { - db, closeFn := getDB(t) - defer closeFn() + t.Run("UpdateSubscription", func(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() - subRepo := NewSubscriptionRepo(db) + subRepo := NewSubscriptionRepo(db) - project := seedProject(t, db) - source := seedSource(t, db) - endpoint := seedEndpoint(t, db) + project := seedProject(t, db) + source := seedSource(t, db) + endpoint := seedEndpoint(t, db) - newSub := generateSubscription(project, source, endpoint, &datastore.Device{}) + newSub := generateSubscription(project, source, endpoint, &datastore.Device{}) + require.NoError(t, subRepo.CreateSubscription(context.Background(), newSub.ProjectID, newSub)) - // Fetch sub again - _, err := subRepo.FindSubscriptionByID(context.Background(), project.UID, newSub.UID) - require.Error(t, err) - require.EqualError(t, err, datastore.ErrSubscriptionNotFound.Error()) + update := &datastore.Subscription{ + UID: newSub.UID, + Name: "tyne&wear", + ProjectID: newSub.ProjectID, + Type: newSub.Type, + SourceID: seedSource(t, db).UID, + EndpointID: seedEndpoint(t, db).UID, + AlertConfig: &datastore.DefaultAlertConfig, + RetryConfig: &datastore.DefaultRetryConfig, + FilterConfig: newSub.FilterConfig, + RateLimitConfig: &datastore.DefaultRateLimitConfig, + } - require.NoError(t, subRepo.CreateSubscription(context.Background(), project.UID, newSub)) + err := subRepo.UpdateSubscription(context.Background(), project.UID, update) + require.NoError(t, err) - // Fetch sub again - dbSub, err := subRepo.FindSubscriptionByID(context.Background(), project.UID, newSub.UID) - require.NoError(t, err) + dbSub, err := subRepo.FindSubscriptionByID(context.Background(), newSub.ProjectID, newSub.UID) + require.NoError(t, err) - require.NotEmpty(t, dbSub.CreatedAt) - require.NotEmpty(t, dbSub.UpdatedAt) + require.NotEmpty(t, dbSub.CreatedAt) + require.NotEmpty(t, dbSub.UpdatedAt) - dbSub.CreatedAt, dbSub.UpdatedAt = time.Time{}, time.Time{} - require.Equal(t, dbSub.Endpoint.UID, endpoint.UID) - require.Equal(t, dbSub.Endpoint.Name, endpoint.Name) - require.Equal(t, dbSub.Endpoint.ProjectID, endpoint.ProjectID) - require.Equal(t, dbSub.Endpoint.SupportEmail, endpoint.SupportEmail) + dbSub.CreatedAt, dbSub.UpdatedAt = time.Time{}, time.Time{} + dbSub.Source, dbSub.Endpoint, dbSub.Device = nil, nil, nil - require.Equal(t, dbSub.Source.UID, source.UID) - require.Equal(t, dbSub.Source.Name, source.Name) - require.Equal(t, dbSub.Source.Type, source.Type) - require.Equal(t, dbSub.Source.MaskID, source.MaskID) - require.Equal(t, dbSub.Source.ProjectID, source.ProjectID) - require.Equal(t, dbSub.Source.IsDisabled, source.IsDisabled) + require.Equal(t, dbSub.UID, update.UID) + require.Equal(t, dbSub.Name, update.Name) + }) - dbSub.Source, dbSub.Endpoint, dbSub.Device = nil, nil, nil + t.Run("FindSubscriptionByID", func(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() - require.Equal(t, dbSub.UID, newSub.UID) -} + subRepo := NewSubscriptionRepo(db) -func Test_FindSubscriptionsBySourceID(t *testing.T) { - db, closeFn := getDB(t) - defer closeFn() + project := seedProject(t, db) + source := seedSource(t, db) + endpoint := seedEndpoint(t, db) - subRepo := NewSubscriptionRepo(db) + newSub := generateSubscription(project, source, endpoint, &datastore.Device{}) - source := seedSource(t, db) - project := seedProject(t, db) - endpoint := seedEndpoint(t, db) - - subMap := map[string]*datastore.Subscription{} - for i := 0; i < 5; i++ { - var newSub *datastore.Subscription - if i == 3 { - newSub = generateSubscription(project, seedSource(t, db), endpoint, &datastore.Device{}) - } else { - newSub = generateSubscription(project, source, endpoint, &datastore.Device{}) - } + // Fetch sub again + _, err := subRepo.FindSubscriptionByID(context.Background(), project.UID, newSub.UID) + require.Error(t, err) + require.EqualError(t, err, datastore.ErrSubscriptionNotFound.Error()) require.NoError(t, subRepo.CreateSubscription(context.Background(), project.UID, newSub)) - subMap[newSub.UID] = newSub - } - // Fetch sub again - dbSubs, err := subRepo.FindSubscriptionsBySourceID(context.Background(), project.UID, source.UID) - require.NoError(t, err) - require.Equal(t, 4, len(dbSubs)) - - for _, dbSub := range dbSubs { + // Fetch sub again + dbSub, err := subRepo.FindSubscriptionByID(context.Background(), project.UID, newSub.UID) + require.NoError(t, err) require.NotEmpty(t, dbSub.CreatedAt) require.NotEmpty(t, dbSub.UpdatedAt) @@ -374,138 +324,190 @@ func Test_FindSubscriptionsBySourceID(t *testing.T) { dbSub.Source, dbSub.Endpoint, dbSub.Device = nil, nil, nil - require.Equal(t, dbSub.UID, subMap[dbSub.UID].UID) - } -} + require.Equal(t, dbSub.UID, newSub.UID) + }) -func Test_FindSubscriptionByEndpointID(t *testing.T) { - db, closeFn := getDB(t) - defer closeFn() + t.Run("FindSubscriptionsBySourceID", func(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() - subRepo := NewSubscriptionRepo(db) + subRepo := NewSubscriptionRepo(db) - source := seedSource(t, db) - project := seedProject(t, db) - endpoint := seedEndpoint(t, db) + source := seedSource(t, db) + project := seedProject(t, db) + endpoint := seedEndpoint(t, db) - subMap := map[string]*datastore.Subscription{} - for i := 0; i < 8; i++ { - var newSub *datastore.Subscription - if i == 3 { - newSub = generateSubscription(project, source, seedEndpoint(t, db), &datastore.Device{}) - } else { - newSub = generateSubscription(project, source, endpoint, &datastore.Device{}) + subMap := map[string]*datastore.Subscription{} + for i := 0; i < 5; i++ { + var newSub *datastore.Subscription + if i == 3 { + newSub = generateSubscription(project, seedSource(t, db), endpoint, &datastore.Device{}) + } else { + newSub = generateSubscription(project, source, endpoint, &datastore.Device{}) + } + + require.NoError(t, subRepo.CreateSubscription(context.Background(), project.UID, newSub)) + subMap[newSub.UID] = newSub } - require.NoError(t, subRepo.CreateSubscription(context.Background(), project.UID, newSub)) - subMap[newSub.UID] = newSub - } + // Fetch sub again + dbSubs, err := subRepo.FindSubscriptionsBySourceID(context.Background(), project.UID, source.UID) + require.NoError(t, err) + require.Equal(t, 4, len(dbSubs)) - // Fetch sub again - dbSubs, err := subRepo.FindSubscriptionsByEndpointID(context.Background(), project.UID, endpoint.UID) - require.NoError(t, err) - require.Equal(t, 7, len(dbSubs)) + for _, dbSub := range dbSubs { - for _, dbSub := range dbSubs { + require.NotEmpty(t, dbSub.CreatedAt) + require.NotEmpty(t, dbSub.UpdatedAt) - require.NotEmpty(t, dbSub.CreatedAt) - require.NotEmpty(t, dbSub.UpdatedAt) + dbSub.CreatedAt, dbSub.UpdatedAt = time.Time{}, time.Time{} + require.Equal(t, dbSub.Endpoint.UID, endpoint.UID) + require.Equal(t, dbSub.Endpoint.Name, endpoint.Name) + require.Equal(t, dbSub.Endpoint.ProjectID, endpoint.ProjectID) + require.Equal(t, dbSub.Endpoint.SupportEmail, endpoint.SupportEmail) - dbSub.CreatedAt, dbSub.UpdatedAt = time.Time{}, time.Time{} - require.Equal(t, dbSub.Endpoint.UID, endpoint.UID) - require.Equal(t, dbSub.Endpoint.Name, endpoint.Name) - require.Equal(t, dbSub.Endpoint.ProjectID, endpoint.ProjectID) - require.Equal(t, dbSub.Endpoint.SupportEmail, endpoint.SupportEmail) + require.Equal(t, dbSub.Source.UID, source.UID) + require.Equal(t, dbSub.Source.Name, source.Name) + require.Equal(t, dbSub.Source.Type, source.Type) + require.Equal(t, dbSub.Source.MaskID, source.MaskID) + require.Equal(t, dbSub.Source.ProjectID, source.ProjectID) + require.Equal(t, dbSub.Source.IsDisabled, source.IsDisabled) - require.Equal(t, dbSub.Source.UID, source.UID) - require.Equal(t, dbSub.Source.Name, source.Name) - require.Equal(t, dbSub.Source.Type, source.Type) - require.Equal(t, dbSub.Source.MaskID, source.MaskID) - require.Equal(t, dbSub.Source.ProjectID, source.ProjectID) - require.Equal(t, dbSub.Source.IsDisabled, source.IsDisabled) + dbSub.Source, dbSub.Endpoint, dbSub.Device = nil, nil, nil - dbSub.Source, dbSub.Endpoint, dbSub.Device = nil, nil, nil + require.Equal(t, dbSub.UID, subMap[dbSub.UID].UID) + } + }) - require.Equal(t, dbSub.UID, subMap[dbSub.UID].UID) - } -} + t.Run("FindSubscriptionByEndpointID", func(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() -func Test_FindSubscriptionByDeviceID(t *testing.T) { - db, closeFn := getDB(t) - defer closeFn() + subRepo := NewSubscriptionRepo(db) - subRepo := NewSubscriptionRepo(db) + source := seedSource(t, db) + project := seedProject(t, db) + endpoint := seedEndpoint(t, db) - project := seedProject(t, db) - source := seedSource(t, db) - endpoint := seedEndpoint(t, db) - device := seedDevice(t, db) - newSub := generateSubscription(project, source, endpoint, device) + subMap := map[string]*datastore.Subscription{} + for i := 0; i < 8; i++ { + var newSub *datastore.Subscription + if i == 3 { + newSub = generateSubscription(project, source, seedEndpoint(t, db), &datastore.Device{}) + } else { + newSub = generateSubscription(project, source, endpoint, &datastore.Device{}) + } - require.NoError(t, subRepo.CreateSubscription(context.Background(), project.UID, newSub)) + require.NoError(t, subRepo.CreateSubscription(context.Background(), project.UID, newSub)) + subMap[newSub.UID] = newSub + } - // Fetch sub again - dbSub, err := subRepo.FindSubscriptionByDeviceID(context.Background(), project.UID, device.UID, newSub.Type) - require.NoError(t, err) + // Fetch sub again + dbSubs, err := subRepo.FindSubscriptionsByEndpointID(context.Background(), project.UID, endpoint.UID) + require.NoError(t, err) + require.Equal(t, 7, len(dbSubs)) - require.NotEmpty(t, dbSub.CreatedAt) - require.NotEmpty(t, dbSub.UpdatedAt) + for _, dbSub := range dbSubs { - dbSub.CreatedAt, dbSub.UpdatedAt = time.Time{}, time.Time{} - require.Nil(t, dbSub.Endpoint) - require.Nil(t, dbSub.Source) + require.NotEmpty(t, dbSub.CreatedAt) + require.NotEmpty(t, dbSub.UpdatedAt) - require.Equal(t, device.UID, dbSub.Device.UID) - require.Equal(t, device.HostName, dbSub.Device.HostName) + dbSub.CreatedAt, dbSub.UpdatedAt = time.Time{}, time.Time{} + require.Equal(t, dbSub.Endpoint.UID, endpoint.UID) + require.Equal(t, dbSub.Endpoint.Name, endpoint.Name) + require.Equal(t, dbSub.Endpoint.ProjectID, endpoint.ProjectID) + require.Equal(t, dbSub.Endpoint.SupportEmail, endpoint.SupportEmail) - dbSub.Source, dbSub.Endpoint, dbSub.Device = nil, nil, nil + require.Equal(t, dbSub.Source.UID, source.UID) + require.Equal(t, dbSub.Source.Name, source.Name) + require.Equal(t, dbSub.Source.Type, source.Type) + require.Equal(t, dbSub.Source.MaskID, source.MaskID) + require.Equal(t, dbSub.Source.ProjectID, source.ProjectID) + require.Equal(t, dbSub.Source.IsDisabled, source.IsDisabled) - require.Equal(t, dbSub.UID, newSub.UID) -} + dbSub.Source, dbSub.Endpoint, dbSub.Device = nil, nil, nil -func Test_FindCLISubscriptions(t *testing.T) { - db, closeFn := getDB(t) - defer closeFn() + require.Equal(t, dbSub.UID, subMap[dbSub.UID].UID) + } + }) - subRepo := NewSubscriptionRepo(db) + t.Run("FindSubscriptionByDeviceID", func(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() - source := seedSource(t, db) - project := seedProject(t, db) - endpoint := seedEndpoint(t, db) + subRepo := NewSubscriptionRepo(db) - for i := 0; i < 8; i++ { - newSub := &datastore.Subscription{ - UID: ulid.Make().String(), - Name: "Subscription", - Type: datastore.SubscriptionTypeCLI, - ProjectID: project.UID, - SourceID: source.UID, - EndpointID: endpoint.UID, - AlertConfig: &datastore.AlertConfiguration{ - Count: 10, - Threshold: "1m", - }, - RetryConfig: &datastore.RetryConfiguration{ - Type: "linear", - Duration: 3, - RetryCount: 10, - }, - FilterConfig: &datastore.FilterConfiguration{ - EventTypes: []string{"some.event"}, - Filter: datastore.FilterSchema{ - Headers: datastore.M{}, - Body: datastore.M{}, - }, - }, - } + project := seedProject(t, db) + source := seedSource(t, db) + endpoint := seedEndpoint(t, db) + device := seedDevice(t, db) + newSub := generateSubscription(project, source, endpoint, device) require.NoError(t, subRepo.CreateSubscription(context.Background(), project.UID, newSub)) - } - // Fetch sub again - dbSubs, err := subRepo.FindCLISubscriptions(context.Background(), project.UID) - require.NoError(t, err) - require.Equal(t, 8, len(dbSubs)) + // Fetch sub again + dbSub, err := subRepo.FindSubscriptionByDeviceID(context.Background(), project.UID, device.UID, newSub.Type) + require.NoError(t, err) + + require.NotEmpty(t, dbSub.CreatedAt) + require.NotEmpty(t, dbSub.UpdatedAt) + + dbSub.CreatedAt, dbSub.UpdatedAt = time.Time{}, time.Time{} + require.Nil(t, dbSub.Endpoint) + require.Nil(t, dbSub.Source) + + require.Equal(t, device.UID, dbSub.Device.UID) + require.Equal(t, device.HostName, dbSub.Device.HostName) + + dbSub.Source, dbSub.Endpoint, dbSub.Device = nil, nil, nil + + require.Equal(t, dbSub.UID, newSub.UID) + }) + + t.Run("FindCLISubscriptions", func(t *testing.T) { + db, closeFn := getDB(t) + defer closeFn() + + subRepo := NewSubscriptionRepo(db) + + source := seedSource(t, db) + project := seedProject(t, db) + endpoint := seedEndpoint(t, db) + + for i := 0; i < 8; i++ { + newSub := &datastore.Subscription{ + UID: ulid.Make().String(), + Name: "Subscription", + Type: datastore.SubscriptionTypeCLI, + ProjectID: project.UID, + SourceID: source.UID, + EndpointID: endpoint.UID, + AlertConfig: &datastore.AlertConfiguration{ + Count: 10, + Threshold: "1m", + }, + RetryConfig: &datastore.RetryConfiguration{ + Type: "linear", + Duration: 3, + RetryCount: 10, + }, + FilterConfig: &datastore.FilterConfiguration{ + EventTypes: []string{"some.event"}, + Filter: datastore.FilterSchema{ + Headers: datastore.M{}, + Body: datastore.M{}, + }, + }, + } + + require.NoError(t, subRepo.CreateSubscription(context.Background(), project.UID, newSub)) + } + + // Fetch sub again + dbSubs, err := subRepo.FindCLISubscriptions(context.Background(), project.UID) + require.NoError(t, err) + require.Equal(t, 8, len(dbSubs)) + }) } func seedDevice(t *testing.T, db database.Database) *datastore.Device { @@ -513,8 +515,9 @@ func seedDevice(t *testing.T, db database.Database) *datastore.Device { endpoint := seedEndpoint(t, db) d := &datastore.Device{ - UID: ulid.Make().String(), - ProjectID: project.UID, + UID: ulid.Make().String(), + ProjectID: project.UID, + EndpointID: endpoint.UID, HostName: "host1", Status: datastore.DeviceStatusOnline, diff --git a/sql/sqlite3/1731413863.sql b/sql/sqlite3/1731413863.sql index 19ccdd9005..f3280b21a3 100644 --- a/sql/sqlite3/1731413863.sql +++ b/sql/sqlite3/1731413863.sql @@ -26,7 +26,7 @@ create table if not exists configurations cb_success_threshold INTEGER default 1 not null, cb_observability_window INTEGER default 30 not null, cb_consecutive_failure_threshold INTEGER default 10 not null -); +) strict; -- event endpoints create table if not exists events_endpoints @@ -35,7 +35,7 @@ create table if not exists events_endpoints endpoint_id TEXT not null, FOREIGN KEY(endpoint_id) REFERENCES endpoints(id), FOREIGN KEY(event_id) REFERENCES events(id) -); +) strict; create index if not exists events_endpoints_temp_endpoint_id_idx on events_endpoints (endpoint_id); @@ -54,7 +54,7 @@ create table if not exists gorp_migrations ( id TEXT not null primary key, applied_at TEXT -); +) strict; -- project configurations create table if not exists project_configurations @@ -82,7 +82,7 @@ create table if not exists project_configurations created_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), updated_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), deleted_at TEXT -); +) strict; -- source verifiers create table if not exists source_verifiers @@ -101,7 +101,7 @@ create table if not exists source_verifiers created_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), updated_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), deleted_at TEXT -); +) strict; -- token bucket create table if not exists token_bucket @@ -112,7 +112,7 @@ create table if not exists token_bucket created_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), updated_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), expires_at TEXT not null -); +) strict; -- users create table if not exists users @@ -131,7 +131,7 @@ create table if not exists users created_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), updated_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), deleted_at TEXT -); +) strict; CREATE UNIQUE INDEX if not exists idx_unique_email_deleted_at ON users(email) @@ -149,7 +149,7 @@ create table if not exists organisations updated_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), deleted_at TEXT, FOREIGN KEY(owner_id) REFERENCES users(id) -); +) strict; create unique index if not exists idx_organisations_custom_domain_deleted_at on organisations (custom_domain, assigned_domain) @@ -171,7 +171,7 @@ create table if not exists projects constraint name_org_id_key unique (name, organisation_id, deleted_at), FOREIGN KEY(organisation_id) REFERENCES organisations(id), FOREIGN KEY(project_configuration_id) REFERENCES project_configurations(id) -); +) strict; create unique index if not exists idx_name_organisation_id_deleted_at on projects (organisation_id, name) @@ -189,7 +189,7 @@ create table if not exists applications updated_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), deleted_at TEXT, FOREIGN KEY(project_id) REFERENCES projects(id) -); +) strict; -- devices todo(raymond): deprecate me create table if not exists devices @@ -203,7 +203,7 @@ create table if not exists devices last_seen_at TEXT not null, deleted_at TEXT, FOREIGN KEY(project_id) REFERENCES projects(id) -); +) strict; -- endpoints create table if not exists endpoints @@ -230,7 +230,7 @@ create table if not exists endpoints name TEXT not null, url TEXT not null, FOREIGN KEY(project_id) REFERENCES projects(id) -); +) strict; CREATE UNIQUE INDEX if not exists idx_name_project_id_deleted_at ON endpoints(name, project_id) @@ -268,7 +268,7 @@ create table if not exists api_keys FOREIGN KEY(role_project) REFERENCES projects(id), FOREIGN KEY(role_endpoint) REFERENCES endpoints(id), FOREIGN KEY(user_id) REFERENCES users(id) -); +) strict; CREATE UNIQUE INDEX if not exists idx_mask_id_deleted_at ON api_keys(mask_id) @@ -289,7 +289,7 @@ create table if not exists event_types updated_at TEXT default (strftime('%Y-%m-%dT%H:%M:%fZ')) not null, deprecated_at TEXT, FOREIGN KEY(project_id) REFERENCES projects(id) -); +) strict; CREATE UNIQUE INDEX if not exists idx_name_project_id_deleted_at ON event_types(name, project_id) @@ -332,7 +332,7 @@ create table if not exists jobs updated_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), deleted_at TEXT, FOREIGN KEY(project_id) REFERENCES projects(id) -); +) strict; -- meta events create table if not exists meta_events @@ -347,7 +347,7 @@ create table if not exists meta_events updated_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), deleted_at TEXT, FOREIGN KEY(project_id) REFERENCES projects(id) -); +) strict; -- organisation invites create table if not exists organisation_invites @@ -367,7 +367,7 @@ create table if not exists organisation_invites FOREIGN KEY(role_project) REFERENCES projects(id), FOREIGN KEY(organisation_id) REFERENCES organisations(id), FOREIGN KEY(role_endpoint) REFERENCES endpoints(id) -); +) strict; CREATE UNIQUE INDEX if not exists idx_token_organisation_id_deleted_at ON organisation_invites(token, organisation_id) @@ -395,7 +395,7 @@ create table if not exists organisation_members FOREIGN KEY(role_endpoint) REFERENCES endpoints(id), FOREIGN KEY(organisation_id) REFERENCES organisations(id), FOREIGN KEY(user_id) REFERENCES users(id) -); +) strict; CREATE UNIQUE INDEX if not exists idx_organisation_id_user_id_deleted_at ON organisation_members(organisation_id, user_id) @@ -423,7 +423,7 @@ create table if not exists portal_links owner_id TEXT, can_manage_endpoint BOOLEAN default false, FOREIGN KEY(project_id) REFERENCES projects(id) -); +) strict; CREATE UNIQUE INDEX if not exists idx_token_deleted_at ON portal_links(token) @@ -444,7 +444,7 @@ create table if not exists portal_links_endpoints endpoint_id TEXT not null, FOREIGN KEY(portal_link_id) REFERENCES portal_links(id), FOREIGN KEY(endpoint_id) REFERENCES endpoints(id) -); +) strict; create index if not exists idx_portal_links_endpoints_enpdoint_id on portal_links_endpoints (endpoint_id); @@ -476,7 +476,7 @@ create table if not exists sources updated_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), FOREIGN KEY(project_id) REFERENCES projects(id), FOREIGN KEY(source_verifier_id) REFERENCES source_verifiers(id) -); +) strict; CREATE UNIQUE INDEX if not exists idx_mask_id_project_id_deleted_at ON sources(mask_id, project_id) @@ -513,7 +513,7 @@ create table if not exists events updated_at TEXT not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), FOREIGN KEY(source_id) REFERENCES sources(id), FOREIGN KEY(project_id) REFERENCES projects(name) -); +) strict; create index if not exists idx_events_created_at_key on events (created_at); @@ -572,7 +572,7 @@ create table if not exists subscriptions FOREIGN KEY(device_id) REFERENCES devices(id), FOREIGN KEY(endpoint_id) REFERENCES endpoints(id), FOREIGN KEY(project_id) REFERENCES projects(id) -); +) strict; CREATE UNIQUE INDEX if not exists idx_name_project_id_deleted_at ON subscriptions(name, project_id) @@ -625,7 +625,7 @@ create table if not exists event_deliveries FOREIGN KEY(event_id) REFERENCES events(id), FOREIGN KEY(endpoint_id) REFERENCES endpoints(id), FOREIGN KEY(project_id) REFERENCES projects(id) -); +) strict; -- delivery attempts create table if not exists delivery_attempts @@ -650,7 +650,7 @@ create table if not exists delivery_attempts FOREIGN KEY(event_delivery_id) REFERENCES event_deliveries(id), FOREIGN KEY(endpoint_id) REFERENCES endpoints(id), FOREIGN KEY(project_id) REFERENCES projects(id) -); +) strict; create index if not exists idx_delivery_attempts_created_at on delivery_attempts (created_at); From ad3563fe2fc06f79bbd54cb9c8305c15145d29de Mon Sep 17 00:00:00 2001 From: Raymond Tukpe Date: Mon, 18 Nov 2024 12:04:13 +0100 Subject: [PATCH 7/7] chore: in-progress add sqlite source impl --- database/sqlite3/endpoint.go | 46 +++++ database/sqlite3/source.go | 154 +++++++++----- database/sqlite3/source_test.go | 342 ++++++++++++++----------------- database/sqlite3/sqlite3.go | 9 + database/sqlite3/subscription.go | 19 +- 5 files changed, 327 insertions(+), 243 deletions(-) diff --git a/database/sqlite3/endpoint.go b/database/sqlite3/endpoint.go index c0a230cd41..1e20de276c 100644 --- a/database/sqlite3/endpoint.go +++ b/database/sqlite3/endpoint.go @@ -495,3 +495,49 @@ type EndpointSecret struct { Endpoint datastore.Endpoint `json:"endpoint"` Secret datastore.Secret `db:"secret"` } + +type dbEndpoint struct { + UID string `db:"id"` + Name string `db:"name"` + Status datastore.EndpointStatus `db:"status"` + OwnerID string `db:"owner_id"` + Url string `db:"url"` + Description string `db:"description"` + HttpTimeout uint64 `db:"http_timeout"` + RateLimit int `db:"rate_limit"` + RateLimitDuration uint64 `db:"rate_limit_duration"` + AdvancedSignatures bool `db:"advanced_signatures"` + SlackWebhookURL string `db:"slack_webhook_url"` + SupportEmail string `db:"support_email"` + AppID string `db:"app_id"` + ProjectID string `db:"project_id"` + Secrets datastore.Secrets `db:"secrets"` + Authentication *datastore.EndpointAuthentication `db:"authentication"` + CreatedAt string `db:"created_at"` + UpdatedAt string `db:"updated_at"` + DeletedAt *string `db:"deleted_at"` +} + +func (e *dbEndpoint) toDatastoreEndpoint() *datastore.Endpoint { + return &datastore.Endpoint{ + UID: e.UID, + Name: e.Name, + Status: e.Status, + OwnerID: e.OwnerID, + Url: e.Url, + Description: e.Description, + HttpTimeout: e.HttpTimeout, + RateLimit: e.RateLimit, + RateLimitDuration: e.RateLimitDuration, + AdvancedSignatures: e.AdvancedSignatures, + SlackWebhookURL: e.SlackWebhookURL, + SupportEmail: e.SupportEmail, + AppID: e.AppID, + ProjectID: e.ProjectID, + Secrets: e.Secrets, + Authentication: e.Authentication, + CreatedAt: asTime(e.CreatedAt), + UpdatedAt: asTime(e.UpdatedAt), + DeletedAt: asNullTime(e.DeletedAt), + } +} diff --git a/database/sqlite3/source.go b/database/sqlite3/source.go index 2635270f81..84ea8c6538 100644 --- a/database/sqlite3/source.go +++ b/database/sqlite3/source.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "github.com/lib/pq" + "time" "github.com/oklog/ulid/v2" @@ -33,6 +34,7 @@ const ( updateSourceById = ` UPDATE sources SET + updated_at = $1, name= $2, type=$3, mask_id=$4, @@ -45,13 +47,13 @@ const ( custom_response_content_type = $11, idempotency_keys = $12, body_function = $13, - header_function = $14, - updated_at = NOW() - WHERE id = $1 AND deleted_at IS NULL ; + header_function = $14 + WHERE id = $15 AND deleted_at IS NULL ; ` updateSourceVerifierById = ` UPDATE source_verifiers SET + updated_at = $1, type=$2, basic_username=$3, basic_password=$4, @@ -60,9 +62,8 @@ const ( hmac_hash=$7, hmac_header=$8, hmac_secret=$9, - hmac_encoding=$10, - updated_at = NOW() - WHERE id = $1 AND deleted_at IS NULL; + hmac_encoding=$10 + WHERE id = $11 AND deleted_at IS NULL; ` baseFetchSource = ` @@ -123,26 +124,26 @@ const ( deleteSource = ` UPDATE sources SET - deleted_at = NOW() - WHERE id = $1 AND project_id = $2 AND deleted_at IS NULL; + deleted_at = $1 + WHERE id = $2 AND project_id = $3 AND deleted_at IS NULL; ` deleteSourceVerifier = ` UPDATE source_verifiers SET - deleted_at = NOW() - WHERE id = $1 AND deleted_at IS NULL; + deleted_at = $1 + WHERE id = $2 AND deleted_at IS NULL; ` deleteSourceSubscription = ` UPDATE subscriptions SET - deleted_at = NOW() - WHERE source_id = $1 AND project_id = $2 AND deleted_at IS NULL; + deleted_at = $1 + WHERE source_id = $2 AND project_id = $3 AND deleted_at IS NULL; ` fetchSourcesPagedFilter = ` AND (s.type = :type OR :type = '') AND (s.provider = :provider OR :provider = '') - AND s.name ILIKE :query + AND s.name LIKE :query AND s.project_id = :project_id ` @@ -197,7 +198,6 @@ func NewSourceRepo(db database.Database) datastore.SourceRepository { } func (s *sourceRepo) CreateSource(ctx context.Context, source *datastore.Source) error { - var sourceVerifierID *string tx, err := s.db.BeginTxx(ctx, &sql.TxOptions{}) if err != nil { return err @@ -221,10 +221,9 @@ func (s *sourceRepo) CreateSource(ctx context.Context, source *datastore.Source) if !util.IsStringEmpty(string(source.Verifier.Type)) { id := ulid.Make().String() - sourceVerifierID = &id result2, err := tx.ExecContext( - ctx, createSourceVerifier, sourceVerifierID, source.Verifier.Type, basic.UserName, basic.Password, + ctx, createSourceVerifier, id, source.Verifier.Type, basic.UserName, basic.Password, apiKey.HeaderName, apiKey.HeaderValue, hmac.Hash, hmac.Header, hmac.Secret, hmac.Encoding, ) if err != nil { @@ -239,14 +238,12 @@ func (s *sourceRepo) CreateSource(ctx context.Context, source *datastore.Source) if rowsAffected < 1 { return ErrSourceVerifierNotCreated } - } - if !util.IsStringEmpty(string(source.Verifier.Type)) { - source.VerifierID = *sourceVerifierID + source.VerifierID = id } result1, err := tx.ExecContext( - ctx, createSource, source.UID, sourceVerifierID, source.Name, source.Type, source.MaskID, + ctx, createSource, source.UID, source.VerifierID, source.Name, source.Type, source.MaskID, source.Provider, source.IsDisabled, pq.Array(source.ForwardHeaders), source.ProjectID, source.PubSub, source.CustomResponse.Body, source.CustomResponse.ContentType, source.IdempotencyKeys, source.BodyFunction, source.HeaderFunction, @@ -280,10 +277,10 @@ func (s *sourceRepo) UpdateSource(ctx context.Context, projectID string, source defer rollbackTx(tx) result, err := tx.ExecContext( - ctx, updateSourceById, source.UID, source.Name, source.Type, source.MaskID, + ctx, updateSourceById, time.Now(), source.Name, source.Type, source.MaskID, source.Provider, source.IsDisabled, source.ForwardHeaders, projectID, source.PubSub, source.CustomResponse.Body, source.CustomResponse.ContentType, - source.IdempotencyKeys, source.BodyFunction, source.HeaderFunction, + source.IdempotencyKeys, source.BodyFunction, source.HeaderFunction, source.UID, ) if err != nil { return err @@ -314,8 +311,9 @@ func (s *sourceRepo) UpdateSource(ctx context.Context, projectID string, source if !util.IsStringEmpty(string(source.Verifier.Type)) { result2, err := tx.ExecContext( - ctx, updateSourceVerifierById, source.VerifierID, source.Verifier.Type, basic.UserName, basic.Password, - apiKey.HeaderName, apiKey.HeaderValue, hmac.Hash, hmac.Header, hmac.Secret, hmac.Encoding, + ctx, updateSourceVerifierById, time.Now(), source.Verifier.Type, + basic.UserName, basic.Password, apiKey.HeaderName, apiKey.HeaderValue, + hmac.Hash, hmac.Header, hmac.Secret, hmac.Encoding, source.VerifierID, ) if err != nil { return err @@ -340,7 +338,7 @@ func (s *sourceRepo) UpdateSource(ctx context.Context, projectID string, source } func (s *sourceRepo) FindSourceByID(ctx context.Context, projectId string, id string) (*datastore.Source, error) { - source := &datastore.Source{} + source := &dbSource{} err := s.db.QueryRowxContext(ctx, fmt.Sprintf(fetchSource, "s.id"), id).StructScan(source) if err != nil { if errors.Is(err, sql.ErrNoRows) { @@ -349,11 +347,11 @@ func (s *sourceRepo) FindSourceByID(ctx context.Context, projectId string, id st return nil, err } - return source, nil + return source.toDatastoreSource(), nil } func (s *sourceRepo) FindSourceByName(ctx context.Context, projectID string, name string) (*datastore.Source, error) { - source := &datastore.Source{} + source := &dbSource{} err := s.db.QueryRowxContext(ctx, fmt.Sprintf(fetchSourceByName, "s.project_id", "s.name"), projectID, name).StructScan(source) if err != nil { if errors.Is(err, sql.ErrNoRows) { @@ -362,11 +360,11 @@ func (s *sourceRepo) FindSourceByName(ctx context.Context, projectID string, nam return nil, err } - return source, nil + return source.toDatastoreSource(), nil } func (s *sourceRepo) FindSourceByMaskID(ctx context.Context, maskID string) (*datastore.Source, error) { - source := &datastore.Source{} + source := &dbSource{} err := s.db.QueryRowxContext(ctx, fmt.Sprintf(fetchSource, "s.mask_id"), maskID).StructScan(source) if err != nil { if errors.Is(err, sql.ErrNoRows) { @@ -375,7 +373,7 @@ func (s *sourceRepo) FindSourceByMaskID(ctx context.Context, maskID string) (*da return nil, err } - return source, nil + return source.toDatastoreSource(), nil } func (s *sourceRepo) DeleteSourceByID(ctx context.Context, projectId string, id, sourceVerifierId string) error { @@ -385,17 +383,17 @@ func (s *sourceRepo) DeleteSourceByID(ctx context.Context, projectId string, id, } defer rollbackTx(tx) - _, err = tx.ExecContext(ctx, deleteSourceVerifier, sourceVerifierId) + _, err = tx.ExecContext(ctx, deleteSourceVerifier, time.Now(), sourceVerifierId) if err != nil { return err } - _, err = tx.ExecContext(ctx, deleteSource, id, projectId) + _, err = tx.ExecContext(ctx, deleteSource, time.Now(), id, projectId) if err != nil { return err } - _, err = tx.ExecContext(ctx, deleteSourceSubscription, id, projectId) + _, err = tx.ExecContext(ctx, deleteSourceSubscription, time.Now(), id, projectId) if err != nil { return err } @@ -447,16 +445,16 @@ func (s *sourceRepo) LoadSourcesPaged(ctx context.Context, projectID string, fil sources := make([]datastore.Source, 0) for rows.Next() { - var source datastore.Source + source := dbSource{} err = rows.StructScan(&source) if err != nil { return nil, datastore.PaginationData{}, err } - sources = append(sources, source) + sources = append(sources, *source.toDatastoreSource()) } - var count datastore.PrevRowCount + var rowCount datastore.PrevRowCount if len(sources) > 0 { var countQuery string var qargs []interface{} @@ -473,16 +471,16 @@ func (s *sourceRepo) LoadSourcesPaged(ctx context.Context, projectID string, fil countQuery = s.db.Rebind(countQuery) // count the row number before the first row - rows, err := s.db.QueryxContext(ctx, countQuery, qargs...) - if err != nil { - return nil, datastore.PaginationData{}, err + resRows, innerErr := s.db.QueryxContext(ctx, countQuery, qargs...) + if innerErr != nil { + return nil, datastore.PaginationData{}, innerErr } - defer closeWithError(rows) + defer closeWithError(resRows) - if rows.Next() { - err = rows.StructScan(&count) - if err != nil { - return nil, datastore.PaginationData{}, err + if resRows.Next() { + innerErr = resRows.StructScan(&rowCount) + if innerErr != nil { + return nil, datastore.PaginationData{}, innerErr } } } @@ -496,7 +494,7 @@ func (s *sourceRepo) LoadSourcesPaged(ctx context.Context, projectID string, fil sources = sources[:len(sources)-1] } - pagination := &datastore.PaginationData{PrevRowCount: count} + pagination := &datastore.PaginationData{PrevRowCount: rowCount} pagination = pagination.Build(pageable, ids) return sources, *pagination, nil @@ -531,13 +529,13 @@ func (s *sourceRepo) LoadPubSubSourcesByProjectIDs(ctx context.Context, projectI sources := make([]datastore.Source, 0) for rows.Next() { - var source datastore.Source + source := dbSource{} err = rows.StructScan(&source) if err != nil { return nil, datastore.PaginationData{}, err } - sources = append(sources, source) + sources = append(sources, *source.toDatastoreSource()) } // Bypass pagination.Build here since we're only dealing with forward paging here @@ -557,3 +555,65 @@ func (s *sourceRepo) LoadPubSubSourcesByProjectIDs(ctx context.Context, projectI return sources, *pagination, nil } + +type dbSource struct { + UID string `db:"id"` + Name string `db:"name"` + Type string `db:"type"` + Provider string `db:"provider"` + MaskID string `db:"mask_id"` + ProjectID string `db:"project_id"` + IsDisabled bool `db:"is_disabled"` + ForwardHeaders *string `db:"forward_headers"` + PubSub *datastore.PubSubConfig `db:"pub_sub"` + VerifierID string `db:"source_verifier_id"` + Verifier *datastore.VerifierConfig `db:"verifier"` + CustomResponse datastore.CustomResponse `db:"custom_response"` + IdempotencyKeys *string `db:"idempotency_keys"` + BodyFunction *string `db:"body_function"` + HeaderFunction *string `db:"header_function"` + CreatedAt string `db:"created_at"` + UpdatedAt string `db:"updated_at"` + DeletedAt *string `db:"deleted_at"` +} + +func (s *dbSource) toDatastoreSource() *datastore.Source { + return &datastore.Source{ + UID: s.UID, + Name: s.Name, + Type: datastore.SourceType(s.Type), + Provider: datastore.SourceProvider(s.Provider), + MaskID: s.MaskID, + ProjectID: s.ProjectID, + IsDisabled: s.IsDisabled, + ForwardHeaders: asStringArray(s.ForwardHeaders), + PubSub: s.PubSub, + VerifierID: s.VerifierID, + Verifier: s.Verifier, + CustomResponse: s.CustomResponse, + IdempotencyKeys: asStringArray(s.IdempotencyKeys), + BodyFunction: s.BodyFunction, + HeaderFunction: s.HeaderFunction, + CreatedAt: asTime(s.CreatedAt), + UpdatedAt: asTime(s.UpdatedAt), + DeletedAt: asNullTime(s.DeletedAt), + } +} + +func scanSources(rows *sqlx.Rows) ([]datastore.Source, error) { + sources := make([]datastore.Source, 0) + var err error + defer closeWithError(rows) + + for rows.Next() { + source := dbSource{} + err = rows.StructScan(&source) + if err != nil { + return nil, err + } + + sources = append(sources, *source.toDatastoreSource()) + } + + return sources, nil +} diff --git a/database/sqlite3/source_test.go b/database/sqlite3/source_test.go index b3da30912e..c8cbf99fae 100644 --- a/database/sqlite3/source_test.go +++ b/database/sqlite3/source_test.go @@ -16,205 +16,185 @@ import ( "github.com/stretchr/testify/require" ) -func Test_CreateSource(t *testing.T) { - db, closeFn := getDB(t) - defer closeFn() - - sourceRepo := NewSourceRepo(db) - source := generateSource(t, db) - - require.NoError(t, sourceRepo.CreateSource(context.Background(), source)) - - newSource, err := sourceRepo.FindSourceByID(context.Background(), source.ProjectID, source.UID) - require.NoError(t, err) - - newSource.CreatedAt = time.Time{} - newSource.UpdatedAt = time.Time{} - - require.Equal(t, source, newSource) -} - -func Test_FindSourceByID(t *testing.T) { - db, closeFn := getDB(t) - defer closeFn() - - sourceRepo := NewSourceRepo(db) - source := generateSource(t, db) - - _, err := sourceRepo.FindSourceByID(context.Background(), source.ProjectID, source.UID) - - require.Error(t, err) - require.True(t, errors.Is(err, datastore.ErrSourceNotFound)) - - require.NoError(t, sourceRepo.CreateSource(context.Background(), source)) - - newSource, err := sourceRepo.FindSourceByID(context.Background(), source.ProjectID, source.UID) - require.NoError(t, err) - - newSource.CreatedAt = time.Time{} - newSource.UpdatedAt = time.Time{} - - require.Equal(t, source, newSource) -} - -func Test_FindSourceByName(t *testing.T) { - db, closeFn := getDB(t) - defer closeFn() - - sourceRepo := NewSourceRepo(db) - source := generateSource(t, db) - - _, err := sourceRepo.FindSourceByName(context.Background(), source.ProjectID, source.Name) - - require.Error(t, err) - require.True(t, errors.Is(err, datastore.ErrSourceNotFound)) - - require.NoError(t, sourceRepo.CreateSource(context.Background(), source)) - - newSource, err := sourceRepo.FindSourceByName(context.Background(), source.ProjectID, source.Name) - require.NoError(t, err) - - newSource.CreatedAt = time.Time{} - newSource.UpdatedAt = time.Time{} +func TestSourceRepo(t *testing.T) { + tests := []struct { + name string + testFunc func(t *testing.T, sourceRepo datastore.SourceRepository, source *datastore.Source) + }{ + { + name: "Create Source", + testFunc: func(t *testing.T, sourceRepo datastore.SourceRepository, source *datastore.Source) { + require.NoError(t, sourceRepo.CreateSource(context.Background(), source)) - require.Equal(t, source, newSource) -} + newSource, err := sourceRepo.FindSourceByID(context.Background(), source.ProjectID, source.UID) + require.NoError(t, err) -func Test_FindSourceByMaskID(t *testing.T) { - db, closeFn := getDB(t) - defer closeFn() + newSource.CreatedAt = time.Time{} + newSource.UpdatedAt = time.Time{} - sourceRepo := NewSourceRepo(db) - source := generateSource(t, db) + require.Equal(t, source, newSource) + }, + }, + { + name: "Find Source By ID", + testFunc: func(t *testing.T, sourceRepo datastore.SourceRepository, source *datastore.Source) { + _, err := sourceRepo.FindSourceByID(context.Background(), source.ProjectID, source.UID) + require.Error(t, err) + require.True(t, errors.Is(err, datastore.ErrSourceNotFound)) - _, err := sourceRepo.FindSourceByMaskID(context.Background(), source.MaskID) + require.NoError(t, sourceRepo.CreateSource(context.Background(), source)) - require.Error(t, err) - require.True(t, errors.Is(err, datastore.ErrSourceNotFound)) + newSource, err := sourceRepo.FindSourceByID(context.Background(), source.ProjectID, source.UID) + require.NoError(t, err) - require.NoError(t, sourceRepo.CreateSource(context.Background(), source)) + newSource.CreatedAt = time.Time{} + newSource.UpdatedAt = time.Time{} - newSource, err := sourceRepo.FindSourceByMaskID(context.Background(), source.MaskID) - require.NoError(t, err) + require.Equal(t, source, newSource) + }, + }, + { + name: "Find Source By Name", + testFunc: func(t *testing.T, sourceRepo datastore.SourceRepository, source *datastore.Source) { + _, err := sourceRepo.FindSourceByName(context.Background(), source.ProjectID, source.Name) + require.Error(t, err) + require.True(t, errors.Is(err, datastore.ErrSourceNotFound)) - newSource.CreatedAt = time.Time{} - newSource.UpdatedAt = time.Time{} + require.NoError(t, sourceRepo.CreateSource(context.Background(), source)) - require.Equal(t, source, newSource) -} + newSource, err := sourceRepo.FindSourceByName(context.Background(), source.ProjectID, source.Name) + require.NoError(t, err) -func Test_UpdateSource(t *testing.T) { - db, closeFn := getDB(t) - defer closeFn() + newSource.CreatedAt = time.Time{} + newSource.UpdatedAt = time.Time{} - sourceRepo := NewSourceRepo(db) - source := generateSource(t, db) + require.Equal(t, source, newSource) + }, + }, + { + name: "Find Source By Mask ID", + testFunc: func(t *testing.T, sourceRepo datastore.SourceRepository, source *datastore.Source) { + _, err := sourceRepo.FindSourceByMaskID(context.Background(), source.MaskID) + require.Error(t, err) + require.True(t, errors.Is(err, datastore.ErrSourceNotFound)) - require.NoError(t, sourceRepo.CreateSource(context.Background(), source)) + require.NoError(t, sourceRepo.CreateSource(context.Background(), source)) - name := "Convoy-Dev" - source.Name = name - source.IsDisabled = true - source.CustomResponse = datastore.CustomResponse{ - Body: "/ref/", - ContentType: "application/json", - } - require.NoError(t, sourceRepo.UpdateSource(context.Background(), source.ProjectID, source)) + newSource, err := sourceRepo.FindSourceByMaskID(context.Background(), source.MaskID) + require.NoError(t, err) - newSource, err := sourceRepo.FindSourceByID(context.Background(), source.ProjectID, source.UID) - require.NoError(t, err) + newSource.CreatedAt = time.Time{} + newSource.UpdatedAt = time.Time{} - newSource.CreatedAt = time.Time{} - newSource.UpdatedAt = time.Time{} + require.Equal(t, source, newSource) + }, + }, + { + name: "Update Source", + testFunc: func(t *testing.T, sourceRepo datastore.SourceRepository, source *datastore.Source) { + require.NoError(t, sourceRepo.CreateSource(context.Background(), source)) - require.Equal(t, source, newSource) -} + name := "Convoy-Dev" + source.Name = name + source.IsDisabled = true + source.CustomResponse = datastore.CustomResponse{ + Body: "/ref/", + ContentType: "application/json", + } + require.NoError(t, sourceRepo.UpdateSource(context.Background(), source.ProjectID, source)) -func Test_DeleteSource(t *testing.T) { - db, closeFn := getDB(t) - defer closeFn() + newSource, err := sourceRepo.FindSourceByID(context.Background(), source.ProjectID, source.UID) + require.NoError(t, err) - sourceRepo := NewSourceRepo(db) - subRepo := NewSubscriptionRepo(db) - source := generateSource(t, db) + newSource.CreatedAt = time.Time{} + newSource.UpdatedAt = time.Time{} - require.NoError(t, sourceRepo.CreateSource(context.Background(), source)) - - sub := &datastore.Subscription{ - Name: "test_sub", - Type: datastore.SubscriptionTypeAPI, - ProjectID: source.ProjectID, - SourceID: source.UID, - AlertConfig: &datastore.DefaultAlertConfig, - RetryConfig: &datastore.DefaultRetryConfig, - FilterConfig: &datastore.FilterConfiguration{ - EventTypes: []string{"*"}, - Filter: datastore.FilterSchema{ - Headers: datastore.M{}, - Body: datastore.M{}, + require.Equal(t, source, newSource) }, }, - RateLimitConfig: &datastore.DefaultRateLimitConfig, - } + { + name: "Delete Source", + testFunc: func(t *testing.T, sourceRepo datastore.SourceRepository, source *datastore.Source) { + db, _ := getDB(t) - err := subRepo.CreateSubscription(context.Background(), source.ProjectID, sub) - require.NoError(t, err) + subRepo := NewSubscriptionRepo(db) + require.NoError(t, sourceRepo.CreateSource(context.Background(), source)) - _, err = sourceRepo.FindSourceByID(context.Background(), source.ProjectID, source.UID) - require.NoError(t, err) + project := seedProject(t, db) + endpoint := seedEndpoint(t, db) - require.NoError(t, sourceRepo.DeleteSourceByID(context.Background(), source.ProjectID, source.UID, source.VerifierID)) + sub := generateSubscription(project, source, endpoint, &datastore.Device{}) + require.NoError(t, subRepo.CreateSubscription(context.Background(), sub.ProjectID, sub)) - _, err = sourceRepo.FindSourceByID(context.Background(), source.ProjectID, source.UID) - require.Error(t, err) - require.True(t, errors.Is(err, datastore.ErrSourceNotFound)) + _, err := sourceRepo.FindSourceByID(context.Background(), source.ProjectID, source.UID) + require.NoError(t, err) - _, err = subRepo.FindSubscriptionByID(context.Background(), source.ProjectID, sub.UID) - require.Error(t, err) - require.True(t, errors.Is(err, datastore.ErrSubscriptionNotFound)) -} + require.NoError(t, sourceRepo.DeleteSourceByID(context.Background(), source.ProjectID, source.UID, source.VerifierID)) -func Test_LoadSourcesPaged(t *testing.T) { - type Expected struct { - paginationData datastore.PaginationData - } + _, err = sourceRepo.FindSourceByID(context.Background(), source.ProjectID, source.UID) + require.Error(t, err) + require.True(t, errors.Is(err, datastore.ErrSourceNotFound)) - tests := []struct { - name string - pageData datastore.Pageable - count int - expected Expected - }{ - { - name: "Load Sources Paged - 10 records", - pageData: datastore.Pageable{PerPage: 3}, - count: 10, - expected: Expected{ - paginationData: datastore.PaginationData{ - PerPage: 3, - }, + _, err = subRepo.FindSubscriptionByID(context.Background(), source.ProjectID, sub.UID) + require.Error(t, err) + require.True(t, errors.Is(err, datastore.ErrSubscriptionNotFound)) }, }, - { - name: "Load Sources Paged - 12 records", - pageData: datastore.Pageable{PerPage: 4}, - count: 12, - expected: Expected{ - paginationData: datastore.PaginationData{ - PerPage: 4, - }, - }, - }, + name: "Load Sources Paged", + testFunc: func(t *testing.T, sourceRepo datastore.SourceRepository, source *datastore.Source) { + pagingTests := []struct { + name string + pageData datastore.Pageable + count int + perPage int64 + }{ + { + name: "10 records - 3 per page", + pageData: datastore.Pageable{PerPage: 3}, + count: 10, + perPage: 3, + }, + { + name: "12 records - 4 per page", + pageData: datastore.Pageable{PerPage: 4}, + count: 12, + perPage: 4, + }, + { + name: "5 records - 3 per page", + pageData: datastore.Pageable{PerPage: 3}, + count: 5, + perPage: 3, + }, + } - { - name: "Load Sources Paged - 5 records", - pageData: datastore.Pageable{PerPage: 3}, - count: 5, - expected: Expected{ - paginationData: datastore.PaginationData{ - PerPage: 3, - }, + for _, pt := range pagingTests { + t.Run(pt.name, func(t *testing.T) { + for i := 0; i < pt.count; i++ { + s := &datastore.Source{ + UID: ulid.Make().String(), + ProjectID: source.ProjectID, + Name: "Convoy-Prod", + MaskID: uniuri.NewLen(16), + Type: datastore.HTTPSource, + Verifier: &datastore.VerifierConfig{ + Type: datastore.HMacVerifier, + HMac: &datastore.HMac{ + Header: "X-Paystack-Signature", + Hash: "SHA512", + Secret: "Paystack Secret", + }, + }, + } + require.NoError(t, sourceRepo.CreateSource(context.Background(), s)) + } + + _, pageable, err := sourceRepo.LoadSourcesPaged(context.Background(), source.ProjectID, &datastore.SourceFilter{}, pt.pageData) + require.NoError(t, err) + require.Equal(t, pt.perPage, pageable.PerPage) + }) + } }, }, } @@ -224,32 +204,10 @@ func Test_LoadSourcesPaged(t *testing.T) { db, closeFn := getDB(t) defer closeFn() - sourceRepo := NewSourceRepo(db) - project := seedProject(t, db) - - for i := 0; i < tc.count; i++ { - source := &datastore.Source{ - UID: ulid.Make().String(), - ProjectID: project.UID, - Name: "Convoy-Prod", - MaskID: uniuri.NewLen(16), - Type: datastore.HTTPSource, - Verifier: &datastore.VerifierConfig{ - Type: datastore.HMacVerifier, - HMac: &datastore.HMac{ - Header: "X-Paystack-Signature", - Hash: "SHA512", - Secret: "Paystack Secret", - }, - }, - } - require.NoError(t, sourceRepo.CreateSource(context.Background(), source)) - } - - _, pageable, err := sourceRepo.LoadSourcesPaged(context.Background(), project.UID, &datastore.SourceFilter{}, tc.pageData) + repo := NewSourceRepo(db) + source := generateSource(t, db) - require.NoError(t, err) - require.Equal(t, tc.expected.paginationData.PerPage, pageable.PerPage) + tc.testFunc(t, repo, source) }) } } diff --git a/database/sqlite3/sqlite3.go b/database/sqlite3/sqlite3.go index 723fbd97e2..ba1b2c6329 100644 --- a/database/sqlite3/sqlite3.go +++ b/database/sqlite3/sqlite3.go @@ -7,6 +7,7 @@ import ( "fmt" "github.com/frain-dev/convoy/database/hooks" "github.com/frain-dev/convoy/pkg/log" + "github.com/lib/pq" "gopkg.in/guregu/null.v4" "io" "time" @@ -140,3 +141,11 @@ func asNullTime(ts *string) null.Time { } return null.NewTime(t, true) } + +func asStringArray(a *string) pq.StringArray { + if a == nil { + return nil + } + + return pq.StringArray{*a} +} diff --git a/database/sqlite3/subscription.go b/database/sqlite3/subscription.go index 699153da45..9da7bcdb3f 100644 --- a/database/sqlite3/subscription.go +++ b/database/sqlite3/subscription.go @@ -852,8 +852,8 @@ type dbSubscription struct { EndpointID string `db:"endpoint_id"` DeviceID string `db:"device_id"` Function null.String `db:"function"` - Source *datastore.Source `db:"source_metadata"` - Endpoint *datastore.Endpoint `db:"endpoint_metadata"` + Source *dbSource `db:"source_metadata"` + Endpoint *dbEndpoint `db:"endpoint_metadata"` Device *datastore.Device `db:"device_metadata"` AlertConfig *datastore.AlertConfiguration `db:"alert_config"` RetryConfig *datastore.RetryConfiguration `db:"retry_config"` @@ -865,6 +865,17 @@ type dbSubscription struct { } func (ss *dbSubscription) toDatastoreSubscription() *datastore.Subscription { + + var src *datastore.Source + if ss.Source != nil { + src = ss.Source.toDatastoreSource() + } + + var end *datastore.Endpoint + if ss.Endpoint != nil { + end = ss.Endpoint.toDatastoreEndpoint() + } + return &datastore.Subscription{ UID: ss.UID, Name: ss.Name, @@ -874,8 +885,8 @@ func (ss *dbSubscription) toDatastoreSubscription() *datastore.Subscription { EndpointID: ss.EndpointID, DeviceID: ss.DeviceID, Function: ss.Function, - Source: ss.Source, - Endpoint: ss.Endpoint, + Source: src, + Endpoint: end, Device: ss.Device, AlertConfig: ss.AlertConfig, RetryConfig: ss.RetryConfig,