Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add tx support to Notification API #1365

Merged
merged 1 commit into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions internal/notification/repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,26 @@ package notification

import (
"context"
"fmt"

"github.com/openmeterio/openmeter/pkg/pagination"
)

type TxRepository interface {
ChannelRepository
RuleRepository
EventRepository

Commit() error
Rollback() error
}

type Repository interface {
ChannelRepository
RuleRepository
EventRepository

WithTx(context.Context) (TxRepository, error)
}

type ChannelRepository interface {
Expand All @@ -36,3 +48,59 @@ type EventRepository interface {
GetEventDeliveryStatus(ctx context.Context, params GetEventDeliveryStatusInput) (*EventDeliveryStatus, error)
UpdateEventDeliveryStatus(ctx context.Context, params UpdateEventDeliveryStatusInput) (*EventDeliveryStatus, error)
}

func WithTxNoValue(ctx context.Context, repo Repository, fn func(ctx context.Context, repo TxRepository) error) error {
var err error

wrapped := func(ctx context.Context, repo TxRepository) (interface{}, error) {
if err = fn(ctx, repo); err != nil {
return nil, err
}

return nil, nil
}

_, err = WithTx[any](ctx, repo, wrapped)

return err
}

func WithTx[T any](ctx context.Context, repo Repository, fn func(ctx context.Context, repo TxRepository) (T, error)) (resp T, err error) {
var txRepo TxRepository

txRepo, err = repo.WithTx(ctx)
if err != nil {
return resp, fmt.Errorf("failed to start transaction: %w", err)
}
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("recovered from panic: %v: %w", r, err)

if e := txRepo.Rollback(); e != nil {
err = fmt.Errorf("failed to rollback transaction: %w: %w", e, err)
}

return
}

if err != nil {
if e := txRepo.Rollback(); e != nil {
err = fmt.Errorf("failed to rollback transaction: %w: %w", e, err)
}

return
}

if e := txRepo.Commit(); e != nil {
err = fmt.Errorf("failed to commit transaction: %w", e)
}
}()

resp, err = fn(ctx, txRepo)
if err != nil {
err = fmt.Errorf("failed to execute transaction: %w", err)
return
}

return
}
120 changes: 97 additions & 23 deletions internal/notification/repository/repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,56 @@ var _ notification.Repository = (*repository)(nil)

type repository struct {
db *entdb.Client
tx *entdb.Tx

logger *slog.Logger
}

func (r repository) Commit() error {
if r.tx != nil {
return r.tx.Commit()
}

return nil
}

func (r repository) Rollback() error {
if r.tx != nil {
return r.tx.Rollback()
}

return nil
}

func (r repository) client() *entdb.Client {
if r.tx != nil {
return r.tx.Client()
}

return r.db
}

func (r repository) WithTx(ctx context.Context) (notification.TxRepository, error) {
if r.tx != nil {
return r, nil
}

tx, err := r.db.Tx(ctx)
if err != nil {
return nil, fmt.Errorf("failed to create transaction: %w", err)
}

return &repository{
db: r.db,
tx: tx,
logger: r.logger,
}, nil
}

func (r repository) ListChannels(ctx context.Context, params notification.ListChannelsInput) (pagination.PagedResponse[notification.Channel], error) {
query := r.db.NotificationChannel.Query().
db := r.client()

query := db.NotificationChannel.Query().
Where(channeldb.DeletedAtIsNil()) // Do not return deleted channels

if len(params.Namespaces) > 0 {
Expand Down Expand Up @@ -116,7 +160,9 @@ func (r repository) ListChannels(ctx context.Context, params notification.ListCh
}

func (r repository) CreateChannel(ctx context.Context, params notification.CreateChannelInput) (*notification.Channel, error) {
query := r.db.NotificationChannel.Create().
db := r.client()

query := db.NotificationChannel.Create().
SetType(params.Type).
SetName(params.Name).
SetNamespace(params.Namespace).
Expand All @@ -136,7 +182,9 @@ func (r repository) CreateChannel(ctx context.Context, params notification.Creat
}

func (r repository) DeleteChannel(ctx context.Context, params notification.DeleteChannelInput) error {
query := r.db.NotificationChannel.UpdateOneID(params.ID).
db := r.client()

query := db.NotificationChannel.UpdateOneID(params.ID).
SetDeletedAt(clock.Now().UTC()).
SetDisabled(true)

Expand All @@ -158,7 +206,9 @@ func (r repository) DeleteChannel(ctx context.Context, params notification.Delet
}

func (r repository) GetChannel(ctx context.Context, params notification.GetChannelInput) (*notification.Channel, error) {
query := r.db.NotificationChannel.Query().
db := r.client()

query := db.NotificationChannel.Query().
Where(channeldb.ID(params.ID)).
Where(channeldb.Namespace(params.Namespace))

Expand All @@ -184,7 +234,9 @@ func (r repository) GetChannel(ctx context.Context, params notification.GetChann
}

func (r repository) UpdateChannel(ctx context.Context, params notification.UpdateChannelInput) (*notification.Channel, error) {
query := r.db.NotificationChannel.UpdateOneID(params.ID).
db := r.client()

query := db.NotificationChannel.UpdateOneID(params.ID).
SetUpdatedAt(clock.Now().UTC()).
SetDisabled(params.Disabled).
SetConfig(params.Config).
Expand Down Expand Up @@ -212,7 +264,9 @@ func (r repository) UpdateChannel(ctx context.Context, params notification.Updat
}

func (r repository) ListRules(ctx context.Context, params notification.ListRulesInput) (pagination.PagedResponse[notification.Rule], error) {
query := r.db.NotificationRule.Query().
db := r.client()

query := db.NotificationRule.Query().
Where(ruledb.DeletedAtIsNil()) // Do not return deleted Rules

if len(params.Namespaces) > 0 {
Expand Down Expand Up @@ -273,7 +327,9 @@ func (r repository) ListRules(ctx context.Context, params notification.ListRules
}

func (r repository) CreateRule(ctx context.Context, params notification.CreateRuleInput) (*notification.Rule, error) {
query := r.db.NotificationRule.Create().
db := r.client()

query := db.NotificationRule.Create().
SetType(params.Type).
SetName(params.Name).
SetNamespace(params.Namespace).
Expand All @@ -290,7 +346,7 @@ func (r repository) CreateRule(ctx context.Context, params notification.CreateRu
return nil, fmt.Errorf("invalid query result: nil notification rule received")
}

channelsQuery := r.db.NotificationChannel.Query().
channelsQuery := db.NotificationChannel.Query().
Where(channeldb.Namespace(params.Namespace)).
Where(channeldb.IDIn(params.Channels...))

Expand All @@ -305,7 +361,9 @@ func (r repository) CreateRule(ctx context.Context, params notification.CreateRu
}

func (r repository) DeleteRule(ctx context.Context, params notification.DeleteRuleInput) error {
query := r.db.NotificationRule.UpdateOneID(params.ID).
db := r.client()

query := db.NotificationRule.UpdateOneID(params.ID).
Where(ruledb.Namespace(params.Namespace)).
SetDeletedAt(clock.Now().UTC()).
SetDisabled(true)
Expand All @@ -328,7 +386,9 @@ func (r repository) DeleteRule(ctx context.Context, params notification.DeleteRu
}

func (r repository) GetRule(ctx context.Context, params notification.GetRuleInput) (*notification.Rule, error) {
query := r.db.NotificationRule.Query().
db := r.client()

query := db.NotificationRule.Query().
Where(ruledb.ID(params.ID)).
Where(ruledb.Namespace(params.Namespace)).
WithChannels()
Expand All @@ -355,7 +415,9 @@ func (r repository) GetRule(ctx context.Context, params notification.GetRuleInpu
}

func (r repository) UpdateRule(ctx context.Context, params notification.UpdateRuleInput) (*notification.Rule, error) {
query := r.db.NotificationRule.UpdateOneID(params.ID).
db := r.client()

query := db.NotificationRule.UpdateOneID(params.ID).
SetUpdatedAt(clock.Now().UTC()).
SetDisabled(params.Disabled).
SetConfig(params.Config).
Expand All @@ -380,7 +442,7 @@ func (r repository) UpdateRule(ctx context.Context, params notification.UpdateRu
return nil, fmt.Errorf("invalid query result: nil notification rule received")
}

channelsQuery := r.db.NotificationChannel.Query().
channelsQuery := db.NotificationChannel.Query().
Where(channeldb.Namespace(params.Namespace)).
Where(channeldb.IDIn(params.Channels...))

Expand All @@ -395,7 +457,9 @@ func (r repository) UpdateRule(ctx context.Context, params notification.UpdateRu
}

func (r repository) ListEvents(ctx context.Context, params notification.ListEventsInput) (pagination.PagedResponse[notification.Event], error) {
query := r.db.NotificationEvent.Query()
db := r.client()

query := db.NotificationEvent.Query()

if len(params.Namespaces) > 0 {
query = query.Where(eventdb.NamespaceIn(params.Namespaces...))
Expand Down Expand Up @@ -471,7 +535,9 @@ func (r repository) ListEvents(ctx context.Context, params notification.ListEven
}

func (r repository) GetEvent(ctx context.Context, params notification.GetEventInput) (*notification.Event, error) {
query := r.db.NotificationEvent.Query().
db := r.client()

query := db.NotificationEvent.Query().
Where(eventdb.Namespace(params.Namespace)).
Where(eventdb.ID(params.ID)).
WithDeliveryStatuses().
Expand Down Expand Up @@ -509,7 +575,9 @@ func (r repository) CreateEvent(ctx context.Context, params notification.CreateE
return nil, fmt.Errorf("failed to serialize notification event payload: %w", err)
}

query := r.db.NotificationEvent.Create().
db := r.client()

query := db.NotificationEvent.Create().
SetType(params.Type).
SetNamespace(params.Namespace).
SetRuleID(params.RuleID).
Expand All @@ -524,7 +592,7 @@ func (r repository) CreateEvent(ctx context.Context, params notification.CreateE
return nil, errors.New("invalid query response: nil notification event received")
}

ruleQuery := r.db.NotificationRule.Query().
ruleQuery := db.NotificationRule.Query().
Where(ruledb.Namespace(params.Namespace)).
Where(ruledb.ID(params.RuleID)).
Where(ruledb.DeletedAtIsNil()).
Expand Down Expand Up @@ -560,7 +628,7 @@ func (r repository) CreateEvent(ctx context.Context, params notification.CreateE
continue
}

q := r.db.NotificationEventDeliveryStatus.Create().
q := db.NotificationEventDeliveryStatus.Create().
SetNamespace(params.Namespace).
SetEventID(eventRow.ID).
SetChannelID(channel.ID).
Expand All @@ -570,7 +638,7 @@ func (r repository) CreateEvent(ctx context.Context, params notification.CreateE
statusBulkQuery = append(statusBulkQuery, q)
}

statusQuery := r.db.NotificationEventDeliveryStatus.CreateBulk(statusBulkQuery...)
statusQuery := db.NotificationEventDeliveryStatus.CreateBulk(statusBulkQuery...)

statusRows, err := statusQuery.Save(ctx)
if err != nil {
Expand All @@ -588,7 +656,9 @@ func (r repository) CreateEvent(ctx context.Context, params notification.CreateE
}

func (r repository) ListEventsDeliveryStatus(ctx context.Context, params notification.ListEventsDeliveryStatusInput) (pagination.PagedResponse[notification.EventDeliveryStatus], error) {
query := r.db.NotificationEventDeliveryStatus.Query()
db := r.client()

query := db.NotificationEventDeliveryStatus.Query()

if len(params.Namespaces) > 0 {
query = query.Where(statusdb.NamespaceIn(params.Namespaces...))
Expand Down Expand Up @@ -640,7 +710,9 @@ func (r repository) ListEventsDeliveryStatus(ctx context.Context, params notific
}

func (r repository) GetEventDeliveryStatus(ctx context.Context, params notification.GetEventDeliveryStatusInput) (*notification.EventDeliveryStatus, error) {
query := r.db.NotificationEventDeliveryStatus.Query().
db := r.client()

query := db.NotificationEventDeliveryStatus.Query().
Where(statusdb.Namespace(params.Namespace)).
Where(statusdb.ID(params.ID))

Expand All @@ -667,10 +739,12 @@ func (r repository) GetEventDeliveryStatus(ctx context.Context, params notificat
func (r repository) UpdateEventDeliveryStatus(ctx context.Context, params notification.UpdateEventDeliveryStatusInput) (*notification.EventDeliveryStatus, error) {
var updateQuery *entdb.NotificationEventDeliveryStatusUpdateOne

db := r.client()

if params.ID != "" {
updateQuery = r.db.NotificationEventDeliveryStatus.UpdateOneID(params.ID).SetState(params.State)
updateQuery = db.NotificationEventDeliveryStatus.UpdateOneID(params.ID).SetState(params.State)
} else {
getQuery := r.db.NotificationEventDeliveryStatus.Query().
getQuery := db.NotificationEventDeliveryStatus.Query().
Where(statusdb.Namespace(params.Namespace)).
Where(statusdb.EventID(params.EventID)).
Where(statusdb.ChannelID(params.ChannelID))
Expand All @@ -689,7 +763,7 @@ func (r repository) UpdateEventDeliveryStatus(ctx context.Context, params notifi
return nil, fmt.Errorf("failed to udpate notification event delivery status: %w", err)
}

updateQuery = r.db.NotificationEventDeliveryStatus.UpdateOne(statusRow).
updateQuery = db.NotificationEventDeliveryStatus.UpdateOne(statusRow).
SetState(params.State).
SetReason(params.Reason)
}
Expand Down
Loading
Loading