Skip to content

Commit

Permalink
Merge pull request #1606 from openmeterio/refactor/tx-repo
Browse files Browse the repository at this point in the history
Use default returns for TX Repo helpers
  • Loading branch information
GAlexIHU authored Oct 3, 2024
2 parents efcc0f1 + c9d00e1 commit a53fb09
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 40 deletions.
60 changes: 25 additions & 35 deletions openmeter/entitlement/adapter/entitlement.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"time"

"entgo.io/ent/dialect/sql"
"github.com/samber/lo"

"github.com/openmeterio/openmeter/openmeter/ent/db"
db_entitlement "github.com/openmeterio/openmeter/openmeter/ent/db/entitlement"
Expand All @@ -18,7 +17,6 @@ import (
"github.com/openmeterio/openmeter/openmeter/entitlement/balanceworker"
"github.com/openmeterio/openmeter/pkg/clock"
"github.com/openmeterio/openmeter/pkg/convert"
"github.com/openmeterio/openmeter/pkg/defaultx"
"github.com/openmeterio/openmeter/pkg/framework/entutils"
"github.com/openmeterio/openmeter/pkg/framework/transaction"
"github.com/openmeterio/openmeter/pkg/models"
Expand Down Expand Up @@ -104,7 +102,7 @@ func (a *entitlementDBAdapter) GetEntitlementOfSubject(ctx context.Context, name
}

func (a *entitlementDBAdapter) CreateEntitlement(ctx context.Context, ent entitlement.CreateEntitlementRepoInputs) (*entitlement.Entitlement, error) {
return entutils.TransactingRepo[entitlement.Entitlement, *entitlementDBAdapter](
return entutils.TransactingRepo[*entitlement.Entitlement, *entitlementDBAdapter](
ctx,
a,
func(ctx context.Context, repo *entitlementDBAdapter) (*entitlement.Entitlement, error) {
Expand Down Expand Up @@ -148,7 +146,7 @@ func (a *entitlementDBAdapter) CreateEntitlement(ctx context.Context, ent entitl
}

func (a *entitlementDBAdapter) DeleteEntitlement(ctx context.Context, entitlementID models.NamespacedID) error {
_, err := entutils.TransactingRepo[entitlement.Entitlement, *entitlementDBAdapter](
_, err := entutils.TransactingRepo[*entitlement.Entitlement, *entitlementDBAdapter](
ctx,
a,
func(ctx context.Context, repo *entitlementDBAdapter) (*entitlement.Entitlement, error) {
Expand All @@ -169,10 +167,10 @@ func (a *entitlementDBAdapter) DeleteEntitlement(ctx context.Context, entitlemen
}

func (a *entitlementDBAdapter) ListAffectedEntitlements(ctx context.Context, eventFilters []balanceworker.IngestEventQueryFilter) ([]balanceworker.IngestEventDataResponse, error) {
res, err := entutils.TransactingRepo[[]balanceworker.IngestEventDataResponse, *entitlementDBAdapter](
return entutils.TransactingRepo[[]balanceworker.IngestEventDataResponse, *entitlementDBAdapter](
ctx,
a,
func(ctx context.Context, repo *entitlementDBAdapter) (*[]balanceworker.IngestEventDataResponse, error) {
func(ctx context.Context, repo *entitlementDBAdapter) ([]balanceworker.IngestEventDataResponse, error) {
if len(eventFilters) == 0 {
return nil, fmt.Errorf("no eventFilters provided")
}
Expand Down Expand Up @@ -210,17 +208,15 @@ func (a *entitlementDBAdapter) ListAffectedEntitlements(ctx context.Context, eve
})
}

return &result, nil
return result, nil
})

return defaultx.WithDefault(res, nil), err
}

func (a *entitlementDBAdapter) GetEntitlementsOfSubject(ctx context.Context, namespace string, subjectKey models.SubjectKey) ([]entitlement.Entitlement, error) {
res, err := entutils.TransactingRepo(
return entutils.TransactingRepo(
ctx,
a,
func(ctx context.Context, repo *entitlementDBAdapter) (*[]entitlement.Entitlement, error) {
func(ctx context.Context, repo *entitlementDBAdapter) ([]entitlement.Entitlement, error) {
res, err := withLatestUsageReset(repo.db.Entitlement.Query(), []string{namespace}).
Where(
db_entitlement.Or(db_entitlement.DeletedAtGT(clock.Now()), db_entitlement.DeletedAtIsNil()),
Expand All @@ -237,17 +233,16 @@ func (a *entitlementDBAdapter) GetEntitlementsOfSubject(ctx context.Context, nam
result = append(result, *mapEntitlementEntity(e))
}

return &result, nil
return result, nil
},
)
return defaultx.WithDefault(res, nil), err
}

func (a *entitlementDBAdapter) HasEntitlementForMeter(ctx context.Context, namespace string, meterSlug string) (bool, error) {
res, err := entutils.TransactingRepo(
return entutils.TransactingRepo(
ctx,
a,
func(ctx context.Context, repo *entitlementDBAdapter) (*bool, error) {
func(ctx context.Context, repo *entitlementDBAdapter) (bool, error) {
exists, err := repo.db.Entitlement.Query().
Where(
db_entitlement.Or(db_entitlement.DeletedAtGT(clock.Now()), db_entitlement.DeletedAtIsNil()),
Expand All @@ -256,20 +251,19 @@ func (a *entitlementDBAdapter) HasEntitlementForMeter(ctx context.Context, names
).
Exist(ctx)
if err != nil {
return lo.ToPtr(false), err
return false, err
}

return &exists, nil
return exists, nil
},
)
return defaultx.WithDefault(res, false), err
}

func (a *entitlementDBAdapter) ListEntitlements(ctx context.Context, params entitlement.ListEntitlementsParams) (pagination.PagedResponse[entitlement.Entitlement], error) {
res, err := entutils.TransactingRepo(
return entutils.TransactingRepo(
ctx,
a,
func(ctx context.Context, repo *entitlementDBAdapter) (*pagination.PagedResponse[entitlement.Entitlement], error) {
func(ctx context.Context, repo *entitlementDBAdapter) (pagination.PagedResponse[entitlement.Entitlement], error) {
query := repo.db.Entitlement.Query()

if len(params.Namespaces) > 0 {
Expand Down Expand Up @@ -349,7 +343,7 @@ func (a *entitlementDBAdapter) ListEntitlements(ctx context.Context, params enti

entities, err := query.All(ctx)
if err != nil {
return &response, err
return response, err
}

mapped := make([]entitlement.Entitlement, 0, len(entities))
Expand All @@ -358,12 +352,12 @@ func (a *entitlementDBAdapter) ListEntitlements(ctx context.Context, params enti
}

response.Items = mapped
return &response, nil
return response, nil
}

paged, err := query.Paginate(ctx, params.Page)
if err != nil {
return &response, err
return response, err
}

result := make([]entitlement.Entitlement, 0, len(paged.Items))
Expand All @@ -374,11 +368,9 @@ func (a *entitlementDBAdapter) ListEntitlements(ctx context.Context, params enti
response.TotalCount = paged.TotalCount
response.Items = result

return &response, nil
return response, nil
},
)

return defaultx.WithDefault(res, pagination.PagedResponse[entitlement.Entitlement]{}), err
}

func mapEntitlementEntity(e *db.Entitlement) *entitlement.Entitlement {
Expand Down Expand Up @@ -456,10 +448,10 @@ func (a *entitlementDBAdapter) UpdateEntitlementUsagePeriod(ctx context.Context,
}

func (a *entitlementDBAdapter) ListEntitlementsWithExpiredUsagePeriod(ctx context.Context, namespaces []string, expiredBefore time.Time) ([]entitlement.Entitlement, error) {
res, err := entutils.TransactingRepo(
return entutils.TransactingRepo(
ctx,
a,
func(ctx context.Context, repo *entitlementDBAdapter) (*[]entitlement.Entitlement, error) {
func(ctx context.Context, repo *entitlementDBAdapter) ([]entitlement.Entitlement, error) {
query := withLatestUsageReset(repo.db.Entitlement.Query(), namespaces).
Where(
db_entitlement.CurrentUsagePeriodEndNotNil(),
Expand All @@ -481,10 +473,9 @@ func (a *entitlementDBAdapter) ListEntitlementsWithExpiredUsagePeriod(ctx contex
result = append(result, *mapEntitlementEntity(e))
}

return &result, nil
return result, nil
},
)
return defaultx.WithDefault(res, nil), err
}

func (a *entitlementDBAdapter) LockEntitlementForTx(ctx context.Context, tx *entutils.TxDriver, entitlementID models.NamespacedID) error {
Expand Down Expand Up @@ -514,10 +505,10 @@ type namespacesWithCount struct {
}

func (a *entitlementDBAdapter) ListNamespacesWithActiveEntitlements(ctx context.Context, includeDeletedAfter time.Time) ([]string, error) {
res, err := entutils.TransactingRepo(
return entutils.TransactingRepo(
ctx,
a,
func(ctx context.Context, repo *entitlementDBAdapter) (*[]string, error) {
func(ctx context.Context, repo *entitlementDBAdapter) ([]string, error) {
namespaces := []namespacesWithCount{}

err := repo.db.Entitlement.Query().
Expand All @@ -531,12 +522,11 @@ func (a *entitlementDBAdapter) ListNamespacesWithActiveEntitlements(ctx context.
return nil, err
}

return lo.ToPtr(slicesx.Map(namespaces, func(n namespacesWithCount) string {
return slicesx.Map(namespaces, func(n namespacesWithCount) string {
return n.Namespace
})), nil
}), nil
},
)
return defaultx.WithDefault(res, nil), err
}

func withLatestUsageReset(q *db.EntitlementQuery, namespaces []string) *db.EntitlementQuery {
Expand Down
2 changes: 1 addition & 1 deletion openmeter/entitlement/adapter/usage_reset.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func (a *usageResetDBAdapter) Save(ctx context.Context, usageResetTime metereden
_, err := entutils.TransactingRepo[interface{}, *usageResetDBAdapter](
ctx,
a,
func(ctx context.Context, repo *usageResetDBAdapter) (*interface{}, error) {
func(ctx context.Context, repo *usageResetDBAdapter) (interface{}, error) {
_, err := repo.db.UsageReset.Create().
SetEntitlementID(usageResetTime.EntitlementID).
SetNamespace(usageResetTime.Namespace).
Expand Down
9 changes: 5 additions & 4 deletions pkg/framework/entutils/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,13 @@ func TransactingRepo[R, T any](
TxUser[T]
TxCreator
},
cb func(ctx context.Context, rep T) (*R, error),
) (*R, error) {
return transaction.Run(ctx, repo, func(ctx context.Context) (*R, error) {
cb func(ctx context.Context, rep T) (R, error),
) (R, error) {
return transaction.Run(ctx, repo, func(ctx context.Context) (R, error) {
var def R
tx, err := GetDriverFromContext(ctx)
if err != nil {
return nil, err
return def, err
}
return cb(ctx, repo.WithTx(ctx, tx))
})
Expand Down

0 comments on commit a53fb09

Please sign in to comment.