Skip to content

Commit

Permalink
Merge pull request #1365 from openmeterio/notification-api
Browse files Browse the repository at this point in the history
feat: add tx support to Notification API
  • Loading branch information
chrisgacsal authored Aug 14, 2024
2 parents 41c7f0e + 8555afa commit bafb1cf
Show file tree
Hide file tree
Showing 3 changed files with 306 additions and 139 deletions.
68 changes: 68 additions & 0 deletions internal/notification/repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,26 @@ package notification

import (
"context"
"fmt"

"github.com/openmeterio/openmeter/pkg/pagination"
)

type TxRepository interface {
ChannelRepository
RuleRepository
EventRepository

Commit() error
Rollback() error
}

type Repository interface {
ChannelRepository
RuleRepository
EventRepository

WithTx(context.Context) (TxRepository, error)
}

type ChannelRepository interface {
Expand All @@ -36,3 +48,59 @@ type EventRepository interface {
GetEventDeliveryStatus(ctx context.Context, params GetEventDeliveryStatusInput) (*EventDeliveryStatus, error)
UpdateEventDeliveryStatus(ctx context.Context, params UpdateEventDeliveryStatusInput) (*EventDeliveryStatus, error)
}

func WithTxNoValue(ctx context.Context, repo Repository, fn func(ctx context.Context, repo TxRepository) error) error {
var err error

wrapped := func(ctx context.Context, repo TxRepository) (interface{}, error) {
if err = fn(ctx, repo); err != nil {
return nil, err
}

return nil, nil
}

_, err = WithTx[any](ctx, repo, wrapped)

return err
}

func WithTx[T any](ctx context.Context, repo Repository, fn func(ctx context.Context, repo TxRepository) (T, error)) (resp T, err error) {
var txRepo TxRepository

txRepo, err = repo.WithTx(ctx)
if err != nil {
return resp, fmt.Errorf("failed to start transaction: %w", err)
}
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("recovered from panic: %v: %w", r, err)

if e := txRepo.Rollback(); e != nil {
err = fmt.Errorf("failed to rollback transaction: %w: %w", e, err)
}

return
}

if err != nil {
if e := txRepo.Rollback(); e != nil {
err = fmt.Errorf("failed to rollback transaction: %w: %w", e, err)
}

return
}

if e := txRepo.Commit(); e != nil {
err = fmt.Errorf("failed to commit transaction: %w", e)
}
}()

resp, err = fn(ctx, txRepo)
if err != nil {
err = fmt.Errorf("failed to execute transaction: %w", err)
return
}

return
}
120 changes: 97 additions & 23 deletions internal/notification/repository/repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,56 @@ var _ notification.Repository = (*repository)(nil)

type repository struct {
db *entdb.Client
tx *entdb.Tx

logger *slog.Logger
}

func (r repository) Commit() error {
if r.tx != nil {
return r.tx.Commit()
}

return nil
}

func (r repository) Rollback() error {
if r.tx != nil {
return r.tx.Rollback()
}

return nil
}

func (r repository) client() *entdb.Client {
if r.tx != nil {
return r.tx.Client()
}

return r.db
}

func (r repository) WithTx(ctx context.Context) (notification.TxRepository, error) {
if r.tx != nil {
return r, nil
}

tx, err := r.db.Tx(ctx)
if err != nil {
return nil, fmt.Errorf("failed to create transaction: %w", err)
}

return &repository{
db: r.db,
tx: tx,
logger: r.logger,
}, nil
}

func (r repository) ListChannels(ctx context.Context, params notification.ListChannelsInput) (pagination.PagedResponse[notification.Channel], error) {
query := r.db.NotificationChannel.Query().
db := r.client()

query := db.NotificationChannel.Query().
Where(channeldb.DeletedAtIsNil()) // Do not return deleted channels

if len(params.Namespaces) > 0 {
Expand Down Expand Up @@ -116,7 +160,9 @@ func (r repository) ListChannels(ctx context.Context, params notification.ListCh
}

func (r repository) CreateChannel(ctx context.Context, params notification.CreateChannelInput) (*notification.Channel, error) {
query := r.db.NotificationChannel.Create().
db := r.client()

query := db.NotificationChannel.Create().
SetType(params.Type).
SetName(params.Name).
SetNamespace(params.Namespace).
Expand All @@ -136,7 +182,9 @@ func (r repository) CreateChannel(ctx context.Context, params notification.Creat
}

func (r repository) DeleteChannel(ctx context.Context, params notification.DeleteChannelInput) error {
query := r.db.NotificationChannel.UpdateOneID(params.ID).
db := r.client()

query := db.NotificationChannel.UpdateOneID(params.ID).
SetDeletedAt(clock.Now().UTC()).
SetDisabled(true)

Expand All @@ -158,7 +206,9 @@ func (r repository) DeleteChannel(ctx context.Context, params notification.Delet
}

func (r repository) GetChannel(ctx context.Context, params notification.GetChannelInput) (*notification.Channel, error) {
query := r.db.NotificationChannel.Query().
db := r.client()

query := db.NotificationChannel.Query().
Where(channeldb.ID(params.ID)).
Where(channeldb.Namespace(params.Namespace))

Expand All @@ -184,7 +234,9 @@ func (r repository) GetChannel(ctx context.Context, params notification.GetChann
}

func (r repository) UpdateChannel(ctx context.Context, params notification.UpdateChannelInput) (*notification.Channel, error) {
query := r.db.NotificationChannel.UpdateOneID(params.ID).
db := r.client()

query := db.NotificationChannel.UpdateOneID(params.ID).
SetUpdatedAt(clock.Now().UTC()).
SetDisabled(params.Disabled).
SetConfig(params.Config).
Expand Down Expand Up @@ -212,7 +264,9 @@ func (r repository) UpdateChannel(ctx context.Context, params notification.Updat
}

func (r repository) ListRules(ctx context.Context, params notification.ListRulesInput) (pagination.PagedResponse[notification.Rule], error) {
query := r.db.NotificationRule.Query().
db := r.client()

query := db.NotificationRule.Query().
Where(ruledb.DeletedAtIsNil()) // Do not return deleted Rules

if len(params.Namespaces) > 0 {
Expand Down Expand Up @@ -273,7 +327,9 @@ func (r repository) ListRules(ctx context.Context, params notification.ListRules
}

func (r repository) CreateRule(ctx context.Context, params notification.CreateRuleInput) (*notification.Rule, error) {
query := r.db.NotificationRule.Create().
db := r.client()

query := db.NotificationRule.Create().
SetType(params.Type).
SetName(params.Name).
SetNamespace(params.Namespace).
Expand All @@ -290,7 +346,7 @@ func (r repository) CreateRule(ctx context.Context, params notification.CreateRu
return nil, fmt.Errorf("invalid query result: nil notification rule received")
}

channelsQuery := r.db.NotificationChannel.Query().
channelsQuery := db.NotificationChannel.Query().
Where(channeldb.Namespace(params.Namespace)).
Where(channeldb.IDIn(params.Channels...))

Expand All @@ -305,7 +361,9 @@ func (r repository) CreateRule(ctx context.Context, params notification.CreateRu
}

func (r repository) DeleteRule(ctx context.Context, params notification.DeleteRuleInput) error {
query := r.db.NotificationRule.UpdateOneID(params.ID).
db := r.client()

query := db.NotificationRule.UpdateOneID(params.ID).
Where(ruledb.Namespace(params.Namespace)).
SetDeletedAt(clock.Now().UTC()).
SetDisabled(true)
Expand All @@ -328,7 +386,9 @@ func (r repository) DeleteRule(ctx context.Context, params notification.DeleteRu
}

func (r repository) GetRule(ctx context.Context, params notification.GetRuleInput) (*notification.Rule, error) {
query := r.db.NotificationRule.Query().
db := r.client()

query := db.NotificationRule.Query().
Where(ruledb.ID(params.ID)).
Where(ruledb.Namespace(params.Namespace)).
WithChannels()
Expand All @@ -355,7 +415,9 @@ func (r repository) GetRule(ctx context.Context, params notification.GetRuleInpu
}

func (r repository) UpdateRule(ctx context.Context, params notification.UpdateRuleInput) (*notification.Rule, error) {
query := r.db.NotificationRule.UpdateOneID(params.ID).
db := r.client()

query := db.NotificationRule.UpdateOneID(params.ID).
SetUpdatedAt(clock.Now().UTC()).
SetDisabled(params.Disabled).
SetConfig(params.Config).
Expand All @@ -380,7 +442,7 @@ func (r repository) UpdateRule(ctx context.Context, params notification.UpdateRu
return nil, fmt.Errorf("invalid query result: nil notification rule received")
}

channelsQuery := r.db.NotificationChannel.Query().
channelsQuery := db.NotificationChannel.Query().
Where(channeldb.Namespace(params.Namespace)).
Where(channeldb.IDIn(params.Channels...))

Expand All @@ -395,7 +457,9 @@ func (r repository) UpdateRule(ctx context.Context, params notification.UpdateRu
}

func (r repository) ListEvents(ctx context.Context, params notification.ListEventsInput) (pagination.PagedResponse[notification.Event], error) {
query := r.db.NotificationEvent.Query()
db := r.client()

query := db.NotificationEvent.Query()

if len(params.Namespaces) > 0 {
query = query.Where(eventdb.NamespaceIn(params.Namespaces...))
Expand Down Expand Up @@ -471,7 +535,9 @@ func (r repository) ListEvents(ctx context.Context, params notification.ListEven
}

func (r repository) GetEvent(ctx context.Context, params notification.GetEventInput) (*notification.Event, error) {
query := r.db.NotificationEvent.Query().
db := r.client()

query := db.NotificationEvent.Query().
Where(eventdb.Namespace(params.Namespace)).
Where(eventdb.ID(params.ID)).
WithDeliveryStatuses().
Expand Down Expand Up @@ -509,7 +575,9 @@ func (r repository) CreateEvent(ctx context.Context, params notification.CreateE
return nil, fmt.Errorf("failed to serialize notification event payload: %w", err)
}

query := r.db.NotificationEvent.Create().
db := r.client()

query := db.NotificationEvent.Create().
SetType(params.Type).
SetNamespace(params.Namespace).
SetRuleID(params.RuleID).
Expand All @@ -524,7 +592,7 @@ func (r repository) CreateEvent(ctx context.Context, params notification.CreateE
return nil, errors.New("invalid query response: nil notification event received")
}

ruleQuery := r.db.NotificationRule.Query().
ruleQuery := db.NotificationRule.Query().
Where(ruledb.Namespace(params.Namespace)).
Where(ruledb.ID(params.RuleID)).
Where(ruledb.DeletedAtIsNil()).
Expand Down Expand Up @@ -560,7 +628,7 @@ func (r repository) CreateEvent(ctx context.Context, params notification.CreateE
continue
}

q := r.db.NotificationEventDeliveryStatus.Create().
q := db.NotificationEventDeliveryStatus.Create().
SetNamespace(params.Namespace).
SetEventID(eventRow.ID).
SetChannelID(channel.ID).
Expand All @@ -570,7 +638,7 @@ func (r repository) CreateEvent(ctx context.Context, params notification.CreateE
statusBulkQuery = append(statusBulkQuery, q)
}

statusQuery := r.db.NotificationEventDeliveryStatus.CreateBulk(statusBulkQuery...)
statusQuery := db.NotificationEventDeliveryStatus.CreateBulk(statusBulkQuery...)

statusRows, err := statusQuery.Save(ctx)
if err != nil {
Expand All @@ -588,7 +656,9 @@ func (r repository) CreateEvent(ctx context.Context, params notification.CreateE
}

func (r repository) ListEventsDeliveryStatus(ctx context.Context, params notification.ListEventsDeliveryStatusInput) (pagination.PagedResponse[notification.EventDeliveryStatus], error) {
query := r.db.NotificationEventDeliveryStatus.Query()
db := r.client()

query := db.NotificationEventDeliveryStatus.Query()

if len(params.Namespaces) > 0 {
query = query.Where(statusdb.NamespaceIn(params.Namespaces...))
Expand Down Expand Up @@ -640,7 +710,9 @@ func (r repository) ListEventsDeliveryStatus(ctx context.Context, params notific
}

func (r repository) GetEventDeliveryStatus(ctx context.Context, params notification.GetEventDeliveryStatusInput) (*notification.EventDeliveryStatus, error) {
query := r.db.NotificationEventDeliveryStatus.Query().
db := r.client()

query := db.NotificationEventDeliveryStatus.Query().
Where(statusdb.Namespace(params.Namespace)).
Where(statusdb.ID(params.ID))

Expand All @@ -667,10 +739,12 @@ func (r repository) GetEventDeliveryStatus(ctx context.Context, params notificat
func (r repository) UpdateEventDeliveryStatus(ctx context.Context, params notification.UpdateEventDeliveryStatusInput) (*notification.EventDeliveryStatus, error) {
var updateQuery *entdb.NotificationEventDeliveryStatusUpdateOne

db := r.client()

if params.ID != "" {
updateQuery = r.db.NotificationEventDeliveryStatus.UpdateOneID(params.ID).SetState(params.State)
updateQuery = db.NotificationEventDeliveryStatus.UpdateOneID(params.ID).SetState(params.State)
} else {
getQuery := r.db.NotificationEventDeliveryStatus.Query().
getQuery := db.NotificationEventDeliveryStatus.Query().
Where(statusdb.Namespace(params.Namespace)).
Where(statusdb.EventID(params.EventID)).
Where(statusdb.ChannelID(params.ChannelID))
Expand All @@ -689,7 +763,7 @@ func (r repository) UpdateEventDeliveryStatus(ctx context.Context, params notifi
return nil, fmt.Errorf("failed to udpate notification event delivery status: %w", err)
}

updateQuery = r.db.NotificationEventDeliveryStatus.UpdateOne(statusRow).
updateQuery = db.NotificationEventDeliveryStatus.UpdateOne(statusRow).
SetState(params.State).
SetReason(params.Reason)
}
Expand Down
Loading

0 comments on commit bafb1cf

Please sign in to comment.