diff --git a/openmeter/billing/README.md b/openmeter/billing/README.md index 0f0f2aa74..0b8ebdc08 100644 --- a/openmeter/billing/README.md +++ b/openmeter/billing/README.md @@ -149,3 +149,11 @@ The entity's `ChildrenWithIDReuse` call can be used to facilitate the line reuse Then the adapter layer will use those IDs to make decisions if they want to persist or recreate the records. We could do the same logic in the adapter layer, but this approach makes it more flexible on the calculation layer if we want to generate new lines or not. If this becomes a burden we can do the same matching logic as part of the upsert logic in adapter. + +## Subscription adapter + +The subscription adapter is responsible for feeding the billing with line items during the subscription's lifecycle. The generation of items is event-driven, new items are yielded when: +- A subscription is created +- A new invoice is created +- A subscription is modified +- Upgrade/Downgrade is handled as a subscription create/cancel diff --git a/openmeter/billing/adapter.go b/openmeter/billing/adapter.go index 3343b2794..553390cc7 100644 --- a/openmeter/billing/adapter.go +++ b/openmeter/billing/adapter.go @@ -45,6 +45,7 @@ type InvoiceLineAdapter interface { ListInvoiceLines(ctx context.Context, input ListInvoiceLinesAdapterInput) ([]*Line, error) AssociateLinesToInvoice(ctx context.Context, input AssociateLinesToInvoiceAdapterInput) ([]*Line, error) GetInvoiceLine(ctx context.Context, input GetInvoiceLineAdapterInput) (*Line, error) + GetLinesForSubscription(ctx context.Context, input GetLinesForSubscriptionInput) ([]*Line, error) GetInvoiceLineOwnership(ctx context.Context, input GetInvoiceLineOwnershipAdapterInput) (GetOwnershipAdapterResponse, error) } diff --git a/openmeter/billing/adapter/invoice.go b/openmeter/billing/adapter/invoice.go index 73a516cb8..d82f761c3 100644 --- a/openmeter/billing/adapter/invoice.go +++ b/openmeter/billing/adapter/invoice.go @@ -43,7 +43,7 @@ func (a *adapter) GetInvoiceById(ctx context.Context, in billing.GetInvoiceByIdI WithBillingWorkflowConfig() if in.Expand.Lines { - query = tx.expandInvoiceLineItems(query, in.Expand.DeletedLines) + query = tx.expandInvoiceLineItems(query, in.Expand) } invoice, err := query.Only(ctx) @@ -63,15 +63,20 @@ func (a *adapter) GetInvoiceById(ctx context.Context, in billing.GetInvoiceByIdI }) } -func (a *adapter) expandInvoiceLineItems(query *db.BillingInvoiceQuery, includeDeleted bool) *db.BillingInvoiceQuery { +func (a *adapter) expandInvoiceLineItems(query *db.BillingInvoiceQuery, expand billing.InvoiceExpand) *db.BillingInvoiceQuery { return query.WithBillingInvoiceLines(func(q *db.BillingInvoiceLineQuery) { - if !includeDeleted { + if !expand.DeletedLines { q = q.Where(billinginvoiceline.DeletedAtIsNil()) } + requestedStatuses := []billing.InvoiceLineStatus{billing.InvoiceLineStatusValid} + if expand.SplitLines { + requestedStatuses = append(requestedStatuses, billing.InvoiceLineStatusSplit) + } + q = q.Where( // Detailed lines are sub-lines of a line and should not be included in the top-level invoice - billinginvoiceline.StatusIn(billing.InvoiceLineStatusValid), + billinginvoiceline.StatusIn(requestedStatuses...), ) a.expandLineItems(q) @@ -170,6 +175,10 @@ func (a *adapter) ListInvoices(ctx context.Context, input billing.ListInvoicesIn query = query.Where(billinginvoice.StatusIn(input.ExtendedStatuses...)) } + if len(input.IDs) > 0 { + query = query.Where(billinginvoice.IDIn(input.IDs...)) + } + if len(input.Statuses) > 0 { query = query.Where(func(s *sql.Selector) { s.Where(sql.Or( @@ -190,7 +199,7 @@ func (a *adapter) ListInvoices(ctx context.Context, input billing.ListInvoicesIn } if input.Expand.Lines { - query = tx.expandInvoiceLineItems(query, input.Expand.DeletedLines) + query = tx.expandInvoiceLineItems(query, input.Expand) } switch input.OrderBy { diff --git a/openmeter/billing/adapter/invoicelinemapper.go b/openmeter/billing/adapter/invoicelinemapper.go index 84180292c..d04c0dc84 100644 --- a/openmeter/billing/adapter/invoicelinemapper.go +++ b/openmeter/billing/adapter/invoicelinemapper.go @@ -52,6 +52,7 @@ func (a *adapter) mapInvoiceLineFromDB(ctx context.Context, invoiceLines []*db.B if err != nil { return nil, err } + mappedEntities[dbLine.ID] = &entity } diff --git a/openmeter/billing/adapter/invoicelines.go b/openmeter/billing/adapter/invoicelines.go index 0d9229369..9f3b987a8 100644 --- a/openmeter/billing/adapter/invoicelines.go +++ b/openmeter/billing/adapter/invoicelines.go @@ -103,7 +103,7 @@ func (a *adapter) UpsertInvoiceLines(ctx context.Context, inputIn billing.Upsert SetNillableInvoicingAppExternalID(lo.EmptyableToPtr(line.ExternalIDs.Invoicing)) if line.Subscription != nil { - create = create.SetSubscriptionID(line.Subscription.ItemID). + create = create.SetSubscriptionID(line.Subscription.SubscriptionID). SetSubscriptionPhaseID(line.Subscription.PhaseID). SetSubscriptionItemID(line.Subscription.ItemID) } @@ -492,3 +492,28 @@ func (a *adapter) GetInvoiceLineOwnership(ctx context.Context, in billing.GetInv }, nil }) } + +func (a *adapter) GetLinesForSubscription(ctx context.Context, in billing.GetLinesForSubscriptionInput) ([]*billing.Line, error) { + if err := in.Validate(); err != nil { + return nil, billing.ValidationError{ + Err: err, + } + } + + return entutils.TransactingRepo(ctx, a, func(ctx context.Context, tx *adapter) ([]*billing.Line, error) { + query := tx.db.BillingInvoiceLine.Query(). + Where(billinginvoiceline.Namespace(in.Namespace)). + Where(billinginvoiceline.SubscriptionID(in.SubscriptionID)). + // TODO[OM-1038]: document issues with deleted lines + Where(billinginvoiceline.ParentLineIDIsNil()) // This one is required so that we are not fetching split line's children directly, the mapper will handle that + + query = tx.expandLineItems(query) + + dbLines, err := query.All(ctx) + if err != nil { + return nil, fmt.Errorf("fetching lines: %w", err) + } + + return tx.mapInvoiceLineFromDB(ctx, dbLines) + }) +} diff --git a/openmeter/billing/invoice.go b/openmeter/billing/invoice.go index 5e45e7bc8..602ef6d60 100644 --- a/openmeter/billing/invoice.go +++ b/openmeter/billing/invoice.go @@ -137,6 +137,7 @@ type InvoiceExpand struct { WorkflowApps bool Lines bool DeletedLines bool + SplitLines bool } var InvoiceExpandAll = InvoiceExpand{ @@ -144,6 +145,7 @@ var InvoiceExpandAll = InvoiceExpand{ WorkflowApps: true, Lines: true, DeletedLines: false, + SplitLines: false, } func (e InvoiceExpand) Validate() error { @@ -160,6 +162,11 @@ func (e InvoiceExpand) SetDeletedLines(v bool) InvoiceExpand { return e } +func (e InvoiceExpand) SetSplitLines(v bool) InvoiceExpand { + e.SplitLines = v + return e +} + type InvoiceBase struct { Namespace string `json:"namespace"` ID string `json:"id"` @@ -373,6 +380,7 @@ type ListInvoicesInput struct { pagination.Page Namespace string + IDs []string Customers []string // Statuses searches by short InvoiceStatus (e.g. draft, issued) Statuses []string @@ -497,3 +505,21 @@ type GetOwnershipAdapterResponse struct { } type DeleteInvoiceInput = InvoiceID + +type UpdateInvoiceLinesInternalInput struct { + Namespace string + CustomerID string + Lines []*Line +} + +func (i UpdateInvoiceLinesInternalInput) Validate() error { + if i.Namespace == "" { + return errors.New("namespace is required") + } + + if i.CustomerID == "" { + return errors.New("customer ID is required") + } + + return nil +} diff --git a/openmeter/billing/invoiceline.go b/openmeter/billing/invoiceline.go index 28f0e6255..a655dc2ef 100644 --- a/openmeter/billing/invoiceline.go +++ b/openmeter/billing/invoiceline.go @@ -99,6 +99,10 @@ func (p Period) Contains(t time.Time) bool { return t.After(p.Start) && t.Before(p.End) } +func (p Period) Duration() time.Duration { + return p.End.Sub(p.Start) +} + // LineBase represents the common fields for an invoice item. type LineBase struct { Namespace string `json:"namespace"` @@ -409,6 +413,21 @@ func (i Line) ValidateUsageBased() error { return nil } +// DissacociateChildren removes the Children both from the DBState and the current line, so that the +// line can be safely persisted/managed without the children. +// +// The childrens receive DBState objects, so that they can be safely persisted/managed without the parent. +func (i *Line) DisassociateChildren() { + if i.Children.IsAbsent() { + return + } + + i.Children = LineChildren{} + if i.DBState != nil { + i.DBState.Children = LineChildren{} + } +} + // TODO[OM-1016]: For events we need a json marshaler type LineChildren struct { mo.Option[[]*Line] @@ -471,7 +490,11 @@ func (c *LineChildren) ReplaceByID(id string, newLine *Line) bool { for i, line := range lines { if line.ID == id { + // Let's preserve the DB state of the original line (as we are only replacing the current state) + originalDBState := line.DBState + lines[i] = newLine + lines[i].DBState = originalDBState return true } } @@ -952,3 +975,20 @@ type GetInvoiceLineInput = LineID type GetInvoiceLineOwnershipAdapterInput = LineID type DeleteInvoiceLineInput = LineID + +type GetLinesForSubscriptionInput struct { + Namespace string + SubscriptionID string +} + +func (i GetLinesForSubscriptionInput) Validate() error { + if i.Namespace == "" { + return errors.New("namespace is required") + } + + if i.SubscriptionID == "" { + return errors.New("subscription id is required") + } + + return nil +} diff --git a/openmeter/billing/service.go b/openmeter/billing/service.go index 982aa3c9e..4d4b1c40c 100644 --- a/openmeter/billing/service.go +++ b/openmeter/billing/service.go @@ -32,7 +32,9 @@ type CustomerOverrideService interface { type InvoiceLineService interface { CreatePendingInvoiceLines(ctx context.Context, input CreateInvoiceLinesInput) ([]*Line, error) GetInvoiceLine(ctx context.Context, input GetInvoiceLineInput) (*Line, error) + GetLinesForSubscription(ctx context.Context, input GetLinesForSubscriptionInput) ([]*Line, error) UpdateInvoiceLine(ctx context.Context, input UpdateInvoiceLineInput) (*Line, error) + DeleteInvoiceLine(ctx context.Context, input DeleteInvoiceLineInput) error } @@ -48,4 +50,11 @@ type InvoiceService interface { ApproveInvoice(ctx context.Context, input ApproveInvoiceInput) (Invoice, error) RetryInvoice(ctx context.Context, input RetryInvoiceInput) (Invoice, error) DeleteInvoice(ctx context.Context, input DeleteInvoiceInput) error + + // UpdateInvoiceLinesInternal updates the specified invoice lines and ensures that invoice states are properly syncronized + // This method is intended to be used by OpenMeter internal services only, as it allows for updating invoice line values, + // that are not allowed to be updated by external services. + // + // The call also ensures that the invoice's state is properly updated and invoice immutability is also considered. + UpdateInvoiceLinesInternal(ctx context.Context, input UpdateInvoiceLinesInternalInput) error } diff --git a/openmeter/billing/service/invoice.go b/openmeter/billing/service/invoice.go index 0cc48a147..b11a4774a 100644 --- a/openmeter/billing/service/invoice.go +++ b/openmeter/billing/service/invoice.go @@ -2,6 +2,7 @@ package billingservice import ( "context" + "errors" "fmt" "slices" "strings" @@ -540,3 +541,194 @@ func (s *Service) DeleteInvoice(ctx context.Context, input billing.DeleteInvoice return err } + +func (s *Service) UpdateInvoiceLinesInternal(ctx context.Context, input billing.UpdateInvoiceLinesInternalInput) error { + if err := input.Validate(); err != nil { + return billing.ValidationError{ + Err: err, + } + } + + if len(input.Lines) == 0 { + return nil + } + + return transaction.RunWithNoValue(ctx, s.adapter, func(ctx context.Context) error { + // Split line child updates should be done invoice by invoice, so let's flatten split lines + lines := s.flattenSplitLines(input.Lines) + + linesByInvoice := lo.GroupBy(lines, func(line *billing.Line) string { + return line.InvoiceID + }) + + // We want to upsert the new lines at the end, so that any updates on the gathering invoice will not interfere + newPendingLines := linesByInvoice[""] + delete(linesByInvoice, "") + + for invoiceID, lines := range linesByInvoice { + invoice, err := s.GetInvoiceByID(ctx, billing.GetInvoiceByIdInput{ + Invoice: billing.InvoiceID{ + ID: invoiceID, + Namespace: input.Namespace, + }, + Expand: billing.InvoiceExpand{ + Lines: true, + DeletedLines: true, + SplitLines: true, + }, + }) + if err != nil { + return fmt.Errorf("fetching invoice: %w", err) + } + + if input.CustomerID != invoice.Customer.CustomerID { + return billing.ValidationError{ + Err: fmt.Errorf("customer mismatch: [input.customer=%s] vs [invoice.customer=%s]", input.CustomerID, invoice.Customer.CustomerID), + } + } + + if invoice.Status == billing.InvoiceStatusGathering { + for _, line := range lines { + if !invoice.Lines.ReplaceByID(line.ID, line) { + return fmt.Errorf("line[%s] not found in invoice[%s]", line.ID, invoice.ID) + } + } + + if _, err := s.adapter.UpdateInvoice(ctx, invoice); err != nil { + return fmt.Errorf("updating gathering invoice: %w", err) + } + } else { + if err := s.handleNonGatheringInvoiceLineUpdate(ctx, invoice, lines); err != nil { + return fmt.Errorf("handling invoice line update: %w", err) + } + } + } + + _, err := s.CreatePendingInvoiceLines(ctx, billing.CreateInvoiceLinesInput{ + Namespace: input.Namespace, + Lines: lo.Map(newPendingLines, func(line *billing.Line, _ int) billing.LineWithCustomer { + return billing.LineWithCustomer{ + Line: *line, + CustomerID: input.CustomerID, + } + }), + }) + if err != nil { + return fmt.Errorf("creating new pending lines: %w", err) + } + + // Note: The gathering invoice will be maintained by the CreatePendingInvoiceLines call, so we don't need to care for any empty gathering + // invoices here. + return nil + }) +} + +func (s *Service) handleNonGatheringInvoiceLineUpdate(ctx context.Context, invoice billing.Invoice, lines []*billing.Line) error { + if invoice.Lines.IsAbsent() { + return errors.New("cannot update invoice without expanded lines") + } + + existingInvoiceLinesByID := lo.GroupBy(invoice.Lines.OrEmpty(), func(line *billing.Line) string { + return line.ID + }) + + // Let's look for the lines that have been updated + changedLines := make([]*billing.Line, 0, len(lines)) + + for _, line := range lines { + existingLines, existingLineFound := existingInvoiceLinesByID[line.ID] + + if existingLineFound { + if len(existingLines) != 1 { + return fmt.Errorf("line[%s] has more than one entry in the invoice", line.ID) + } + + existingLine := existingLines[0] + if !existingLine.LineBase.Equal(line.LineBase) { + changedLines = append(changedLines, line) + } + } else { + changedLines = append(changedLines, line) + } + } + + if len(changedLines) == 0 { + return nil + } + + // Let's try to avoid touching an immutable invoice + if invoice.StatusDetails.Immutable { + // We only care about lines that are affecting the balance at this stage, as + // there's a chance that an invoice being created and a subscription update is + // happening in the same time. + + return fmt.Errorf("invoice is immutable, but voiding is not implemented yet: invoice[%s] lineIDs:[%s]", + invoice.ID, + strings.Join(lo.Map(changedLines, func(line *billing.Line, _ int) string { + return line.ID + }), ","), + ) + } + + // Note: in the current setup this could only happen if there's a parallel progressive invoice creation and + // subscription edit. + for _, line := range changedLines { + // Should not happen as split lines can only live on gathering invoices + if line.Status == billing.InvoiceLineStatusSplit { + return fmt.Errorf("split line[%s] cannot be updated", line.ID) + } + + // Let's update the snapshot of the line, as we might have changed the period + srv, err := s.lineService.FromEntity(line) + if err != nil { + return fmt.Errorf("creating line service: %w", err) + } + + if err := srv.Validate(ctx, &invoice); err != nil { + return fmt.Errorf("validating line: %w", err) + } + + if err := srv.SnapshotQuantity(ctx, &invoice); err != nil { + return fmt.Errorf("snapshotting quantity: %w", err) + } + + if err := srv.CalculateDetailedLines(); err != nil { + return fmt.Errorf("calculating detailed lines: %w", err) + } + + if err := srv.UpdateTotals(); err != nil { + return fmt.Errorf("updating totals: %w", err) + } + } + + invoice, err := s.executeTriggerOnInvoice( + ctx, + invoice.InvoiceID(), + triggerUpdated, + ExecuteTriggerWithAllowInStates(billing.InvoiceStatusDraftUpdating), + ExecuteTriggerWithEditCallback(func(sm *InvoiceStateMachine) error { + for _, line := range changedLines { + if !invoice.Lines.ReplaceByID(line.ID, line) { + return fmt.Errorf("line[%s] not found in invoice[%s]", line.ID, invoice.ID) + } + } + return nil + }), + ) + + return err +} + +func (s *Service) flattenSplitLines(lines []*billing.Line) []*billing.Line { + out := make([]*billing.Line, 0, len(lines)) + for _, line := range lines { + out = append(out, line) + + if line.Status == billing.InvoiceLineStatusSplit { + out = append(out, line.Children.OrEmpty()...) + line.DisassociateChildren() + } + } + + return out +} diff --git a/openmeter/billing/service/invoicecalc/details.go b/openmeter/billing/service/invoicecalc/details.go index aa1b6774a..d529a40c8 100644 --- a/openmeter/billing/service/invoicecalc/details.go +++ b/openmeter/billing/service/invoicecalc/details.go @@ -50,6 +50,11 @@ func RecalculateDetailedLinesAndTotals(invoice *billing.Invoice, deps Calculator return billing.Totals{} } + // Split lines cannot contribute to the totals, as they are superseded by the child lines + if line.Status == billing.InvoiceLineStatusSplit { + return billing.Totals{} + } + return line.Totals })...) diff --git a/openmeter/billing/service/invoiceline.go b/openmeter/billing/service/invoiceline.go index 17abbf80e..17a5962f1 100644 --- a/openmeter/billing/service/invoiceline.go +++ b/openmeter/billing/service/invoiceline.go @@ -383,3 +383,13 @@ func (s *Service) DeleteInvoiceLine(ctx context.Context, input billing.DeleteInv return err }) } + +func (s *Service) GetLinesForSubscription(ctx context.Context, input billing.GetLinesForSubscriptionInput) ([]*billing.Line, error) { + if err := input.Validate(); err != nil { + return nil, billing.ValidationError{ + Err: err, + } + } + + return s.adapter.GetLinesForSubscription(ctx, input) +} diff --git a/openmeter/billing/service/lineservice/feeline.go b/openmeter/billing/service/lineservice/feeline.go index 59d1a9413..ce1545bc0 100644 --- a/openmeter/billing/service/lineservice/feeline.go +++ b/openmeter/billing/service/lineservice/feeline.go @@ -21,6 +21,8 @@ func (l feeLine) PrepareForCreate(context.Context) (Line, error) { } func (l feeLine) CanBeInvoicedAsOf(_ context.Context, t time.Time) (*billing.Period, error) { + // TODO[later]: Prorate can be implemented here for progressive billing/pro-rating of the fee + if !t.Before(l.line.InvoiceAt) { return &l.line.Period, nil } diff --git a/openmeter/billing/service/lineservice/linebase.go b/openmeter/billing/service/lineservice/linebase.go index 30023bd64..3ec728e98 100644 --- a/openmeter/billing/service/lineservice/linebase.go +++ b/openmeter/billing/service/lineservice/linebase.go @@ -21,6 +21,8 @@ type UpdateInput struct { // PreventChildChanges is used to prevent any child changes to the line by the adapter. PreventChildChanges bool + + ResetChildUniqueReferenceID bool } func (i UpdateInput) apply(line *billing.Line) { @@ -43,6 +45,10 @@ func (i UpdateInput) apply(line *billing.Line) { if i.PreventChildChanges { line.Children = billing.LineChildren{} } + + if i.ResetChildUniqueReferenceID { + line.ChildUniqueReferenceID = nil + } } type SplitResult struct { @@ -215,16 +221,18 @@ func (l lineBase) Split(ctx context.Context, splitAt time.Time) (SplitResult, er // Let's create the child lines preSplitAtLine := l.CloneForCreate(UpdateInput{ - ParentLine: mo.Some(parentLine.ToEntity()), - Status: billing.InvoiceLineStatusValid, - PeriodEnd: splitAt, - InvoiceAt: splitAt, + ParentLine: mo.Some(parentLine.ToEntity()), + Status: billing.InvoiceLineStatusValid, + PeriodEnd: splitAt, + InvoiceAt: splitAt, + ResetChildUniqueReferenceID: true, }) postSplitAtLine := l.CloneForCreate(UpdateInput{ - ParentLine: mo.Some(parentLine.ToEntity()), - Status: billing.InvoiceLineStatusValid, - PeriodStart: splitAt, + ParentLine: mo.Some(parentLine.ToEntity()), + Status: billing.InvoiceLineStatusValid, + PeriodStart: splitAt, + ResetChildUniqueReferenceID: true, }) splitLines, err := l.service.UpsertLines(ctx, l.line.Namespace, preSplitAtLine, postSplitAtLine) @@ -240,17 +248,19 @@ func (l lineBase) Split(ctx context.Context, splitAt time.Time) (SplitResult, er // We have alredy split the line once, we just need to create a new line and update the existing line postSplitAtLine, err := l.CloneForCreate(UpdateInput{ - Status: billing.InvoiceLineStatusValid, - PeriodStart: splitAt, - ParentLine: mo.Some(l.line.ParentLine), + Status: billing.InvoiceLineStatusValid, + PeriodStart: splitAt, + ParentLine: mo.Some(l.line.ParentLine), + ResetChildUniqueReferenceID: true, }).Save(ctx) if err != nil { return SplitResult{}, fmt.Errorf("creating split lines: %w", err) } preSplitAtLine, err := l.Update(UpdateInput{ - PeriodEnd: splitAt, - InvoiceAt: splitAt, + PeriodEnd: splitAt, + InvoiceAt: splitAt, + ResetChildUniqueReferenceID: true, }).Save(ctx) if err != nil { return SplitResult{}, fmt.Errorf("updating parent line: %w", err) diff --git a/openmeter/billing/worker/subscription/phaseiterator.go b/openmeter/billing/worker/subscription/phaseiterator.go new file mode 100644 index 000000000..0e36ae412 --- /dev/null +++ b/openmeter/billing/worker/subscription/phaseiterator.go @@ -0,0 +1,302 @@ +package billingworkersubscription + +import ( + "fmt" + "slices" + "strings" + "time" + + "github.com/alpacahq/alpacadecimal" + "github.com/samber/lo" + + "github.com/openmeterio/openmeter/openmeter/billing" + "github.com/openmeterio/openmeter/openmeter/productcatalog" + "github.com/openmeterio/openmeter/openmeter/subscription" + "github.com/openmeterio/openmeter/pkg/models" + "github.com/openmeterio/openmeter/pkg/timex" +) + +// timeInfinity is a big enough time that we can use to represent infinity (biggest possible date for our system) +var timeInfinity = time.Date(9999, 12, 31, 23, 59, 59, 999999999, time.UTC) + +type PhaseIterator struct { + // subscriptionID is the ID of the subscription that is being iterated (used for unique ID generation) + subscriptionID string + // phaseKey is the key of the phase that is being iterated (used for unique ID generation) + phaseKey string + // phaseID is the database ID of the phase that is being iterated (used for DB references) + phaseID string + // phaseCadence is the cadence of the phase that is being iterated + phaseCadence models.CadencedModel + + items [][]subscription.SubscriptionItemView +} + +type subscriptionItemWithPeriod struct { + subscription.SubscriptionItemView + Period billing.Period + UniqueID string + NonTruncatedPeriod billing.Period + PhaseID string +} + +func (r subscriptionItemWithPeriod) IsTruncated() bool { + return !r.Period.Equal(r.NonTruncatedPeriod) +} + +// PeriodPercentage returns the percentage of the period that is actually billed, compared to the non-truncated period +// can be used to calculate prorated prices +func (r subscriptionItemWithPeriod) PeriodPercentage() alpacadecimal.Decimal { + return alpacadecimal.NewFromInt(int64(r.Period.Duration())).Div(alpacadecimal.NewFromInt(int64(r.NonTruncatedPeriod.Duration()))) +} + +func NewPhaseIterator(subs subscription.SubscriptionView, phaseKey string) (*PhaseIterator, error) { + it := &PhaseIterator{ + subscriptionID: subs.Subscription.ID, + phaseKey: phaseKey, + } + + return it, it.ResolvePhaseData(subs, phaseKey) +} + +func (it *PhaseIterator) ResolvePhaseData(subs subscription.SubscriptionView, phaseKey string) error { + phaseCadence := models.CadencedModel{} + var currentPhase *subscription.SubscriptionPhaseView + + slices.SortFunc(subs.Phases, func(i, j subscription.SubscriptionPhaseView) int { + return timex.Compare(i.SubscriptionPhase.ActiveFrom, j.SubscriptionPhase.ActiveFrom) + }) + + for i, phase := range subs.Phases { + if phase.SubscriptionPhase.Key == phaseKey { + phaseCadence.ActiveFrom = phase.SubscriptionPhase.ActiveFrom + + if i < len(subs.Phases)-1 { + phaseCadence.ActiveTo = lo.ToPtr(subs.Phases[i+1].SubscriptionPhase.ActiveFrom) + } + + currentPhase = &phase + + break + } + } + + if currentPhase == nil { + return fmt.Errorf("phase %s not found in subscription %s", phaseKey, subs.Subscription.ID) + } + + it.phaseCadence = phaseCadence + it.phaseID = currentPhase.SubscriptionPhase.ID + + it.items = make([][]subscription.SubscriptionItemView, 0, len(currentPhase.ItemsByKey)) + for _, items := range currentPhase.ItemsByKey { + slices.SortFunc(items, func(i, j subscription.SubscriptionItemView) int { + return timex.Compare(i.SubscriptionItem.ActiveFrom, j.SubscriptionItem.ActiveFrom) + }) + + it.items = append(it.items, items) + } + + return nil +} + +func (it *PhaseIterator) HasInvoicableItems() bool { + for _, itemsByKey := range it.items { + for _, item := range itemsByKey { + if item.Spec.RateCard.Price != nil { + return true + } + } + } + + return false +} + +func (it *PhaseIterator) PhaseEnd() *time.Time { + return it.phaseCadence.ActiveTo +} + +func (it *PhaseIterator) PhaseStart() time.Time { + return it.phaseCadence.ActiveFrom +} + +// GetMinimumBillableTime returns the minimum time that we can bill for the phase (e.g. the first time we would be +// yielding a line item) +func (it *PhaseIterator) GetMinimumBillableTime() time.Time { + minTime := timeInfinity + for _, itemsByKey := range it.items { + for _, item := range itemsByKey { + if item.Spec.RateCard.Price == nil { + continue + } + + if item.SubscriptionItem.RateCard.Price.Type() == productcatalog.FlatPriceType { + if item.SubscriptionItem.ActiveFrom.Before(minTime) { + minTime = item.SubscriptionItem.ActiveFrom + } + } else { + // Let's make sure that truncation won't filter out the item + period := billing.Period{ + Start: item.SubscriptionItem.ActiveFrom, + End: timeInfinity, + } + + if item.SubscriptionItem.ActiveTo != nil { + period.End = *item.SubscriptionItem.ActiveTo + } + + if it.phaseCadence.ActiveTo != nil && period.End.After(*it.phaseCadence.ActiveTo) { + period.End = *it.phaseCadence.ActiveTo + } + + period = period.Truncate(billing.DefaultMeterResolution) + if period.IsEmpty() { + continue + } + + if period.Start.Before(minTime) { + minTime = period.Start + } + } + } + } + + return minTime +} + +func (it *PhaseIterator) Generate(iterationEnd time.Time) ([]subscriptionItemWithPeriod, error) { + out := []subscriptionItemWithPeriod{} + for _, itemsByKey := range it.items { + slices.SortFunc(itemsByKey, func(i, j subscription.SubscriptionItemView) int { + return timex.Compare(i.SubscriptionItem.ActiveFrom, j.SubscriptionItem.ActiveFrom) + }) + + for versionID, item := range itemsByKey { + // Let's drop non-billable items + if item.Spec.RateCard.Price == nil { + continue + } + + if item.Spec.RateCard.BillingCadence == nil { + generatedItem, err := it.generateOneTimeItem(item, versionID) + if err != nil { + return nil, err + } + out = append(out, generatedItem) + continue + } + + start := item.SubscriptionItem.ActiveFrom + periodID := 0 + + for { + end, _ := item.Spec.RateCard.BillingCadence.AddTo(start) + + nonTruncatedPeriod := billing.Period{ + Start: start, + End: end, + } + + if item.SubscriptionItem.ActiveTo != nil && item.SubscriptionItem.ActiveTo.Before(end) { + end = *item.SubscriptionItem.ActiveTo + } + + if it.phaseCadence.ActiveTo != nil && end.After(*it.phaseCadence.ActiveTo) { + end = *it.phaseCadence.ActiveTo + } + + generatedItem := subscriptionItemWithPeriod{ + SubscriptionItemView: item, + Period: billing.Period{ + Start: start, + End: end, + }, + + UniqueID: strings.Join([]string{ + it.subscriptionID, + it.phaseKey, + item.Spec.ItemKey, + fmt.Sprintf("v[%d]", versionID), + fmt.Sprintf("period[%d]", periodID), + }, "/"), + + NonTruncatedPeriod: nonTruncatedPeriod, + PhaseID: it.phaseID, + } + + out = append(out, generatedItem) + + periodID++ + start = end + + // Either we have reached the end of the phase + if it.phaseCadence.ActiveTo != nil && !start.Before(*it.phaseCadence.ActiveTo) { + break + } + + // We have reached the end of the active range + if item.SubscriptionItem.ActiveTo != nil && !start.Before(*item.SubscriptionItem.ActiveTo) { + break + } + + // Or we have reached the iteration end + if !start.Before(iterationEnd) { + break + } + } + } + } + + return it.truncateItemsIfNeeded(out), nil +} + +func (it *PhaseIterator) truncateItemsIfNeeded(in []subscriptionItemWithPeriod) []subscriptionItemWithPeriod { + out := make([]subscriptionItemWithPeriod, 0, len(in)) + // We need to sanitize the output to compensate for the 1min resolution of meters + for _, item := range in { + // We only need to sanitize the items that are not flat priced, flat prices can be handled in any resolution + if item.Spec.RateCard.Price != nil && item.Spec.RateCard.Price.Type() == productcatalog.FlatPriceType { + out = append(out, item) + continue + } + + item.Period = item.Period.Truncate(billing.DefaultMeterResolution) + if item.Period.IsEmpty() { + continue + } + + item.NonTruncatedPeriod = item.NonTruncatedPeriod.Truncate(billing.DefaultMeterResolution) + + out = append(out, item) + } + + return out +} + +func (it *PhaseIterator) generateOneTimeItem(item subscription.SubscriptionItemView, versionID int) (subscriptionItemWithPeriod, error) { + end := lo.CoalesceOrEmpty(item.SubscriptionItem.ActiveTo, it.phaseCadence.ActiveTo) + if end == nil { + // TODO[later]: implement open ended gathering line items, as that's a valid use case to for example: + // Have a plan, that has an open ended billing item for flat fee, then the end user uses progressive billing + // to bill the end user if the usage gets above $1000. Non-gathering lines must have a period end. + return subscriptionItemWithPeriod{}, fmt.Errorf("cannot determine phase end for item %s", item.Spec.ItemKey) + } + + period := billing.Period{ + Start: item.SubscriptionItem.ActiveFrom, + End: *end, + } + + return subscriptionItemWithPeriod{ + SubscriptionItemView: item, + Period: period, + NonTruncatedPeriod: period, + UniqueID: strings.Join([]string{ + it.subscriptionID, + it.phaseKey, + item.Spec.ItemKey, + fmt.Sprintf("v[%d]", versionID), + }, "/"), + PhaseID: it.phaseID, + }, nil +} diff --git a/openmeter/billing/worker/subscription/phaseiterator_test.go b/openmeter/billing/worker/subscription/phaseiterator_test.go new file mode 100644 index 000000000..09f9ad393 --- /dev/null +++ b/openmeter/billing/worker/subscription/phaseiterator_test.go @@ -0,0 +1,487 @@ +package billingworkersubscription + +import ( + "testing" + "time" + + "github.com/samber/lo" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + "github.com/openmeterio/openmeter/openmeter/productcatalog" + "github.com/openmeterio/openmeter/openmeter/subscription" + "github.com/openmeterio/openmeter/pkg/datex" + "github.com/openmeterio/openmeter/pkg/models" +) + +const NotSet = "" + +type PhaseIteratorTestSuite struct { + *require.Assertions + suite.Suite +} + +func TestPhaseIterator(t *testing.T) { + suite.Run(t, new(PhaseIteratorTestSuite)) +} + +func (s *PhaseIteratorTestSuite) SetupSuite() { + s.Assertions = require.New(s.T()) +} + +type expectedIterations struct { + Start time.Time + End time.Time + Key string + NonTruncatedEnd time.Time +} + +type subscriptionItemViewMock struct { + Key string + Cadence string + + ActiveFrom string + ActiveTo string + + Type productcatalog.PriceType +} + +const NoPriceType = productcatalog.PriceType("NoPrice") + +func (s *PhaseIteratorTestSuite) mustParseTime(t string) time.Time { + return lo.Must(time.Parse(time.RFC3339, t)) +} + +func (s *PhaseIteratorTestSuite) TestPhaseIterator() { + tcs := []struct { + name string + items []subscriptionItemViewMock + end time.Time + expected []expectedIterations + phaseEnd *time.Time + expectError bool + }{ + { + name: "empty", + items: []subscriptionItemViewMock{}, + end: s.mustParseTime("2021-01-01T00:00:00Z"), + expected: []expectedIterations{}, + }, + { + name: "sanity", + items: []subscriptionItemViewMock{ + { + Key: "item-key", + Cadence: "P1D", + }, + }, + end: s.mustParseTime("2021-01-03T00:00:00Z"), + expected: []expectedIterations{ + { + Start: s.mustParseTime("2021-01-01T00:00:00Z"), + End: s.mustParseTime("2021-01-02T00:00:00Z"), + Key: "subID/phase-test/item-key/v[0]/period[0]", + }, + { + Start: s.mustParseTime("2021-01-02T00:00:00Z"), + End: s.mustParseTime("2021-01-03T00:00:00Z"), + Key: "subID/phase-test/item-key/v[0]/period[1]", + }, + }, + }, + { + name: "sanity-non-billable-filtering", + items: []subscriptionItemViewMock{ + { + Key: "item-key", + Cadence: "P1D", + }, + { + Key: "item-key-no-price", + Type: NoPriceType, + }, + }, + end: s.mustParseTime("2021-01-03T00:00:00Z"), + expected: []expectedIterations{ + { + Start: s.mustParseTime("2021-01-01T00:00:00Z"), + End: s.mustParseTime("2021-01-02T00:00:00Z"), + Key: "subID/phase-test/item-key/v[0]/period[0]", + }, + { + Start: s.mustParseTime("2021-01-02T00:00:00Z"), + End: s.mustParseTime("2021-01-03T00:00:00Z"), + Key: "subID/phase-test/item-key/v[0]/period[1]", + }, + }, + }, + { + name: "sanity-phase-end", + items: []subscriptionItemViewMock{ + { + Key: "item-key", + Cadence: "P1D", + }, + }, + end: s.mustParseTime("2021-01-03T00:00:00Z"), + expected: []expectedIterations{ + { + Start: s.mustParseTime("2021-01-01T00:00:00Z"), + End: s.mustParseTime("2021-01-02T00:00:00Z"), + Key: "subID/phase-test/item-key/v[0]/period[0]", + }, + { + Start: s.mustParseTime("2021-01-02T00:00:00Z"), + End: s.mustParseTime("2021-01-02T15:00:00Z"), + Key: "subID/phase-test/item-key/v[0]/period[1]", + NonTruncatedEnd: s.mustParseTime("2021-01-03T00:00:00Z"), + }, + }, + phaseEnd: lo.ToPtr(s.mustParseTime("2021-01-02T15:00:00Z")), + }, + { + name: "different cadence", + items: []subscriptionItemViewMock{ + { + Key: "item-key-1d", + Cadence: "P1D", + }, + { + Key: "item-key-2d", + Cadence: "P2D", + }, + }, + end: s.mustParseTime("2021-01-04T00:00:00Z"), + expected: []expectedIterations{ + { + Start: s.mustParseTime("2021-01-01T00:00:00Z"), + End: s.mustParseTime("2021-01-02T00:00:00Z"), + Key: "subID/phase-test/item-key-1d/v[0]/period[0]", + }, + { + Start: s.mustParseTime("2021-01-02T00:00:00Z"), + End: s.mustParseTime("2021-01-03T00:00:00Z"), + Key: "subID/phase-test/item-key-1d/v[0]/period[1]", + }, + { + Start: s.mustParseTime("2021-01-03T00:00:00Z"), + End: s.mustParseTime("2021-01-04T00:00:00Z"), + Key: "subID/phase-test/item-key-1d/v[0]/period[2]", + }, + { + Start: s.mustParseTime("2021-01-01T00:00:00Z"), + End: s.mustParseTime("2021-01-03T00:00:00Z"), + Key: "subID/phase-test/item-key-2d/v[0]/period[0]", + }, + { + Start: s.mustParseTime("2021-01-03T00:00:00Z"), + End: s.mustParseTime("2021-01-05T00:00:00Z"), + Key: "subID/phase-test/item-key-2d/v[0]/period[1]", + }, + }, + }, + { + // Note: this happens on subscription updates, but the active to/from is always disjunct + name: "active-from-to-matching-period", + items: []subscriptionItemViewMock{ + { + Key: "item-key", + Cadence: "P1D", + ActiveTo: "2021-01-02T00:00:00Z", + }, + { + Key: "item-key", + Cadence: "P1D", + ActiveFrom: "2021-01-02T00:00:00Z", + }, + }, + end: s.mustParseTime("2021-01-03T00:00:00Z"), + expected: []expectedIterations{ + { + Start: s.mustParseTime("2021-01-01T00:00:00Z"), + End: s.mustParseTime("2021-01-02T00:00:00Z"), + Key: "subID/phase-test/item-key/v[0]/period[0]", + }, + { + Start: s.mustParseTime("2021-01-02T00:00:00Z"), + End: s.mustParseTime("2021-01-03T00:00:00Z"), + Key: "subID/phase-test/item-key/v[1]/period[0]", + }, + }, + }, + { + name: "active-from-to-missmatching-period", + items: []subscriptionItemViewMock{ + { + Key: "item-key", + Cadence: "P1D", + ActiveTo: "2021-01-02T20:00:00Z", + }, + { + Key: "item-key", + Cadence: "P1D", + ActiveFrom: "2021-01-02T20:00:00Z", + }, + }, + end: s.mustParseTime("2021-01-03T00:00:00Z"), + expected: []expectedIterations{ + { + Start: s.mustParseTime("2021-01-01T00:00:00Z"), + End: s.mustParseTime("2021-01-02T00:00:00Z"), + Key: "subID/phase-test/item-key/v[0]/period[0]", + }, + { + Start: s.mustParseTime("2021-01-02T00:00:00Z"), + End: s.mustParseTime("2021-01-02T20:00:00Z"), + Key: "subID/phase-test/item-key/v[0]/period[1]", + NonTruncatedEnd: s.mustParseTime("2021-01-03T00:00:00Z"), + }, + { + Start: s.mustParseTime("2021-01-02T20:00:00Z"), + End: s.mustParseTime("2021-01-03T20:00:00Z"), + Key: "subID/phase-test/item-key/v[1]/period[0]", + }, + }, + }, + { + name: "ubp-time truncated", + items: []subscriptionItemViewMock{ + { + Key: "item-key", + Cadence: "P1D", + ActiveTo: "2021-01-02T20:00:02Z", + Type: productcatalog.UnitPriceType, + }, + { + Key: "item-key", + Cadence: "P1D", + ActiveFrom: "2021-01-02T20:00:02Z", + ActiveTo: "2021-01-02T20:00:03Z", + Type: productcatalog.UnitPriceType, + }, + { + Key: "item-key", + Cadence: "P1D", + ActiveFrom: "2021-01-02T20:00:03Z", + ActiveTo: "2021-01-02T20:00:04Z", + Type: productcatalog.UnitPriceType, + }, + { + Key: "item-key", + Cadence: "P1D", + ActiveFrom: "2021-01-02T20:00:04Z", + Type: productcatalog.UnitPriceType, + }, + }, + end: s.mustParseTime("2021-01-03T00:00:00Z"), + expected: []expectedIterations{ + { + Start: s.mustParseTime("2021-01-01T00:00:00Z"), + End: s.mustParseTime("2021-01-02T00:00:00Z"), + Key: "subID/phase-test/item-key/v[0]/period[0]", + }, + { + Start: s.mustParseTime("2021-01-02T00:00:00Z"), + End: s.mustParseTime("2021-01-02T20:00:00Z"), + Key: "subID/phase-test/item-key/v[0]/period[1]", + NonTruncatedEnd: s.mustParseTime("2021-01-03T00:00:00Z"), + }, + { + Start: s.mustParseTime("2021-01-02T20:00:00Z"), + End: s.mustParseTime("2021-01-03T20:00:00Z"), + Key: "subID/phase-test/item-key/v[3]/period[0]", + }, + }, + }, + { + name: "flat-fee recurring", + items: []subscriptionItemViewMock{ + { + Key: "item-key", + Cadence: "P1D", + Type: productcatalog.FlatPriceType, + }, + }, + end: s.mustParseTime("2021-01-03T00:00:00Z"), + expected: []expectedIterations{ + { + Start: s.mustParseTime("2021-01-01T00:00:00Z"), + End: s.mustParseTime("2021-01-02T00:00:00Z"), + Key: "subID/phase-test/item-key/v[0]/period[0]", + }, + { + Start: s.mustParseTime("2021-01-02T00:00:00Z"), + End: s.mustParseTime("2021-01-03T00:00:00Z"), + Key: "subID/phase-test/item-key/v[0]/period[1]", + }, + }, + }, + { + name: "flat-fee one-time", + items: []subscriptionItemViewMock{ + { + Key: "item-key", + Type: productcatalog.FlatPriceType, + }, + }, + end: s.mustParseTime("2021-01-03T00:00:00Z"), + phaseEnd: lo.ToPtr(s.mustParseTime("2021-01-05T00:00:00Z")), + expected: []expectedIterations{ + { + Start: s.mustParseTime("2021-01-01T00:00:00Z"), + End: s.mustParseTime("2021-01-05T00:00:00Z"), + Key: "subID/phase-test/item-key/v[0]", + }, + }, + }, + { + name: "flat-fee recurring, edited", + items: []subscriptionItemViewMock{ + { + Key: "item-key", + Type: productcatalog.FlatPriceType, + Cadence: "P1D", + ActiveTo: "2021-01-02T20:00:00Z", + }, + { + Key: "item-key", + Type: productcatalog.FlatPriceType, + Cadence: "P1D", + ActiveFrom: "2021-01-02T20:00:00Z", + }, + }, + end: s.mustParseTime("2021-01-03T00:00:00Z"), + expected: []expectedIterations{ + { + Start: s.mustParseTime("2021-01-01T00:00:00Z"), + End: s.mustParseTime("2021-01-02T00:00:00Z"), + Key: "subID/phase-test/item-key/v[0]/period[0]", + }, + { + Start: s.mustParseTime("2021-01-02T00:00:00Z"), + End: s.mustParseTime("2021-01-02T20:00:00Z"), + Key: "subID/phase-test/item-key/v[0]/period[1]", + NonTruncatedEnd: s.mustParseTime("2021-01-03T00:00:00Z"), + }, + { + Start: s.mustParseTime("2021-01-02T20:00:00Z"), + End: s.mustParseTime("2021-01-03T20:00:00Z"), + Key: "subID/phase-test/item-key/v[1]/period[0]", + }, + }, + }, + { + name: "flat-fee one-time, no phase end", + items: []subscriptionItemViewMock{ + { + Key: "item-key", + Type: productcatalog.FlatPriceType, + }, + }, + expectError: true, + }, + } + + for _, tc := range tcs { + s.Run(tc.name, func() { + phase := subscription.SubscriptionPhaseView{ + SubscriptionPhase: subscription.SubscriptionPhase{ + ActiveFrom: lo.Must(time.Parse(time.RFC3339, "2021-01-01T00:00:00Z")), + Key: "phase-test", + }, + ItemsByKey: map[string][]subscription.SubscriptionItemView{}, + } + + for _, item := range tc.items { + sitem := subscription.SubscriptionItemView{} + + sitem.Spec.ItemKey = item.Key + switch item.Type { + case productcatalog.UnitPriceType: + sitem.Spec.RateCard.Price = productcatalog.NewPriceFrom(productcatalog.UnitPrice{}) + case productcatalog.FlatPriceType: + sitem.Spec.RateCard.Price = productcatalog.NewPriceFrom(productcatalog.FlatPrice{}) + case NoPriceType: + sitem.Spec.RateCard.Price = nil + default: + sitem.Spec.RateCard.Price = productcatalog.NewPriceFrom(productcatalog.UnitPrice{}) + } + + if item.Cadence != "" { + sitem.Spec.RateCard.BillingCadence = lo.ToPtr(datex.MustParse(s.T(), item.Cadence)) + } + + if item.ActiveFrom != "" { + sitem.SubscriptionItem.ActiveFrom = lo.Must(time.Parse(time.RFC3339, item.ActiveFrom)) + } + + if item.ActiveTo != "" { + sitem.SubscriptionItem.ActiveTo = lo.ToPtr(lo.Must(time.Parse(time.RFC3339, item.ActiveTo))) + } + + if sitem.SubscriptionItem.ActiveFrom.IsZero() { + sitem.SubscriptionItem.ActiveFrom = phase.SubscriptionPhase.ActiveFrom + } + + phase.ItemsByKey[sitem.Spec.ItemKey] = append(phase.ItemsByKey[sitem.Spec.ItemKey], sitem) + } + + subs := subscription.SubscriptionView{ + Subscription: subscription.Subscription{ + NamespacedID: models.NamespacedID{ + ID: "subID", + }, + }, + Phases: []subscription.SubscriptionPhaseView{phase}, + } + + if tc.phaseEnd != nil { + subs.Phases = append(subs.Phases, subscription.SubscriptionPhaseView{ + SubscriptionPhase: subscription.SubscriptionPhase{ + ActiveFrom: *tc.phaseEnd, + }, + }) + } + + it, err := NewPhaseIterator( + subs, + phase.SubscriptionPhase.Key, + ) + s.NoError(err) + + out, err := it.Generate(tc.end) + if tc.expectError { + s.Error(err) + return + } else { + s.NoError(err) + } + + outAsExpect := make([]expectedIterations, 0, len(out)) + for i, item := range out { + // For now we never truncate the start, so we can just codify this + s.Equal(item.Period.Start, item.NonTruncatedPeriod.Start) + + nonTruncatedEnd := time.Time{} + if !item.NonTruncatedPeriod.End.Equal(item.Period.End) { + nonTruncatedEnd = item.NonTruncatedPeriod.End + } + + outAsExpect = append(outAsExpect, expectedIterations{ + Start: item.Period.Start, + End: item.Period.End, + Key: item.UniqueID, + NonTruncatedEnd: nonTruncatedEnd, + }) + + s.T().Logf("out[%d]: [%s..%s] %s (non-truncated: %s)\n", i, item.Period.Start, item.Period.End, item.UniqueID, nonTruncatedEnd) + } + + for i, item := range tc.expected { + s.T().Logf("expected[%d]: [%s..%s] %s (non-truncated: %s)\n", i, item.Start, item.End, item.Key, item.NonTruncatedEnd) + } + + s.ElementsMatch(tc.expected, outAsExpect) + }) + } +} diff --git a/openmeter/billing/worker/subscription/scanario_test.go b/openmeter/billing/worker/subscription/scanario_test.go new file mode 100644 index 000000000..8377884a3 --- /dev/null +++ b/openmeter/billing/worker/subscription/scanario_test.go @@ -0,0 +1,732 @@ +package billingworkersubscription + +import ( + "context" + "log/slog" + "testing" + "time" + + "github.com/alpacahq/alpacadecimal" + "github.com/invopop/gobl/currency" + "github.com/samber/lo" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + "github.com/openmeterio/openmeter/openmeter/billing" + "github.com/openmeterio/openmeter/openmeter/credit" + grantrepo "github.com/openmeterio/openmeter/openmeter/credit/adapter" + customerentity "github.com/openmeterio/openmeter/openmeter/customer/entity" + enttx "github.com/openmeterio/openmeter/openmeter/ent/tx" + "github.com/openmeterio/openmeter/openmeter/entitlement" + entitlementrepo "github.com/openmeterio/openmeter/openmeter/entitlement/adapter" + booleanentitlement "github.com/openmeterio/openmeter/openmeter/entitlement/boolean" + meteredentitlement "github.com/openmeterio/openmeter/openmeter/entitlement/metered" + staticentitlement "github.com/openmeterio/openmeter/openmeter/entitlement/static" + "github.com/openmeterio/openmeter/openmeter/productcatalog" + "github.com/openmeterio/openmeter/openmeter/productcatalog/feature" + "github.com/openmeterio/openmeter/openmeter/productcatalog/plan" + planadapter "github.com/openmeterio/openmeter/openmeter/productcatalog/plan/adapter" + planservice "github.com/openmeterio/openmeter/openmeter/productcatalog/plan/service" + plansubscription "github.com/openmeterio/openmeter/openmeter/productcatalog/subscription" + productcatalogsubscription "github.com/openmeterio/openmeter/openmeter/productcatalog/subscription" + "github.com/openmeterio/openmeter/openmeter/subscription" + subscriptionentitlementadatapter "github.com/openmeterio/openmeter/openmeter/subscription/adapters/entitlement" + subscriptionrepo "github.com/openmeterio/openmeter/openmeter/subscription/repo" + subscriptionservice "github.com/openmeterio/openmeter/openmeter/subscription/service" + "github.com/openmeterio/openmeter/openmeter/watermill/eventbus" + "github.com/openmeterio/openmeter/pkg/clock" + "github.com/openmeterio/openmeter/pkg/currencyx" + "github.com/openmeterio/openmeter/pkg/datex" + "github.com/openmeterio/openmeter/pkg/models" + "github.com/openmeterio/openmeter/pkg/pagination" + billingtest "github.com/openmeterio/openmeter/test/billing" +) + +type SubscriptionHandlerTestSuite struct { + billingtest.BaseSuite + + PlanService plan.Service + SubscriptionService subscription.Service + SubscrpiptionPlanAdapter plansubscription.Adapter + SubscriptionWorkflowService subscription.WorkflowService + + Handler *Handler +} + +func (s *SubscriptionHandlerTestSuite) SetupSuite() { + s.BaseSuite.SetupSuite() + + planAdapter, err := planadapter.New(planadapter.Config{ + Client: s.DBClient, + Logger: slog.Default(), + }) + s.NoError(err) + + planService, err := planservice.New(planservice.Config{ + Feature: s.FeatureService, + Adapter: planAdapter, + Logger: slog.Default(), + }) + s.NoError(err) + + s.PlanService = planService + + subsRepo := subscriptionrepo.NewSubscriptionRepo(s.DBClient) + subsItemRepo := subscriptionrepo.NewSubscriptionItemRepo(s.DBClient) + + s.SubscriptionService = subscriptionservice.New(subscriptionservice.ServiceConfig{ + SubscriptionRepo: subsRepo, + SubscriptionPhaseRepo: subscriptionrepo.NewSubscriptionPhaseRepo(s.DBClient), + SubscriptionItemRepo: subsItemRepo, + // connectors + CustomerService: s.CustomerService, + // adapters + EntitlementAdapter: subscriptionentitlementadatapter.NewSubscriptionEntitlementAdapter( + s.SetupEntitlements(), + subsItemRepo, + subsRepo, + ), + // framework + TransactionManager: subsRepo, + }) + + s.SubscrpiptionPlanAdapter = plansubscription.NewPlanSubscriptionAdapter(plansubscription.PlanSubscriptionAdapterConfig{ + PlanService: planService, + Logger: slog.Default(), + }) + + s.SubscriptionWorkflowService = subscriptionservice.NewWorkflowService(subscriptionservice.WorkflowServiceConfig{ + Service: s.SubscriptionService, + CustomerService: s.CustomerService, + TransactionManager: subsRepo, + }) + + handler, err := New(Config{ + BillingService: s.BillingService, + Logger: slog.Default(), + TxCreator: s.BillingAdapter, + }) + s.NoError(err) + + s.Handler = handler +} + +func (s *SubscriptionHandlerTestSuite) SetupEntitlements() entitlement.Connector { + // Init grants/credit + grantRepo := grantrepo.NewPostgresGrantRepo(s.DBClient) + balanceSnapshotRepo := grantrepo.NewPostgresBalanceSnapshotRepo(s.DBClient) + + // Init entitlements + entitlementRepo := entitlementrepo.NewPostgresEntitlementRepo(s.DBClient) + usageResetRepo := entitlementrepo.NewPostgresUsageResetRepo(s.DBClient) + + mockPublisher := eventbus.NewMock(s.T()) + + owner := meteredentitlement.NewEntitlementGrantOwnerAdapter( + s.FeatureRepo, + entitlementRepo, + usageResetRepo, + s.MeterRepo, + slog.Default(), + ) + + transactionManager := enttx.NewCreator(s.DBClient) + + creditConnector := credit.NewCreditConnector( + grantRepo, + balanceSnapshotRepo, + owner, + s.MockStreamingConnector, + slog.Default(), + time.Minute, + mockPublisher, + transactionManager, + ) + + meteredEntitlementConnector := meteredentitlement.NewMeteredEntitlementConnector( + s.MockStreamingConnector, + owner, + creditConnector, + creditConnector, + grantRepo, + entitlementRepo, + mockPublisher, + ) + + staticEntitlementConnector := staticentitlement.NewStaticEntitlementConnector() + booleanEntitlementConnector := booleanentitlement.NewBooleanEntitlementConnector() + + return entitlement.NewEntitlementConnector( + entitlementRepo, + s.FeatureService, + s.MeterRepo, + meteredEntitlementConnector, + staticEntitlementConnector, + booleanEntitlementConnector, + mockPublisher, + ) +} + +func TestSubscriptionHandlerScenarios(t *testing.T) { + suite.Run(t, new(SubscriptionHandlerTestSuite)) +} + +func (s *SubscriptionHandlerTestSuite) mustParseTime(t string) time.Time { + return lo.Must(time.Parse(time.RFC3339, t)) +} + +func (s *SubscriptionHandlerTestSuite) TestSubscriptionHappyPath() { + ctx := context.Background() + namespace := "test-subs-happy-path" + start := s.mustParseTime("2024-01-01T00:00:00Z") + clock.SetTime(start) + defer clock.ResetTime() + + _ = s.InstallSandboxApp(s.T(), namespace) + + minimalCreateProfileInput := billingtest.MinimalCreateProfileInputTemplate + minimalCreateProfileInput.Namespace = namespace + + profile, err := s.BillingService.CreateProfile(ctx, minimalCreateProfileInput) + s.NoError(err) + s.NotNil(profile) + + apiRequestsTotalMeterSlug := "api-requests-total" + + s.MeterRepo.ReplaceMeters(ctx, []models.Meter{ + { + Namespace: namespace, + Slug: apiRequestsTotalMeterSlug, + WindowSize: models.WindowSizeMinute, + Aggregation: models.MeterAggregationSum, + }, + }) + defer s.MeterRepo.ReplaceMeters(ctx, []models.Meter{}) + + apiRequestsTotalFeatureKey := "api-requests-total" + + apiRequestsTotalFeature, err := s.FeatureService.CreateFeature(ctx, feature.CreateFeatureInputs{ + Namespace: namespace, + Name: "api-requests-total", + Key: apiRequestsTotalFeatureKey, + MeterSlug: lo.ToPtr("api-requests-total"), + }) + s.NoError(err) + + customerEntity, err := s.CustomerService.CreateCustomer(ctx, customerentity.CreateCustomerInput{ + Namespace: namespace, + + CustomerMutate: customerentity.CustomerMutate{ + Name: "Test Customer", + PrimaryEmail: lo.ToPtr("test@test.com"), + BillingAddress: &models.Address{ + Country: lo.ToPtr(models.CountryCode("US")), + }, + Currency: lo.ToPtr(currencyx.Code(currency.USD)), + UsageAttribution: customerentity.CustomerUsageAttribution{ + SubjectKeys: []string{"test"}, + }, + }, + }) + require.NoError(s.T(), err) + require.NotNil(s.T(), customerEntity) + require.NotEmpty(s.T(), customerEntity.ID) + + plan, err := s.PlanService.CreatePlan(ctx, plan.CreatePlanInput{ + NamespacedModel: models.NamespacedModel{ + Namespace: namespace, + }, + Plan: productcatalog.Plan{ + PlanMeta: productcatalog.PlanMeta{ + Name: "Test Plan", + Key: "test-plan", + Version: 1, + Currency: currency.USD, + }, + + Phases: []productcatalog.Phase{ + { + PhaseMeta: productcatalog.PhaseMeta{ + Name: "free trial", + Key: "free-trial", + StartAfter: datex.MustParse(s.T(), "P0D"), + }, + // TODO[OM-1031]: let's add discount handling (as this could be a 100% discount for the first month) + RateCards: productcatalog.RateCards{ + &productcatalog.UsageBasedRateCard{ + RateCardMeta: productcatalog.RateCardMeta{ + Key: apiRequestsTotalFeatureKey, + Name: apiRequestsTotalFeatureKey, + Feature: &apiRequestsTotalFeature, + }, + BillingCadence: datex.MustParse(s.T(), "P1M"), + }, + }, + }, + { + PhaseMeta: productcatalog.PhaseMeta{ + Name: "discounted phase", + Key: "discounted-phase", + StartAfter: datex.MustParse(s.T(), "P1M"), + }, + // TODO[OM-1031]: 50% discount + RateCards: productcatalog.RateCards{ + &productcatalog.UsageBasedRateCard{ + RateCardMeta: productcatalog.RateCardMeta{ + Key: apiRequestsTotalFeatureKey, + Name: apiRequestsTotalFeatureKey, + Feature: &apiRequestsTotalFeature, + Price: productcatalog.NewPriceFrom(productcatalog.UnitPrice{ + Amount: alpacadecimal.NewFromFloat(5), + }), + }, + BillingCadence: datex.MustParse(s.T(), "P1M"), + }, + }, + }, + { + PhaseMeta: productcatalog.PhaseMeta{ + Name: "final phase", + Key: "final-phase", + StartAfter: datex.MustParse(s.T(), "P3M"), + }, + RateCards: productcatalog.RateCards{ + &productcatalog.UsageBasedRateCard{ + RateCardMeta: productcatalog.RateCardMeta{ + Key: apiRequestsTotalFeatureKey, + Name: apiRequestsTotalFeatureKey, + Feature: &apiRequestsTotalFeature, + Price: productcatalog.NewPriceFrom(productcatalog.UnitPrice{ + Amount: alpacadecimal.NewFromFloat(10), + }), + }, + BillingCadence: datex.MustParse(s.T(), "P1M"), + }, + }, + }, + }, + }, + }) + + s.NoError(err) + s.NotNil(plan) + + subscriptionPlan, err := s.SubscrpiptionPlanAdapter.GetVersion(ctx, namespace, productcatalogsubscription.PlanRefInput{ + Key: plan.Key, + Version: lo.ToPtr(1), + }) + s.NoError(err) + + subsView, err := s.SubscriptionWorkflowService.CreateFromPlan(ctx, subscription.CreateSubscriptionWorkflowInput{ + Namespace: namespace, + ActiveFrom: start, + CustomerID: customerEntity.ID, + Name: "subs-1", + }, subscriptionPlan) + + s.NoError(err) + s.NotNil(subsView) + + freeTierPhase := getPhraseByKey(s.T(), subsView, "free-trial") + s.Equal(lo.ToPtr(datex.MustParse(s.T(), "P1M")), freeTierPhase.ItemsByKey[apiRequestsTotalFeatureKey][0].Spec.RateCard.BillingCadence) + + discountedPhase := getPhraseByKey(s.T(), subsView, "discounted-phase") + var gatheringInvoiceID billing.InvoiceID + + // let's provision the first set of items + s.Run("provision first set of items", func() { + s.NoError(s.Handler.SyncronizeSubscription(ctx, subsView, clock.Now())) + + // then there should be a gathering invoice + invoices, err := s.BillingService.ListInvoices(ctx, billing.ListInvoicesInput{ + Namespace: namespace, + Customers: []string{customerEntity.ID}, + Page: pagination.Page{ + PageSize: 10, + PageNumber: 1, + }, + Expand: billing.InvoiceExpandAll, + }) + s.NoError(err) + s.Len(invoices.Items, 1) + + invoice := invoices.Items[0] + s.Equal(billing.InvoiceStatusGathering, invoice.Status) + s.Len(invoice.Lines.OrEmpty(), 1) + + line := invoice.Lines.OrEmpty()[0] + s.Equal(line.Subscription.SubscriptionID, subsView.Subscription.ID) + s.Equal(line.Subscription.PhaseID, discountedPhase.SubscriptionPhase.ID) + s.Equal(line.Subscription.ItemID, discountedPhase.ItemsByKey[apiRequestsTotalFeatureKey][0].SubscriptionItem.ID) + // 1 month free tier + in arrears billing with 1 month cadence + s.Equal(line.InvoiceAt, s.mustParseTime("2024-03-01T00:00:00Z")) + + // When we advance the clock the invoice doesn't get changed + clock.SetTime(s.mustParseTime("2024-02-01T00:00:00Z")) + + s.NoError(s.Handler.SyncronizeSubscription(ctx, subsView, clock.Now())) + + gatheringInvoice, err := s.BillingService.GetInvoiceByID(ctx, billing.GetInvoiceByIdInput{ + Invoice: invoice.InvoiceID(), + Expand: billing.InvoiceExpandAll, + }) + s.NoError(err) + gatheringInvoiceID = gatheringInvoice.InvoiceID() + + gatheringLine := gatheringInvoice.Lines.OrEmpty()[0] + + // TODO[OM-1039]: the invoice's updated at gets updated even if the invoice is not changed + s.Equal(billing.InvoiceStatusGathering, gatheringInvoice.Status) + s.Equal(line.UpdatedAt, gatheringLine.UpdatedAt) + }) + + s.NoError(gatheringInvoiceID.Validate()) + + // Progressive billing updates + s.Run("progressive billing updates", func() { + s.MockStreamingConnector.AddSimpleEvent( + apiRequestsTotalMeterSlug, + 100, + s.mustParseTime("2024-02-02T00:00:00Z")) + clock.SetTime(s.mustParseTime("2024-02-15T00:00:00Z")) + + // we invoice the customer + invoices, err := s.BillingService.InvoicePendingLines(ctx, billing.InvoicePendingLinesInput{ + Customer: customerentity.CustomerID{ + ID: customerEntity.ID, + Namespace: namespace, + }, + }) + s.NoError(err) + s.Len(invoices, 1) + invoice := invoices[0] + + s.Equal(billing.InvoiceStatusDraftWaitingAutoApproval, invoice.Status) + s.Equal(float64(5*100), invoice.Totals.Total.InexactFloat64()) + + s.Len(invoice.Lines.OrEmpty(), 1) + line := invoice.Lines.OrEmpty()[0] + s.Equal(line.Subscription.SubscriptionID, subsView.Subscription.ID) + s.Equal(line.Subscription.PhaseID, discountedPhase.SubscriptionPhase.ID) + s.Equal(line.Subscription.ItemID, discountedPhase.ItemsByKey[apiRequestsTotalFeatureKey][0].SubscriptionItem.ID) + s.Equal(line.InvoiceAt, s.mustParseTime("2024-02-15T00:00:00Z")) + s.Equal(line.Period, billing.Period{ + Start: s.mustParseTime("2024-02-01T00:00:00Z"), + End: s.mustParseTime("2024-02-15T00:00:00Z"), + }) + + // let's fetch the gathering invoice + gatheringInvoice, err := s.BillingService.GetInvoiceByID(ctx, billing.GetInvoiceByIdInput{ + Invoice: gatheringInvoiceID, + Expand: billing.InvoiceExpandAll, + }) + s.NoError(err) + + s.Len(gatheringInvoice.Lines.OrEmpty(), 1) + gatheringLine := gatheringInvoice.Lines.OrEmpty()[0] + s.Equal(gatheringLine.Subscription.SubscriptionID, subsView.Subscription.ID) + s.Equal(gatheringLine.Subscription.PhaseID, discountedPhase.SubscriptionPhase.ID) + s.Equal(gatheringLine.Subscription.ItemID, discountedPhase.ItemsByKey[apiRequestsTotalFeatureKey][0].SubscriptionItem.ID) + s.Equal(gatheringLine.InvoiceAt, s.mustParseTime("2024-03-01T00:00:00Z")) + s.Equal(gatheringLine.Period, billing.Period{ + Start: s.mustParseTime("2024-02-15T00:00:00Z"), + End: s.mustParseTime("2024-03-01T00:00:00Z"), + }) + + // TODO[OM-1037]: let's add/change some items of the subscription then expect that the new item appears on the gathering + // invoice, but the draft invoice is untouched. + }) + + s.Run("subscription cancellation", func() { + clock.SetTime(s.mustParseTime("2024-02-20T00:00:00Z")) + + cancelAt := s.mustParseTime("2024-02-22T00:00:00Z") + subs, err := s.SubscriptionService.Cancel(ctx, models.NamespacedID{ + Namespace: namespace, + ID: subsView.Subscription.ID, + }, cancelAt) + s.NoError(err) + + subsView, err = s.SubscriptionService.GetView(ctx, models.NamespacedID{ + Namespace: namespace, + ID: subs.ID, + }) + s.NoError(err) + + // Subscription has set the cancellation date, and the view's subscription items are updated to have the cadence + // set properly up to the cancellation date. + + // If we are now resyncing the subscription, the gathering invoice should be updated to reflect the new cadence. + + s.NoError(s.Handler.SyncronizeSubscription(ctx, subsView, clock.Now())) + + gatheringInvoice, err := s.BillingService.GetInvoiceByID(ctx, billing.GetInvoiceByIdInput{ + Invoice: gatheringInvoiceID, + Expand: billing.InvoiceExpandAll.SetSplitLines(true), + }) + s.NoError(err) + + s.Len(gatheringInvoice.Lines.OrEmpty(), 2) + gatheringLinesByType := lo.GroupBy(gatheringInvoice.Lines.OrEmpty(), func(line *billing.Line) billing.InvoiceLineStatus { + return line.Status + }) + + s.Len(gatheringLinesByType[billing.InvoiceLineStatusValid], 1) + gatheringLine := gatheringLinesByType[billing.InvoiceLineStatusValid][0] + + s.Equal(gatheringLine.Subscription.SubscriptionID, subsView.Subscription.ID) + s.Equal(gatheringLine.Subscription.PhaseID, discountedPhase.SubscriptionPhase.ID) + s.Equal(gatheringLine.Subscription.ItemID, discountedPhase.ItemsByKey[apiRequestsTotalFeatureKey][0].SubscriptionItem.ID) + + s.Equal(gatheringLine.Period, billing.Period{ + Start: s.mustParseTime("2024-02-15T00:00:00Z"), + End: cancelAt, + }) + s.Equal(gatheringLine.InvoiceAt, cancelAt) + + // split line + s.Len(gatheringLinesByType[billing.InvoiceLineStatusSplit], 1) + splitLine := gatheringLinesByType[billing.InvoiceLineStatusSplit][0] + + s.Equal(splitLine.Subscription.SubscriptionID, subsView.Subscription.ID) + s.Equal(splitLine.Period, billing.Period{ + Start: s.mustParseTime("2024-02-01T00:00:00Z"), + End: s.mustParseTime("2024-02-22T00:00:00Z"), + }) + }) + + s.Run("continue subscription", func() { + clock.SetTime(s.mustParseTime("2024-02-21T00:00:00Z")) + + subs, err := s.SubscriptionService.Continue(ctx, models.NamespacedID{ + Namespace: namespace, + ID: subsView.Subscription.ID, + }) + s.NoError(err) + + subsView, err = s.SubscriptionService.GetView(ctx, models.NamespacedID{ + Namespace: namespace, + ID: subs.ID, + }) + s.NoError(err) + + // If we are now resyncing the subscription, the gathering invoice should be updated to reflect the original cadence + + s.NoError(s.Handler.SyncronizeSubscription(ctx, subsView, clock.Now())) + + gatheringInvoice, err := s.BillingService.GetInvoiceByID(ctx, billing.GetInvoiceByIdInput{ + Invoice: gatheringInvoiceID, + Expand: billing.InvoiceExpandAll.SetSplitLines(true), + }) + s.NoError(err) + + s.Len(gatheringInvoice.Lines.OrEmpty(), 2) + gatheringLinesByType := lo.GroupBy(gatheringInvoice.Lines.OrEmpty(), func(line *billing.Line) billing.InvoiceLineStatus { + return line.Status + }) + + s.Len(gatheringLinesByType[billing.InvoiceLineStatusValid], 1) + gatheringLine := gatheringLinesByType[billing.InvoiceLineStatusValid][0] + + s.Equal(gatheringLine.Subscription.SubscriptionID, subsView.Subscription.ID) + s.Equal(gatheringLine.Subscription.PhaseID, discountedPhase.SubscriptionPhase.ID) + s.Equal(gatheringLine.Subscription.ItemID, discountedPhase.ItemsByKey[apiRequestsTotalFeatureKey][0].SubscriptionItem.ID) + + s.Equal(gatheringLine.Period, billing.Period{ + Start: s.mustParseTime("2024-02-15T00:00:00Z"), + End: s.mustParseTime("2024-03-01T00:00:00Z"), + }) + s.Equal(gatheringLine.InvoiceAt, s.mustParseTime("2024-03-01T00:00:00Z")) + + // split line + s.Len(gatheringLinesByType[billing.InvoiceLineStatusSplit], 1) + splitLine := gatheringLinesByType[billing.InvoiceLineStatusSplit][0] + + s.Equal(splitLine.Subscription.SubscriptionID, subsView.Subscription.ID) + s.Equal(splitLine.Period, billing.Period{ + Start: s.mustParseTime("2024-02-01T00:00:00Z"), + End: s.mustParseTime("2024-03-01T00:00:00Z"), + }) + }) +} + +func (s *SubscriptionHandlerTestSuite) TestInArrearsProrating() { + ctx := context.Background() + namespace := "test-subs-pro-rating" + start := s.mustParseTime("2024-01-01T00:00:00Z") + clock.SetTime(start) + defer clock.ResetTime() + + _ = s.InstallSandboxApp(s.T(), namespace) + + minimalCreateProfileInput := billingtest.MinimalCreateProfileInputTemplate + minimalCreateProfileInput.Namespace = namespace + + profile, err := s.BillingService.CreateProfile(ctx, minimalCreateProfileInput) + s.NoError(err) + s.NotNil(profile) + + customerEntity, err := s.CustomerService.CreateCustomer(ctx, customerentity.CreateCustomerInput{ + Namespace: namespace, + + CustomerMutate: customerentity.CustomerMutate{ + Name: "Test Customer", + PrimaryEmail: lo.ToPtr("test@test.com"), + BillingAddress: &models.Address{ + Country: lo.ToPtr(models.CountryCode("US")), + }, + Currency: lo.ToPtr(currencyx.Code(currency.USD)), + UsageAttribution: customerentity.CustomerUsageAttribution{ + SubjectKeys: []string{"test"}, + }, + }, + }) + require.NoError(s.T(), err) + require.NotNil(s.T(), customerEntity) + require.NotEmpty(s.T(), customerEntity.ID) + + plan, err := s.PlanService.CreatePlan(ctx, plan.CreatePlanInput{ + NamespacedModel: models.NamespacedModel{ + Namespace: namespace, + }, + Plan: productcatalog.Plan{ + PlanMeta: productcatalog.PlanMeta{ + Name: "Test Plan", + Key: "test-plan", + Version: 1, + Currency: currency.USD, + }, + + Phases: []productcatalog.Phase{ + { + PhaseMeta: productcatalog.PhaseMeta{ + Name: "first-phase", + Key: "first-phase", + StartAfter: datex.MustParse(s.T(), "P0D"), + }, + RateCards: productcatalog.RateCards{ + &productcatalog.UsageBasedRateCard{ + RateCardMeta: productcatalog.RateCardMeta{ + Key: "in-arrears", + Name: "in-arrears", + Price: productcatalog.NewPriceFrom(productcatalog.FlatPrice{ + Amount: alpacadecimal.NewFromFloat(5), + PaymentTerm: productcatalog.InArrearsPaymentTerm, + }), + }, + BillingCadence: datex.MustParse(s.T(), "P1D"), + }, + }, + }, + }, + }, + }) + + s.NoError(err) + s.NotNil(plan) + + subscriptionPlan, err := s.SubscrpiptionPlanAdapter.GetVersion(ctx, namespace, productcatalogsubscription.PlanRefInput{ + Key: plan.Key, + Version: lo.ToPtr(1), + }) + s.NoError(err) + + subsView, err := s.SubscriptionWorkflowService.CreateFromPlan(ctx, subscription.CreateSubscriptionWorkflowInput{ + Namespace: namespace, + ActiveFrom: start, + CustomerID: customerEntity.ID, + Name: "subs-1", + }, subscriptionPlan) + + s.NoError(err) + s.NotNil(subsView) + + // let's provision the first set of items + s.Run("provision first set of items", func() { + s.NoError(s.Handler.SyncronizeSubscription(ctx, subsView, clock.Now())) + + // then there should be a gathering invoice + invoices, err := s.BillingService.ListInvoices(ctx, billing.ListInvoicesInput{ + Namespace: namespace, + Customers: []string{customerEntity.ID}, + Page: pagination.Page{ + PageSize: 10, + PageNumber: 1, + }, + Expand: billing.InvoiceExpandAll, + }) + s.NoError(err) + s.Len(invoices.Items, 1) + + lines := invoices.Items[0].Lines.OrEmpty() + s.Len(lines, 1) + + flatFeeLine := lines[0] + s.Equal(flatFeeLine.Subscription.SubscriptionID, subsView.Subscription.ID) + s.Equal(flatFeeLine.Subscription.PhaseID, subsView.Phases[0].SubscriptionPhase.ID) + s.Equal(flatFeeLine.Subscription.ItemID, subsView.Phases[0].ItemsByKey["in-arrears"][0].SubscriptionItem.ID) + s.Equal(flatFeeLine.InvoiceAt, s.mustParseTime("2024-01-02T00:00:00Z")) + s.Equal(flatFeeLine.Period, billing.Period{ + Start: s.mustParseTime("2024-01-01T00:00:00Z"), + End: s.mustParseTime("2024-01-02T00:00:00Z"), + }) + s.Equal(flatFeeLine.FlatFee.PerUnitAmount.InexactFloat64(), 5.0) + s.Equal(flatFeeLine.FlatFee.Quantity.InexactFloat64(), 1.0) + }) + + s.Run("canceling the subscription causes the existing item to be pro-rated", func() { + clock.SetTime(s.mustParseTime("2024-01-01T10:00:00Z")) + + cancelAt := s.mustParseTime("2024-01-01T12:00:00Z") + subs, err := s.SubscriptionService.Cancel(ctx, models.NamespacedID{ + Namespace: namespace, + ID: subsView.Subscription.ID, + }, cancelAt) + s.NoError(err) + + subsView, err = s.SubscriptionService.GetView(ctx, models.NamespacedID{ + Namespace: namespace, + ID: subs.ID, + }) + s.NoError(err) + + s.NoError(s.Handler.SyncronizeSubscription(ctx, subsView, clock.Now())) + + // then there should be a gathering invoice + invoices, err := s.BillingService.ListInvoices(ctx, billing.ListInvoicesInput{ + Namespace: namespace, + Customers: []string{customerEntity.ID}, + Page: pagination.Page{ + PageSize: 10, + PageNumber: 1, + }, + Expand: billing.InvoiceExpandAll, + }) + s.NoError(err) + s.Len(invoices.Items, 1) + + lines := invoices.Items[0].Lines.OrEmpty() + s.Len(lines, 1) + + flatFeeLine := lines[0] + s.Equal(flatFeeLine.Subscription.SubscriptionID, subsView.Subscription.ID) + s.Equal(flatFeeLine.InvoiceAt, cancelAt) + s.Equal(flatFeeLine.Period, billing.Period{ + Start: s.mustParseTime("2024-01-01T00:00:00Z"), + End: cancelAt, + }) + s.Equal(flatFeeLine.FlatFee.PerUnitAmount.InexactFloat64(), 2.5) + s.Equal(flatFeeLine.FlatFee.Quantity.InexactFloat64(), 1.0) + }) +} + +func getPhraseByKey(t *testing.T, subsView subscription.SubscriptionView, key string) subscription.SubscriptionPhaseView { + for _, phase := range subsView.Phases { + if phase.SubscriptionPhase.Key == key { + return phase + } + } + + t.Fatalf("phase with key %s not found", key) + return subscription.SubscriptionPhaseView{} +} diff --git a/openmeter/billing/worker/subscription/sync.go b/openmeter/billing/worker/subscription/sync.go new file mode 100644 index 000000000..bec0319e2 --- /dev/null +++ b/openmeter/billing/worker/subscription/sync.go @@ -0,0 +1,400 @@ +package billingworkersubscription + +import ( + "context" + "fmt" + "log/slog" + "slices" + "time" + + "github.com/alpacahq/alpacadecimal" + "github.com/samber/lo" + + "github.com/openmeterio/openmeter/openmeter/billing" + "github.com/openmeterio/openmeter/openmeter/productcatalog" + "github.com/openmeterio/openmeter/openmeter/subscription" + "github.com/openmeterio/openmeter/pkg/clock" + "github.com/openmeterio/openmeter/pkg/currencyx" + "github.com/openmeterio/openmeter/pkg/framework/transaction" + "github.com/openmeterio/openmeter/pkg/slicesx" + "github.com/openmeterio/openmeter/pkg/timex" +) + +type Config struct { + BillingService billing.Service + TxCreator transaction.Creator + Logger *slog.Logger +} + +func (c Config) Validate() error { + if c.BillingService == nil { + return fmt.Errorf("billing service is required") + } + + if c.TxCreator == nil { + return fmt.Errorf("transaction creator is required") + } + + if c.Logger == nil { + return fmt.Errorf("logger is required") + } + + return nil +} + +type Handler struct { + billingService billing.Service + txCreator transaction.Creator + logger *slog.Logger +} + +func New(config Config) (*Handler, error) { + if err := config.Validate(); err != nil { + return nil, err + } + return &Handler{ + billingService: config.BillingService, + txCreator: config.TxCreator, + logger: config.Logger, + }, nil +} + +func (h *Handler) SyncronizeSubscription(ctx context.Context, subs subscription.SubscriptionView, asOf time.Time) error { + // TODO[later]: Right now we are getting the billing profile as a validation step, but later if we allow more collection + // alignment settings, we should use the collection settings from here to determine the generation end (overriding asof). + _, err := h.billingService.GetProfileWithCustomerOverride( + ctx, + billing.GetProfileWithCustomerOverrideInput{ + Namespace: subs.Subscription.Namespace, + CustomerID: subs.Subscription.CustomerId, + }, + ) + if err != nil { + return fmt.Errorf("getting billing profile: %w", err) + } + + // Let's see what's in scope for the subscription + slices.SortFunc(subs.Phases, func(i, j subscription.SubscriptionPhaseView) int { + return timex.Compare(i.SubscriptionPhase.ActiveFrom, j.SubscriptionPhase.ActiveFrom) + }) + + inScopeLines, err := h.collectUpcomingLines(subs, asOf) + if err != nil { + return fmt.Errorf("collecting upcoming lines: %w", err) + } + + if len(inScopeLines) == 0 { + // The subscription has no invoicable items, so we can return early + return nil + } + + inScopeLinesByUniqueID, unique := slicesx.UniqueGroupBy(inScopeLines, func(i subscriptionItemWithPeriod) string { + return i.UniqueID + }) + if !unique { + return fmt.Errorf("duplicate unique ids in the upcoming lines") + } + + // Let's load the existing lines for the subscription + existingLines, err := h.billingService.GetLinesForSubscription(ctx, billing.GetLinesForSubscriptionInput{ + Namespace: subs.Subscription.Namespace, + SubscriptionID: subs.Subscription.ID, + }) + if err != nil { + return fmt.Errorf("getting existing lines: %w", err) + } + + existingLinesByUniqueID, unique := slicesx.UniqueGroupBy( + lo.Filter(existingLines, func(l *billing.Line, _ int) bool { + return l.ChildUniqueReferenceID != nil + }), + func(l *billing.Line) string { + return *l.ChildUniqueReferenceID + }) + if !unique { + return fmt.Errorf("duplicate unique ids in the existing lines") + } + + existingLineUniqueIDs := lo.Keys(existingLinesByUniqueID) + inScopeLineUniqueIDs := lo.Keys(inScopeLinesByUniqueID) + // Let's execute the synchronization + deletedLines, newLines := lo.Difference(existingLineUniqueIDs, inScopeLineUniqueIDs) + linesToUpsert := lo.Intersect(existingLineUniqueIDs, inScopeLineUniqueIDs) + + currency, err := subs.Spec.Currency.Calculator() + if err != nil { + return fmt.Errorf("getting currency calculator: %w", err) + } + + return transaction.RunWithNoValue(ctx, h.txCreator, func(ctx context.Context) error { + // Let's stage new lines + newLines, err := slicesx.MapWithErr(newLines, func(id string) (billing.LineWithCustomer, error) { + line, err := h.lineFromSubscritionRateCard(subs, inScopeLinesByUniqueID[id], currency) + if err != nil { + return billing.LineWithCustomer{}, fmt.Errorf("generating line[%s]: %w", id, err) + } + + return billing.LineWithCustomer{ + Line: *line, + CustomerID: subs.Subscription.CustomerId, + }, nil + }) + if err != nil { + return fmt.Errorf("creating new lines: %w", err) + } + + _, err = h.billingService.CreatePendingInvoiceLines(ctx, billing.CreateInvoiceLinesInput{ + Namespace: subs.Subscription.Namespace, + Lines: newLines, + }) + if err != nil { + return fmt.Errorf("creating pending invoice lines: %w", err) + } + + // Let's flag deleted lines deleted + nowPtr := lo.ToPtr(clock.Now()) + for _, uniqueID := range deletedLines { + existingLinesByUniqueID[uniqueID].DeletedAt = nowPtr + } + + // Let's update the existing lines + for _, uniqueID := range linesToUpsert { + expectedLine, err := h.lineFromSubscritionRateCard(subs, inScopeLinesByUniqueID[uniqueID], currency) + if err != nil { + return fmt.Errorf("generating expected line[%s]: %w", uniqueID, err) + } + + if err := h.updateInScopeLine(existingLinesByUniqueID[uniqueID], expectedLine); err != nil { + return fmt.Errorf("updating line[%s]: %w", uniqueID, err) + } + } + + return h.billingService.UpdateInvoiceLinesInternal(ctx, billing.UpdateInvoiceLinesInternalInput{ + Namespace: subs.Subscription.Namespace, + CustomerID: subs.Subscription.CustomerId, + Lines: existingLines, + }) + }) +} + +// TODO[OM-1038]: manually deleted lines might come back to draft/gathering invoices (see ticket) + +// collectUpcomingLines collects the upcoming lines for the subscription, if it does not return any lines the subscription doesn't +// have any invoicable items. +// +// AsOf is a guideline for the end of generation, but the actual end of generation can be different based on the collection (as we +// always yield at least one line if an invoicable line exists). +// +// This approach allows us to not to have to poll all the subscriptions periodically, but we can act when an invoice is created or when +// a subscription is updated. +func (h *Handler) collectUpcomingLines(subs subscription.SubscriptionView, asOf time.Time) ([]subscriptionItemWithPeriod, error) { + inScopeLines := make([]subscriptionItemWithPeriod, 0, 128) + + for _, phase := range subs.Phases { + iterator, err := NewPhaseIterator(subs, phase.SubscriptionPhase.Key) + if err != nil { + return nil, fmt.Errorf("creating phase iterator: %w", err) + } + + if !iterator.HasInvoicableItems() { + continue + } + + generationLimit := asOf + if phaseStart := iterator.PhaseStart(); phaseStart.After(asOf) { + // We need to have invoicable items, so we need to advance the limit here at least to phaseStart to see + // if we can have any invoicable items. + + generationLimit = iterator.GetMinimumBillableTime() + + if generationLimit.IsZero() { + // This should not happen, but if it does, we should skip this phase + continue + } + } + + items, err := iterator.Generate(generationLimit) + if err != nil { + return nil, fmt.Errorf("generating items: %w", err) + } + + inScopeLines = append(inScopeLines, items...) + + if phaseEnd := iterator.PhaseEnd(); phaseEnd != nil && !phaseEnd.Before(asOf) { + // we are done with the generation, as the phase end is after the asOf, and we have invoicable items + break + } + } + + return inScopeLines, nil +} + +func (h *Handler) lineFromSubscritionRateCard(subs subscription.SubscriptionView, item subscriptionItemWithPeriod, currency currencyx.Calculator) (*billing.Line, error) { + line := &billing.Line{ + LineBase: billing.LineBase{ + Namespace: subs.Subscription.Namespace, + Name: item.Spec.RateCard.Name, + Description: item.Spec.RateCard.Description, + Currency: subs.Spec.Currency, + Status: billing.InvoiceLineStatusValid, + ChildUniqueReferenceID: &item.UniqueID, + TaxConfig: item.Spec.RateCard.TaxConfig, + Period: item.Period, + + Subscription: &billing.SubscriptionReference{ + SubscriptionID: subs.Subscription.ID, + PhaseID: item.PhaseID, + ItemID: item.SubscriptionItem.ID, + }, + }, + } + + switch item.SubscriptionItem.RateCard.Price.Type() { + case productcatalog.FlatPriceType: + price, err := item.SubscriptionItem.RateCard.Price.AsFlat() + if err != nil { + return nil, fmt.Errorf("converting price to flat: %w", err) + } + + perUnitAmount := price.Amount + switch price.PaymentTerm { + case productcatalog.InArrearsPaymentTerm: + line.InvoiceAt = item.Period.End + // TODO[OM-1040]: We should support rounding errors in prorating calculations (such as 1/3 of a dollar is $0.33, 3*$0.33 is $0.99, if we bill + // $1.00 in three equal pieces we should charge the customer $0.01 as the last split) + perUnitAmount = currency.RoundToPrecision(price.Amount.Mul(item.PeriodPercentage())) + case productcatalog.InAdvancePaymentTerm: + // In case of inAdvance we should always invoice at the start of the period and if there's a change + // prorating should void the item and credit the customer. + // + // Warning: We are not supporting voiding or crediting right now, so we are going to overcharge on + // inAdvance items in case of a change. + line.InvoiceAt = item.Period.Start + default: + return nil, fmt.Errorf("unsupported payment term: %v", price.PaymentTerm) + } + + line.Type = billing.InvoiceLineTypeFee + line.FlatFee = billing.FlatFeeLine{ + PerUnitAmount: perUnitAmount, + Quantity: alpacadecimal.NewFromInt(1), + PaymentTerm: price.PaymentTerm, + Category: billing.FlatFeeCategoryRegular, + } + + case productcatalog.UnitPriceType, productcatalog.TieredPriceType: + // Should not happen, but let's be safe + if item.SubscriptionItem.RateCard.FeatureKey == nil { + return nil, fmt.Errorf("feature must be defined for usage based price") + } + + if item.SubscriptionItem.RateCard.Price == nil { + return nil, fmt.Errorf("price must be defined for usage based price") + } + + line.Type = billing.InvoiceLineTypeUsageBased + line.InvoiceAt = item.Period.End + line.UsageBased = billing.UsageBasedLine{ + Price: *item.SubscriptionItem.RateCard.Price, + FeatureKey: *item.SubscriptionItem.RateCard.FeatureKey, + } + + default: + return nil, fmt.Errorf("unsupported price type: %v", item.SubscriptionItem.RateCard.Price.Type()) + } + + return line, nil +} + +func (h *Handler) updateInScopeLine(existingLine *billing.Line, expectedLine *billing.Line) error { + // TODO/WARNING[later]: This logic should be fine with everything that can be billed progressively, however the following use-cases + // will behave strangely: + // + // - An in advance flat fee cannot be prorated, as that would require void/credit logic. + // - If a volume based item's tiered are changed, then the old volume based item will be billed at that rate, and the + // new volume based item's calculation will start from the new tiered rates beginning. (e.g. if we have a 1 month long tiered + // price, then we change the tiered price in the middle of the month, the old tiered price will be billed for the first half + // and the new tiered price will be billed for the second half, meaning that the customer will be billed for less sum traffic) [OM] + // - If a meter is unique counted, then the unique count will be reset at the change's time + + // This is a non-split line, so it's either assigned to a gathering invoice or an already paid invoice, we can just update the line + // and let billing service handle the rest + if existingLine.Status == billing.InvoiceLineStatusValid { + h.mergeChangesFromLine(existingLine, expectedLine) + + return nil + } + + // Parts of the line has been already invoiced using progressive invoicing, so we need to examine the children + if existingLine.Status == billing.InvoiceLineStatusSplit { + // Nothing to do here, as split lines are UBP lines and thus we don't need the flat fee corrections + // TODO[later]: When we implement progressive billing based pro-rating, we need to support adjusting flat fee + // segments here. + + if existingLine.Period.End.Before(expectedLine.Period.End) { + // Expansion of the line (e.g. continue subscription) + + children := existingLine.Children.OrEmpty() + if len(children) > 0 { + slices.SortFunc(children, func(i, j *billing.Line) int { + return timex.Compare(i.Period.End, j.Period.End) + }) + + lastChild := children[len(children)-1] + lastChild.Period.End = expectedLine.Period.End + lastChild.InvoiceAt = expectedLine.Period.End + } + + existingLine.Period.End = expectedLine.Period.End + existingLine.InvoiceAt = expectedLine.Period.End + } else { + // Shrink of the line (e.g. subscription cancled, subscription item edit) + + for _, child := range existingLine.Children.OrEmpty() { + if child.Period.End.Before(expectedLine.Period.End) { + // The child is not affected by the period shrink, so we can skip it + continue + } + + if child.Period.Start.After(expectedLine.Period.End) { + // The child is after the period shrink, so we need to delete it as it became invalid + child.DeletedAt = lo.ToPtr(clock.Now()) + continue + } + + // The child is partially affected by the period shrink, so we need to adjust the period + if !child.Period.End.Equal(expectedLine.Period.End) { + child.Period.End = expectedLine.Period.End + + if child.InvoiceAt.After(expectedLine.Period.End) { + // The child is invoiced after the period end, so we need to adjust the invoice date + child.InvoiceAt = expectedLine.Period.End + } + } + } + // Split lines are always associated with gathering invoices, so we can safely update the line without checking for + // snapshot update requirements + + existingLine.Period.End = expectedLine.Period.End + existingLine.InvoiceAt = expectedLine.Period.End + } + + return nil + } + + // There is no other state in which a line can be in, so we can safely return an error here + return fmt.Errorf("could not handle line update [lineID=%s, status=%s]", existingLine.ID, existingLine.Status) +} + +func (h *Handler) mergeChangesFromLine(existingLine *billing.Line, expectedLine *billing.Line) { + // We assume that only the period can change, maybe some pricing data due to prorating (for flat lines) + + existingLine.Period = expectedLine.Period + + existingLine.InvoiceAt = expectedLine.InvoiceAt + + // Let's handle the flat fee prorating + if existingLine.Type == billing.InvoiceLineTypeFee { + existingLine.FlatFee.PerUnitAmount = expectedLine.FlatFee.PerUnitAmount + } +} diff --git a/pkg/slicesx/groupby.go b/pkg/slicesx/groupby.go new file mode 100644 index 000000000..aa9a399af --- /dev/null +++ b/pkg/slicesx/groupby.go @@ -0,0 +1,18 @@ +package slicesx + +import "github.com/samber/lo" + +func UniqueGroupBy[T any, U comparable, Slice ~[]T](collection Slice, iteratee func(item T) U) (map[U]T, bool) { + res := lo.GroupBy(collection, iteratee) + out := make(map[U]T, len(res)) + + for k, v := range res { + if len(v) > 1 { + return nil, false + } + + out[k] = v[0] + } + + return out, true +} diff --git a/pkg/timex/compare.go b/pkg/timex/compare.go new file mode 100644 index 000000000..87ef28636 --- /dev/null +++ b/pkg/timex/compare.go @@ -0,0 +1,7 @@ +package timex + +import "time" + +func Compare(a, b time.Time) int { + return int(a.Sub(b)) +} diff --git a/test/billing/adapter_test.go b/test/billing/adapter_test.go index 980f7e99e..76fc8be7d 100644 --- a/test/billing/adapter_test.go +++ b/test/billing/adapter_test.go @@ -1,4 +1,4 @@ -package billing_test +package billing import ( "context" @@ -52,9 +52,9 @@ func (s *BillingAdapterTestSuite) setupInvoice(ctx context.Context, ns string) * require.NotEmpty(s.T(), customerEntity.ID) // Given we have a profile - _ = s.installSandboxApp(s.T(), ns) + _ = s.InstallSandboxApp(s.T(), ns) - minimalCreateProfileInput := minimalCreateProfileInputTemplate + minimalCreateProfileInput := MinimalCreateProfileInputTemplate minimalCreateProfileInput.Namespace = ns profile, err := s.BillingService.CreateProfile(ctx, minimalCreateProfileInput) diff --git a/test/billing/customeroverride_test.go b/test/billing/customeroverride_test.go index 60b18663b..d452a305f 100644 --- a/test/billing/customeroverride_test.go +++ b/test/billing/customeroverride_test.go @@ -1,4 +1,4 @@ -package billing_test +package billing import ( "context" @@ -43,7 +43,7 @@ func (s *CustomerOverrideTestSuite) TestDefaultProfileHandling() { ns := "test-ns-default-profile-handling" ctx := context.Background() - _ = s.installSandboxApp(s.T(), ns) + _ = s.InstallSandboxApp(s.T(), ns) // Given we have an existing customer customer, err := s.CustomerService.CreateCustomer(ctx, customerentity.CreateCustomerInput{ @@ -72,7 +72,7 @@ func (s *CustomerOverrideTestSuite) TestDefaultProfileHandling() { s.T().Run("customer with default profile, no override", func(t *testing.T) { // Given having a default profile - profileInput := minimalCreateProfileInputTemplate + profileInput := MinimalCreateProfileInputTemplate profileInput.Namespace = ns defaultProfile, err = s.BillingService.CreateProfile(ctx, profileInput) @@ -142,7 +142,7 @@ func (s *CustomerOverrideTestSuite) TestPinnedProfileHandling() { ns := "test-ns-pinned-profile-handling" ctx := context.Background() - _ = s.installSandboxApp(s.T(), ns) + _ = s.InstallSandboxApp(s.T(), ns) // Given we have an existing customer customer, err := s.CustomerService.CreateCustomer(ctx, customerentity.CreateCustomerInput{ @@ -156,7 +156,7 @@ func (s *CustomerOverrideTestSuite) TestPinnedProfileHandling() { customerID := customer.ID // Given we have a non-default profile - profileInput := minimalCreateProfileInputTemplate + profileInput := MinimalCreateProfileInputTemplate profileInput.Namespace = ns profileInput.Default = false @@ -215,7 +215,7 @@ func (s *CustomerOverrideTestSuite) TestSanityOverrideOperations() { ns := "test-sanity-override-operations" ctx := context.Background() - s.installSandboxApp(s.T(), ns) + s.InstallSandboxApp(s.T(), ns) customer, err := s.CustomerService.CreateCustomer(ctx, customerentity.CreateCustomerInput{ Namespace: ns, @@ -237,7 +237,7 @@ func (s *CustomerOverrideTestSuite) TestSanityOverrideOperations() { require.ErrorAs(t, err, &billing.NotFoundError{}) }) - profileInput := minimalCreateProfileInputTemplate + profileInput := MinimalCreateProfileInputTemplate profileInput.Namespace = ns defaultProfile, err := s.BillingService.CreateProfile(ctx, profileInput) @@ -310,7 +310,7 @@ func (s *CustomerOverrideTestSuite) TestCustomerIntegration() { ns := "test-customer-integration" ctx := context.Background() - _ = s.installSandboxApp(s.T(), ns) + _ = s.InstallSandboxApp(s.T(), ns) customer, err := s.CustomerService.CreateCustomer(ctx, customerentity.CreateCustomerInput{ Namespace: ns, @@ -328,7 +328,7 @@ func (s *CustomerOverrideTestSuite) TestCustomerIntegration() { require.NoError(s.T(), err) require.NotNil(s.T(), customer) - profileInput := minimalCreateProfileInputTemplate + profileInput := MinimalCreateProfileInputTemplate profileInput.Namespace = ns defaultProfile, err := s.BillingService.CreateProfile(ctx, profileInput) @@ -356,7 +356,7 @@ func (s *CustomerOverrideTestSuite) TestNullSetting() { ns := "test-null-setting" ctx := context.Background() - _ = s.installSandboxApp(s.T(), ns) + _ = s.InstallSandboxApp(s.T(), ns) customer, err := s.CustomerService.CreateCustomer(ctx, customerentity.CreateCustomerInput{ Namespace: ns, @@ -369,7 +369,7 @@ func (s *CustomerOverrideTestSuite) TestNullSetting() { require.NoError(s.T(), err) require.NotNil(s.T(), customer) - profileInput := minimalCreateProfileInputTemplate + profileInput := MinimalCreateProfileInputTemplate profileInput.Namespace = ns defaultProfile, err := s.BillingService.CreateProfile(ctx, profileInput) diff --git a/test/billing/invoice_test.go b/test/billing/invoice_test.go index af7e2f72f..52bc16980 100644 --- a/test/billing/invoice_test.go +++ b/test/billing/invoice_test.go @@ -1,4 +1,4 @@ -package billing_test +package billing import ( "context" @@ -43,7 +43,7 @@ func (s *InvoicingTestSuite) TestPendingLineCreation() { periodStart := periodEnd.Add(-time.Hour * 24 * 30) issueAt := now.Add(-time.Minute) - _ = s.installSandboxApp(s.T(), namespace) + _ = s.InstallSandboxApp(s.T(), namespace) ctx := context.Background() @@ -96,7 +96,7 @@ func (s *InvoicingTestSuite) TestPendingLineCreation() { var billingProfile billing.Profile s.T().Run("create default profile", func(t *testing.T) { - minimalCreateProfileInput := minimalCreateProfileInputTemplate + minimalCreateProfileInput := MinimalCreateProfileInputTemplate minimalCreateProfileInput.Namespace = namespace profile, err := s.BillingService.CreateProfile(ctx, minimalCreateProfileInput) @@ -429,7 +429,7 @@ func (s *InvoicingTestSuite) TestCreateInvoice() { line1IssueAt := now.Add(-2 * time.Hour) line2IssueAt := now.Add(-time.Hour) - _ = s.installSandboxApp(s.T(), namespace) + _ = s.InstallSandboxApp(s.T(), namespace) ctx := context.Background() @@ -453,7 +453,7 @@ func (s *InvoicingTestSuite) TestCreateInvoice() { // Given we have a default profile for the namespace - minimalCreateProfileInput := minimalCreateProfileInputTemplate + minimalCreateProfileInput := MinimalCreateProfileInputTemplate minimalCreateProfileInput.Namespace = namespace profile, err := s.BillingService.CreateProfile(ctx, minimalCreateProfileInput) @@ -839,7 +839,7 @@ func (s *InvoicingTestSuite) TestInvoicingFlow() { s.T().Run(tc.name, func(t *testing.T) { namespace := fmt.Sprintf("ns-invoicing-flow-happy-path-%d", i) - _ = s.installSandboxApp(s.T(), namespace) + _ = s.InstallSandboxApp(s.T(), namespace) // Given we have a test customer customerEntity, err := s.CustomerService.CreateCustomer(ctx, customerentity.CreateCustomerInput{ @@ -859,7 +859,7 @@ func (s *InvoicingTestSuite) TestInvoicingFlow() { require.NotEmpty(s.T(), customerEntity.ID) // Given we have a billing profile - minimalCreateProfileInput := minimalCreateProfileInputTemplate + minimalCreateProfileInput := MinimalCreateProfileInputTemplate minimalCreateProfileInput.Namespace = namespace minimalCreateProfileInput.WorkflowConfig = tc.workflowConfig @@ -1177,7 +1177,7 @@ func (s *InvoicingTestSuite) TestInvoicingFlowErrorHandling() { s.T().Run(tc.name, func(t *testing.T) { namespace := fmt.Sprintf("ns-invoicing-flow-valid-%d", i) - _ = s.installSandboxApp(s.T(), namespace) + _ = s.InstallSandboxApp(s.T(), namespace) mockApp := s.SandboxApp.EnableMock(t) defer s.SandboxApp.DisableMock() @@ -1200,7 +1200,7 @@ func (s *InvoicingTestSuite) TestInvoicingFlowErrorHandling() { require.NotEmpty(s.T(), customerEntity.ID) // Given we have a billing profile - minimalCreateProfileInput := minimalCreateProfileInputTemplate + minimalCreateProfileInput := MinimalCreateProfileInputTemplate minimalCreateProfileInput.Namespace = namespace minimalCreateProfileInput.WorkflowConfig = tc.workflowConfig @@ -1236,7 +1236,7 @@ func (s *InvoicingTestSuite) TestUBPInvoicing() { periodStart := lo.Must(time.Parse(time.RFC3339, "2024-09-02T12:13:14Z")) periodEnd := lo.Must(time.Parse(time.RFC3339, "2024-09-03T12:13:14Z")) - _ = s.installSandboxApp(s.T(), namespace) + _ = s.InstallSandboxApp(s.T(), namespace) s.MeterRepo.ReplaceMeters(ctx, []models.Meter{ { @@ -1332,7 +1332,7 @@ func (s *InvoicingTestSuite) TestUBPInvoicing() { require.NotEmpty(s.T(), customerEntity.ID) // Given we have a default profile for the namespace - minimalCreateProfileInput := minimalCreateProfileInputTemplate + minimalCreateProfileInput := MinimalCreateProfileInputTemplate minimalCreateProfileInput.Namespace = namespace profile, err := s.BillingService.CreateProfile(ctx, minimalCreateProfileInput) diff --git a/test/billing/profile.go b/test/billing/profile.go new file mode 100644 index 000000000..26aa66607 --- /dev/null +++ b/test/billing/profile.go @@ -0,0 +1,49 @@ +package billing + +import ( + "github.com/samber/lo" + + appentitybase "github.com/openmeterio/openmeter/openmeter/app/entity/base" + "github.com/openmeterio/openmeter/openmeter/billing" + "github.com/openmeterio/openmeter/pkg/datex" + "github.com/openmeterio/openmeter/pkg/models" +) + +var MinimalCreateProfileInputTemplate = billing.CreateProfileInput{ + Name: "Awesome Profile", + Default: true, + + WorkflowConfig: billing.WorkflowConfig{ + Collection: billing.CollectionConfig{ + Alignment: billing.AlignmentKindSubscription, + Interval: lo.Must(datex.ISOString("PT2H").Parse()), + }, + Invoicing: billing.InvoicingConfig{ + AutoAdvance: true, + DraftPeriod: lo.Must(datex.ISOString("P1D").Parse()), + DueAfter: lo.Must(datex.ISOString("P1W").Parse()), + }, + Payment: billing.PaymentConfig{ + CollectionMethod: billing.CollectionMethodChargeAutomatically, + }, + }, + + Supplier: billing.SupplierContact{ + Name: "Awesome Supplier", + Address: models.Address{ + Country: lo.ToPtr(models.CountryCode("US")), + }, + }, + + Apps: billing.CreateProfileAppsInput{ + Invoicing: billing.AppReference{ + Type: appentitybase.AppTypeSandbox, + }, + Payment: billing.AppReference{ + Type: appentitybase.AppTypeSandbox, + }, + Tax: billing.AppReference{ + Type: appentitybase.AppTypeSandbox, + }, + }, +} diff --git a/test/billing/profile_test.go b/test/billing/profile_test.go index 3aa613d5c..a8af1e40b 100644 --- a/test/billing/profile_test.go +++ b/test/billing/profile_test.go @@ -1,4 +1,4 @@ -package billing_test +package billing import ( "context" @@ -15,45 +15,6 @@ import ( "github.com/openmeterio/openmeter/pkg/models" ) -var minimalCreateProfileInputTemplate = billing.CreateProfileInput{ - Name: "Awesome Profile", - Default: true, - - WorkflowConfig: billing.WorkflowConfig{ - Collection: billing.CollectionConfig{ - Alignment: billing.AlignmentKindSubscription, - Interval: lo.Must(datex.ISOString("PT2H").Parse()), - }, - Invoicing: billing.InvoicingConfig{ - AutoAdvance: true, - DraftPeriod: lo.Must(datex.ISOString("P1D").Parse()), - DueAfter: lo.Must(datex.ISOString("P1W").Parse()), - }, - Payment: billing.PaymentConfig{ - CollectionMethod: billing.CollectionMethodChargeAutomatically, - }, - }, - - Supplier: billing.SupplierContact{ - Name: "Awesome Supplier", - Address: models.Address{ - Country: lo.ToPtr(models.CountryCode("US")), - }, - }, - - Apps: billing.CreateProfileAppsInput{ - Invoicing: billing.AppReference{ - Type: appentitybase.AppTypeSandbox, - }, - Payment: billing.AppReference{ - Type: appentitybase.AppTypeSandbox, - }, - Tax: billing.AppReference{ - Type: appentitybase.AppTypeSandbox, - }, - }, -} - type ProfileTestSuite struct { BaseSuite } @@ -66,7 +27,7 @@ func (s *ProfileTestSuite) TestProfileLifecycle() { ctx := context.Background() ns := "test_create_billing_profile" - _ = s.installSandboxApp(s.T(), ns) + _ = s.InstallSandboxApp(s.T(), ns) s.T().Run("missing default profile", func(t *testing.T) { defaultProfile, err := s.BillingService.GetDefaultProfile(ctx, billing.GetDefaultProfileInput{ @@ -79,7 +40,7 @@ func (s *ProfileTestSuite) TestProfileLifecycle() { var profile *billing.Profile var err error - minimalCreateProfileInput := minimalCreateProfileInputTemplate + minimalCreateProfileInput := MinimalCreateProfileInputTemplate minimalCreateProfileInput.Namespace = ns s.T().Run("create default profile", func(t *testing.T) { @@ -163,7 +124,7 @@ func (s *ProfileTestSuite) TestProfileFieldSetting() { t := s.T() ns := "test_profile_field_setting" - app := s.installSandboxApp(s.T(), ns) + app := s.InstallSandboxApp(s.T(), ns) input := billing.CreateProfileInput{ Namespace: ns, @@ -277,7 +238,7 @@ func (s *ProfileTestSuite) TestProfileUpdates() { ctx := context.Background() ns := "test_profile_updates" - _ = s.installSandboxApp(s.T(), ns) + _ = s.InstallSandboxApp(s.T(), ns) input := billing.CreateProfileInput{ Namespace: ns, @@ -285,7 +246,7 @@ func (s *ProfileTestSuite) TestProfileUpdates() { Name: "Awesome Default Profile", - Apps: minimalCreateProfileInputTemplate.Apps, + Apps: MinimalCreateProfileInputTemplate.Apps, WorkflowConfig: billing.WorkflowConfig{ Collection: billing.CollectionConfig{ diff --git a/test/billing/suite_test.go b/test/billing/suite.go similarity index 92% rename from test/billing/suite_test.go rename to test/billing/suite.go index 53d0bf8c1..64bc83db0 100644 --- a/test/billing/suite_test.go +++ b/test/billing/suite.go @@ -1,4 +1,4 @@ -package billing_test +package billing import ( "context" @@ -33,6 +33,7 @@ import ( type BaseSuite struct { suite.Suite + *require.Assertions TestDB *testutils.TestDB DBClient *db.Client @@ -42,6 +43,7 @@ type BaseSuite struct { InvoiceCalculator *invoicecalc.MockableInvoiceCalculator FeatureService feature.FeatureConnector + FeatureRepo feature.FeatureRepo MeterRepo *meter.InMemoryRepository MockStreamingConnector *streamingtestutils.MockStreamingConnector @@ -54,11 +56,13 @@ type BaseSuite struct { func (s *BaseSuite) SetupSuite() { t := s.T() t.Log("setup suite") + s.Assertions = require.New(t) s.TestDB = testutils.InitPostgresDB(t) // init db dbClient := db.NewClient(db.Driver(s.TestDB.EntDriver.Driver())) + s.DBClient = dbClient if os.Getenv("TEST_DISABLE_ATLAS") != "" { s.Require().NoError(dbClient.Schema.Create(context.Background())) @@ -74,9 +78,8 @@ func (s *BaseSuite) SetupSuite() { s.MockStreamingConnector = streamingtestutils.NewMockStreamingConnector(t) // Feature - featureRepo := featureadapter.NewPostgresFeatureRepo(dbClient, slog.Default()) - - s.FeatureService = feature.NewFeatureConnector(featureRepo, s.MeterRepo) + s.FeatureRepo = featureadapter.NewPostgresFeatureRepo(dbClient, slog.Default()) + s.FeatureService = feature.NewFeatureConnector(s.FeatureRepo, s.MeterRepo) // Customer @@ -138,7 +141,7 @@ func (s *BaseSuite) SetupSuite() { s.BillingService = billingService.WithInvoiceCalculator(s.InvoiceCalculator) } -func (s *BaseSuite) installSandboxApp(t *testing.T, ns string) appentity.App { +func (s *BaseSuite) InstallSandboxApp(t *testing.T, ns string) appentity.App { ctx := context.Background() _, err := s.AppService.CreateApp(ctx, appentity.CreateAppInput{