diff --git a/openmeter/entitlement/adapter/entitlement.go b/openmeter/entitlement/adapter/entitlement.go index b9886c38f..6ccbd8174 100644 --- a/openmeter/entitlement/adapter/entitlement.go +++ b/openmeter/entitlement/adapter/entitlement.go @@ -7,7 +7,6 @@ import ( "time" "entgo.io/ent/dialect/sql" - "github.com/samber/lo" "github.com/openmeterio/openmeter/openmeter/ent/db" db_entitlement "github.com/openmeterio/openmeter/openmeter/ent/db/entitlement" @@ -18,7 +17,6 @@ import ( "github.com/openmeterio/openmeter/openmeter/entitlement/balanceworker" "github.com/openmeterio/openmeter/pkg/clock" "github.com/openmeterio/openmeter/pkg/convert" - "github.com/openmeterio/openmeter/pkg/defaultx" "github.com/openmeterio/openmeter/pkg/framework/entutils" "github.com/openmeterio/openmeter/pkg/framework/transaction" "github.com/openmeterio/openmeter/pkg/models" @@ -104,7 +102,7 @@ func (a *entitlementDBAdapter) GetEntitlementOfSubject(ctx context.Context, name } func (a *entitlementDBAdapter) CreateEntitlement(ctx context.Context, ent entitlement.CreateEntitlementRepoInputs) (*entitlement.Entitlement, error) { - return entutils.TransactingRepo[entitlement.Entitlement, *entitlementDBAdapter]( + return entutils.TransactingRepo[*entitlement.Entitlement, *entitlementDBAdapter]( ctx, a, func(ctx context.Context, repo *entitlementDBAdapter) (*entitlement.Entitlement, error) { @@ -148,7 +146,7 @@ func (a *entitlementDBAdapter) CreateEntitlement(ctx context.Context, ent entitl } func (a *entitlementDBAdapter) DeleteEntitlement(ctx context.Context, entitlementID models.NamespacedID) error { - _, err := entutils.TransactingRepo[entitlement.Entitlement, *entitlementDBAdapter]( + _, err := entutils.TransactingRepo[*entitlement.Entitlement, *entitlementDBAdapter]( ctx, a, func(ctx context.Context, repo *entitlementDBAdapter) (*entitlement.Entitlement, error) { @@ -169,10 +167,10 @@ func (a *entitlementDBAdapter) DeleteEntitlement(ctx context.Context, entitlemen } func (a *entitlementDBAdapter) ListAffectedEntitlements(ctx context.Context, eventFilters []balanceworker.IngestEventQueryFilter) ([]balanceworker.IngestEventDataResponse, error) { - res, err := entutils.TransactingRepo[[]balanceworker.IngestEventDataResponse, *entitlementDBAdapter]( + return entutils.TransactingRepo[[]balanceworker.IngestEventDataResponse, *entitlementDBAdapter]( ctx, a, - func(ctx context.Context, repo *entitlementDBAdapter) (*[]balanceworker.IngestEventDataResponse, error) { + func(ctx context.Context, repo *entitlementDBAdapter) ([]balanceworker.IngestEventDataResponse, error) { if len(eventFilters) == 0 { return nil, fmt.Errorf("no eventFilters provided") } @@ -210,17 +208,15 @@ func (a *entitlementDBAdapter) ListAffectedEntitlements(ctx context.Context, eve }) } - return &result, nil + return result, nil }) - - return defaultx.WithDefault(res, nil), err } func (a *entitlementDBAdapter) GetEntitlementsOfSubject(ctx context.Context, namespace string, subjectKey models.SubjectKey) ([]entitlement.Entitlement, error) { - res, err := entutils.TransactingRepo( + return entutils.TransactingRepo( ctx, a, - func(ctx context.Context, repo *entitlementDBAdapter) (*[]entitlement.Entitlement, error) { + func(ctx context.Context, repo *entitlementDBAdapter) ([]entitlement.Entitlement, error) { res, err := withLatestUsageReset(repo.db.Entitlement.Query(), []string{namespace}). Where( db_entitlement.Or(db_entitlement.DeletedAtGT(clock.Now()), db_entitlement.DeletedAtIsNil()), @@ -237,17 +233,16 @@ func (a *entitlementDBAdapter) GetEntitlementsOfSubject(ctx context.Context, nam result = append(result, *mapEntitlementEntity(e)) } - return &result, nil + return result, nil }, ) - return defaultx.WithDefault(res, nil), err } func (a *entitlementDBAdapter) HasEntitlementForMeter(ctx context.Context, namespace string, meterSlug string) (bool, error) { - res, err := entutils.TransactingRepo( + return entutils.TransactingRepo( ctx, a, - func(ctx context.Context, repo *entitlementDBAdapter) (*bool, error) { + func(ctx context.Context, repo *entitlementDBAdapter) (bool, error) { exists, err := repo.db.Entitlement.Query(). Where( db_entitlement.Or(db_entitlement.DeletedAtGT(clock.Now()), db_entitlement.DeletedAtIsNil()), @@ -256,20 +251,19 @@ func (a *entitlementDBAdapter) HasEntitlementForMeter(ctx context.Context, names ). Exist(ctx) if err != nil { - return lo.ToPtr(false), err + return false, err } - return &exists, nil + return exists, nil }, ) - return defaultx.WithDefault(res, false), err } func (a *entitlementDBAdapter) ListEntitlements(ctx context.Context, params entitlement.ListEntitlementsParams) (pagination.PagedResponse[entitlement.Entitlement], error) { - res, err := entutils.TransactingRepo( + return entutils.TransactingRepo( ctx, a, - func(ctx context.Context, repo *entitlementDBAdapter) (*pagination.PagedResponse[entitlement.Entitlement], error) { + func(ctx context.Context, repo *entitlementDBAdapter) (pagination.PagedResponse[entitlement.Entitlement], error) { query := repo.db.Entitlement.Query() if len(params.Namespaces) > 0 { @@ -349,7 +343,7 @@ func (a *entitlementDBAdapter) ListEntitlements(ctx context.Context, params enti entities, err := query.All(ctx) if err != nil { - return &response, err + return response, err } mapped := make([]entitlement.Entitlement, 0, len(entities)) @@ -358,12 +352,12 @@ func (a *entitlementDBAdapter) ListEntitlements(ctx context.Context, params enti } response.Items = mapped - return &response, nil + return response, nil } paged, err := query.Paginate(ctx, params.Page) if err != nil { - return &response, err + return response, err } result := make([]entitlement.Entitlement, 0, len(paged.Items)) @@ -374,11 +368,9 @@ func (a *entitlementDBAdapter) ListEntitlements(ctx context.Context, params enti response.TotalCount = paged.TotalCount response.Items = result - return &response, nil + return response, nil }, ) - - return defaultx.WithDefault(res, pagination.PagedResponse[entitlement.Entitlement]{}), err } func mapEntitlementEntity(e *db.Entitlement) *entitlement.Entitlement { @@ -456,10 +448,10 @@ func (a *entitlementDBAdapter) UpdateEntitlementUsagePeriod(ctx context.Context, } func (a *entitlementDBAdapter) ListEntitlementsWithExpiredUsagePeriod(ctx context.Context, namespaces []string, expiredBefore time.Time) ([]entitlement.Entitlement, error) { - res, err := entutils.TransactingRepo( + return entutils.TransactingRepo( ctx, a, - func(ctx context.Context, repo *entitlementDBAdapter) (*[]entitlement.Entitlement, error) { + func(ctx context.Context, repo *entitlementDBAdapter) ([]entitlement.Entitlement, error) { query := withLatestUsageReset(repo.db.Entitlement.Query(), namespaces). Where( db_entitlement.CurrentUsagePeriodEndNotNil(), @@ -481,10 +473,9 @@ func (a *entitlementDBAdapter) ListEntitlementsWithExpiredUsagePeriod(ctx contex result = append(result, *mapEntitlementEntity(e)) } - return &result, nil + return result, nil }, ) - return defaultx.WithDefault(res, nil), err } func (a *entitlementDBAdapter) LockEntitlementForTx(ctx context.Context, tx *entutils.TxDriver, entitlementID models.NamespacedID) error { @@ -514,10 +505,10 @@ type namespacesWithCount struct { } func (a *entitlementDBAdapter) ListNamespacesWithActiveEntitlements(ctx context.Context, includeDeletedAfter time.Time) ([]string, error) { - res, err := entutils.TransactingRepo( + return entutils.TransactingRepo( ctx, a, - func(ctx context.Context, repo *entitlementDBAdapter) (*[]string, error) { + func(ctx context.Context, repo *entitlementDBAdapter) ([]string, error) { namespaces := []namespacesWithCount{} err := repo.db.Entitlement.Query(). @@ -531,12 +522,11 @@ func (a *entitlementDBAdapter) ListNamespacesWithActiveEntitlements(ctx context. return nil, err } - return lo.ToPtr(slicesx.Map(namespaces, func(n namespacesWithCount) string { + return slicesx.Map(namespaces, func(n namespacesWithCount) string { return n.Namespace - })), nil + }), nil }, ) - return defaultx.WithDefault(res, nil), err } func withLatestUsageReset(q *db.EntitlementQuery, namespaces []string) *db.EntitlementQuery { diff --git a/openmeter/entitlement/adapter/usage_reset.go b/openmeter/entitlement/adapter/usage_reset.go index 584e2fc2c..7d5326abc 100644 --- a/openmeter/entitlement/adapter/usage_reset.go +++ b/openmeter/entitlement/adapter/usage_reset.go @@ -37,7 +37,7 @@ func (a *usageResetDBAdapter) Save(ctx context.Context, usageResetTime metereden _, err := entutils.TransactingRepo[interface{}, *usageResetDBAdapter]( ctx, a, - func(ctx context.Context, repo *usageResetDBAdapter) (*interface{}, error) { + func(ctx context.Context, repo *usageResetDBAdapter) (interface{}, error) { _, err := repo.db.UsageReset.Create(). SetEntitlementID(usageResetTime.EntitlementID). SetNamespace(usageResetTime.Namespace). diff --git a/pkg/framework/entutils/transaction.go b/pkg/framework/entutils/transaction.go index ba7be06af..84e18ad3a 100644 --- a/pkg/framework/entutils/transaction.go +++ b/pkg/framework/entutils/transaction.go @@ -201,12 +201,13 @@ func TransactingRepo[R, T any]( TxUser[T] TxCreator }, - cb func(ctx context.Context, rep T) (*R, error), -) (*R, error) { - return transaction.Run(ctx, repo, func(ctx context.Context) (*R, error) { + cb func(ctx context.Context, rep T) (R, error), +) (R, error) { + return transaction.Run(ctx, repo, func(ctx context.Context) (R, error) { + var def R tx, err := GetDriverFromContext(ctx) if err != nil { - return nil, err + return def, err } return cb(ctx, repo.WithTx(ctx, tx)) })