diff --git a/go.mod b/go.mod index e7f39ee88..f1d486c00 100644 --- a/go.mod +++ b/go.mod @@ -45,6 +45,7 @@ require ( github.com/prometheus/client_golang v1.20.4 github.com/prometheus/client_model v0.6.1 github.com/prometheus/common v0.60.0 + github.com/qmuntal/stateless v1.7.1 github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 github.com/redis/go-redis/extra/redisotel/v9 v9.5.3 github.com/redis/go-redis/v9 v9.7.0 diff --git a/go.sum b/go.sum index 2182e253b..3b8a263b6 100644 --- a/go.sum +++ b/go.sum @@ -1161,6 +1161,8 @@ github.com/protocolbuffers/txtpbfmt v0.0.0-20230328191034-3462fbc510c0 h1:sadMIs github.com/protocolbuffers/txtpbfmt v0.0.0-20230328191034-3462fbc510c0/go.mod h1:jgxiZysxFPM+iWKwQwPR+y+Jvo54ARd4EisXxKYpB5c= github.com/pusher/pusher-http-go v4.0.1+incompatible h1:4u6tomPG1WhHaST7Wi9mw83Y+MS/j2EplR2YmDh8Xp4= github.com/pusher/pusher-http-go v4.0.1+incompatible/go.mod h1:XAv1fxRmVTI++2xsfofDhg7whapsLRG/gH/DXbF3a18= +github.com/qmuntal/stateless v1.7.1 h1:dI+BtLHq/nD6u46POkOINTDjY9uE33/4auEzfX3TWp0= +github.com/qmuntal/stateless v1.7.1/go.mod h1:n1HjRBM/cq4uCr3rfUjaMkgeGcd+ykAZwkjLje6jGBM= github.com/quipo/dependencysolver v0.0.0-20170801134659-2b009cb4ddcc h1:hK577yxEJ2f5s8w2iy2KimZmgrdAUZUNftE1ESmg2/Q= github.com/quipo/dependencysolver v0.0.0-20170801134659-2b009cb4ddcc/go.mod h1:OQt6Zo5B3Zs+C49xul8kcHo+fZ1mCLPvd0LFxiZ2DHc= github.com/r3labs/diff/v3 v3.0.1 h1:CBKqf3XmNRHXKmdU7mZP1w7TV0pDyVCis1AUHtA4Xtg= diff --git a/openmeter/billing/adapter.go b/openmeter/billing/adapter.go index 173080a99..d9cce93a6 100644 --- a/openmeter/billing/adapter.go +++ b/openmeter/billing/adapter.go @@ -56,4 +56,5 @@ type InvoiceAdapter interface { DeleteInvoices(ctx context.Context, input DeleteInvoicesAdapterInput) error ListInvoices(ctx context.Context, input ListInvoicesInput) (ListInvoicesResponse, error) AssociatedLineCounts(ctx context.Context, input AssociatedLineCountsAdapterInput) (AssociatedLineCountsAdapterResponse, error) + UpdateInvoice(ctx context.Context, input UpdateInvoiceAdapterInput) error } diff --git a/openmeter/billing/adapter/invoice.go b/openmeter/billing/adapter/invoice.go index c629f5b79..3271b4ec5 100644 --- a/openmeter/billing/adapter/invoice.go +++ b/openmeter/billing/adapter/invoice.go @@ -7,6 +7,7 @@ import ( "strings" "time" + "entgo.io/ent/dialect/sql" "github.com/samber/lo" "github.com/openmeterio/openmeter/api" @@ -144,8 +145,18 @@ func (r *adapter) ListInvoices(ctx context.Context, input billing.ListInvoicesIn query = query.Where(billinginvoice.IssuedAtLTE(*input.IssuedBefore)) } + if len(input.ExtendedStatuses) > 0 { + query = query.Where(billinginvoice.StatusIn(input.ExtendedStatuses...)) + } + if len(input.Statuses) > 0 { - query = query.Where(billinginvoice.StatusIn(input.Statuses...)) + query = query.Where(func(s *sql.Selector) { + s.Where(sql.Or( + lo.Map(input.Statuses, func(status string, _ int) *sql.Predicate { + return sql.Like(billinginvoice.FieldStatus, status+"%") + })..., + )) + }) } if len(input.Currencies) > 0 { @@ -269,7 +280,7 @@ func (r *adapter) CreateInvoice(ctx context.Context, input billing.CreateInvoice // Let's add required edges for mapping newInvoice.Edges.BillingWorkflowConfig = clonedWorkflowConfig - return mapInvoiceFromDB(*newInvoice, billing.InvoiceExpandAll) + return mapInvoiceFromDB(*newInvoice, billingentity.InvoiceExpandAll) } type lineCountQueryOut struct { @@ -316,7 +327,122 @@ func (r *adapter) AssociatedLineCounts(ctx context.Context, input billing.Associ }, nil } -func mapInvoiceFromDB(invoice db.BillingInvoice, expand billing.InvoiceExpand) (billingentity.Invoice, error) { +func (r *adapter) validateUpdateRequest(req billing.UpdateInvoiceAdapterInput, existing *db.BillingInvoice) error { + // The user is expected to submit the updatedAt of the source invoice version it based the update on + // if this doesn't match the current updatedAt, we can't allow the update as it might overwrite some already + // changed values. + if !existing.UpdatedAt.Equal(req.UpdatedAt) { + return billing.ConflictError{ + Entity: billing.EntityInvoice, + ID: req.ID, + Err: fmt.Errorf("invoice has been updated since last read"), + } + } + + if req.Currency != existing.Currency { + return billing.ValidationError{ + Err: fmt.Errorf("currency cannot be changed"), + } + } + + if req.Type != existing.Type { + return billing.ValidationError{ + Err: fmt.Errorf("type cannot be changed"), + } + } + + if req.Customer.CustomerID != existing.CustomerID { + return billing.ValidationError{ + Err: fmt.Errorf("customer cannot be changed"), + } + } + + return nil +} + +// UpdateInvoice updates the specified invoice. It does not return the new invoice, as we would either +// ways need to re-fetch the invoice to get the updated edges. +func (r *adapter) UpdateInvoice(ctx context.Context, in billing.UpdateInvoiceAdapterInput) error { + existingInvoice, err := r.db.BillingInvoice.Query(). + Where(billinginvoice.ID(in.ID)). + Where(billinginvoice.Namespace(in.Namespace)). + Only(ctx) + if err != nil { + return err + } + + if err := r.validateUpdateRequest(in, existingInvoice); err != nil { + return err + } + + updateQuery := r.db.BillingInvoice.UpdateOneID(in.ID). + Where(billinginvoice.Namespace(in.Namespace)). + SetMetadata(in.Metadata). + // Currency is immutable + SetStatus(in.Status). + // Type is immutable + SetOrClearNumber(in.Number). + SetOrClearDescription(in.Description). + SetOrClearDueAt(in.DueAt). + SetOrClearDraftUntil(in.DraftUntil). + SetOrClearIssuedAt(in.IssuedAt) + + if in.Period != nil { + updateQuery = updateQuery. + SetPeriodStart(in.Period.Start). + SetPeriodEnd(in.Period.End) + } else { + updateQuery = updateQuery. + ClearPeriodStart(). + ClearPeriodEnd() + } + + // Supplier + updateQuery = updateQuery. + SetSupplierName(in.Supplier.Name). + SetOrClearSupplierAddressCountry(in.Supplier.Address.Country). + SetOrClearSupplierAddressPostalCode(in.Supplier.Address.PostalCode). + SetOrClearSupplierAddressCity(in.Supplier.Address.City). + SetOrClearSupplierAddressState(in.Supplier.Address.State). + SetOrClearSupplierAddressLine1(in.Supplier.Address.Line1). + SetOrClearSupplierAddressLine2(in.Supplier.Address.Line2). + SetOrClearSupplierAddressPhoneNumber(in.Supplier.Address.PhoneNumber) + + // Customer + updateQuery = updateQuery. + // CustomerID is immutable + SetCustomerName(in.Customer.Name). + SetOrClearCustomerAddressCountry(in.Customer.BillingAddress.Country). + SetOrClearCustomerAddressPostalCode(in.Customer.BillingAddress.PostalCode). + SetOrClearCustomerAddressCity(in.Customer.BillingAddress.City). + SetOrClearCustomerAddressState(in.Customer.BillingAddress.State). + SetOrClearCustomerAddressLine1(in.Customer.BillingAddress.Line1). + SetOrClearCustomerAddressLine2(in.Customer.BillingAddress.Line2). + SetOrClearCustomerAddressPhoneNumber(in.Customer.BillingAddress.PhoneNumber). + SetOrClearCustomerTimezone(in.Customer.Timezone) + + _, err = updateQuery.Save(ctx) + if err != nil { + return err + } + + if in.ExpandedFields.Workflow { + // Update the workflow config + _, err := r.updateWorkflowConfig(ctx, in.Namespace, in.Workflow.Config.ID, in.Workflow.Config) + if err != nil { + return err + } + } + + if in.ExpandedFields.Lines { + // TODO[later]: line updates (with changed flag) + r.logger.Warn("line updates are not yet implemented") + } + + return nil +} + +func mapInvoiceFromDB(invoice db.BillingInvoice, expand billingentity.InvoiceExpand) (billingentity.Invoice, error) { res := billingentity.Invoice{ ID: invoice.ID, Namespace: invoice.Namespace, @@ -327,6 +453,7 @@ func mapInvoiceFromDB(invoice db.BillingInvoice, expand billing.InvoiceExpand) ( Number: invoice.Number, Description: invoice.Description, DueAt: invoice.DueAt, + DraftUntil: invoice.DraftUntil, Supplier: billingentity.SupplierContact{ Name: invoice.SupplierName, Address: models.Address{ @@ -360,6 +487,8 @@ func mapInvoiceFromDB(invoice db.BillingInvoice, expand billing.InvoiceExpand) ( CreatedAt: invoice.CreatedAt, UpdatedAt: invoice.UpdatedAt, DeletedAt: invoice.DeletedAt, + + ExpandedFields: expand, } if expand.Workflow { @@ -369,7 +498,7 @@ func mapInvoiceFromDB(invoice db.BillingInvoice, expand billing.InvoiceExpand) ( } res.Workflow = &billingentity.InvoiceWorkflow{ - WorkflowConfig: workflowConfig, + Config: workflowConfig, SourceBillingProfileID: invoice.SourceBillingProfileID, AppReferences: billingentity.ProfileAppReferences{ diff --git a/openmeter/billing/adapter/profile.go b/openmeter/billing/adapter/profile.go index 87a0c8c9c..3dd2cf478 100644 --- a/openmeter/billing/adapter/profile.go +++ b/openmeter/billing/adapter/profile.go @@ -4,8 +4,6 @@ import ( "context" "fmt" - "github.com/samber/lo" - "github.com/openmeterio/openmeter/api" "github.com/openmeterio/openmeter/openmeter/billing" billingentity "github.com/openmeterio/openmeter/openmeter/billing/entity" @@ -68,7 +66,7 @@ func (a *adapter) createWorkflowConfig(ctx context.Context, ns string, input bil SetNamespace(ns). SetCollectionAlignment(input.Collection.Alignment). SetLineCollectionPeriod(input.Collection.Interval.ISOString()). - SetInvoiceAutoAdvance(*input.Invoicing.AutoAdvance). + SetInvoiceAutoAdvance(input.Invoicing.AutoAdvance). SetInvoiceDraftPeriod(input.Invoicing.DraftPeriod.ISOString()). SetInvoiceDueAfter(input.Invoicing.DueAfter.ISOString()). SetInvoiceCollectionMethod(input.Payment.CollectionMethod). @@ -239,15 +237,7 @@ func (a adapter) UpdateProfile(ctx context.Context, input billing.UpdateProfileA return nil, err } - updatedWorkflowConfig, err := a.db.BillingWorkflowConfig.UpdateOneID(input.WorkflowConfigID). - Where(billingworkflowconfig.Namespace(targetState.Namespace)). - SetCollectionAlignment(targetState.WorkflowConfig.Collection.Alignment). - SetLineCollectionPeriod(targetState.WorkflowConfig.Collection.Interval.ISOString()). - SetInvoiceAutoAdvance(*targetState.WorkflowConfig.Invoicing.AutoAdvance). - SetInvoiceDraftPeriod(targetState.WorkflowConfig.Invoicing.DraftPeriod.ISOString()). - SetInvoiceDueAfter(targetState.WorkflowConfig.Invoicing.DueAfter.ISOString()). - SetInvoiceCollectionMethod(targetState.WorkflowConfig.Payment.CollectionMethod). - Save(ctx) + updatedWorkflowConfig, err := a.updateWorkflowConfig(ctx, targetState.Namespace, input.WorkflowConfigID, targetState.WorkflowConfig) if err != nil { return nil, err } @@ -256,6 +246,18 @@ func (a adapter) UpdateProfile(ctx context.Context, input billing.UpdateProfileA return mapProfileFromDB(updatedProfile) } +func (a adapter) updateWorkflowConfig(ctx context.Context, ns string, id string, input billingentity.WorkflowConfig) (*db.BillingWorkflowConfig, error) { + return a.db.BillingWorkflowConfig.UpdateOneID(id). + Where(billingworkflowconfig.Namespace(ns)). + SetCollectionAlignment(input.Collection.Alignment). + SetLineCollectionPeriod(input.Collection.Interval.ISOString()). + SetInvoiceAutoAdvance(input.Invoicing.AutoAdvance). + SetInvoiceDraftPeriod(input.Invoicing.DraftPeriod.ISOString()). + SetInvoiceDueAfter(input.Invoicing.DueAfter.ISOString()). + SetInvoiceCollectionMethod(input.Payment.CollectionMethod). + Save(ctx) +} + func mapProfileFromDB(dbProfile *db.BillingProfile) (*billingentity.BaseProfile, error) { if dbProfile == nil { return nil, nil @@ -310,11 +312,9 @@ func mapWorkflowConfigToDB(wc billingentity.WorkflowConfig) *db.BillingWorkflowC UpdatedAt: wc.UpdatedAt, DeletedAt: wc.DeletedAt, - CollectionAlignment: wc.Collection.Alignment, - LineCollectionPeriod: wc.Collection.Interval.ISOString(), - InvoiceAutoAdvance: lo.FromPtrOr( - wc.Invoicing.AutoAdvance, - *billingentity.DefaultWorkflowConfig.Invoicing.AutoAdvance), + CollectionAlignment: wc.Collection.Alignment, + LineCollectionPeriod: wc.Collection.Interval.ISOString(), + InvoiceAutoAdvance: wc.Invoicing.AutoAdvance, InvoiceDraftPeriod: wc.Invoicing.DraftPeriod.ISOString(), InvoiceDueAfter: wc.Invoicing.DueAfter.ISOString(), InvoiceCollectionMethod: wc.Payment.CollectionMethod, @@ -350,7 +350,7 @@ func mapWorkflowConfigFromDB(dbWC *db.BillingWorkflowConfig) (billingentity.Work }, Invoicing: billingentity.InvoicingConfig{ - AutoAdvance: lo.ToPtr(dbWC.InvoiceAutoAdvance), + AutoAdvance: dbWC.InvoiceAutoAdvance, DraftPeriod: draftPeriod, DueAfter: dueAfter, }, diff --git a/openmeter/billing/entity/invoice.go b/openmeter/billing/entity/invoice.go index fbc26f312..4e1fed2d3 100644 --- a/openmeter/billing/entity/invoice.go +++ b/openmeter/billing/entity/invoice.go @@ -1,11 +1,12 @@ package billingentity import ( + "errors" "fmt" + "strings" "time" "github.com/invopop/gobl/bill" - "github.com/invopop/gobl/cbc" "github.com/samber/lo" "github.com/openmeterio/openmeter/pkg/currencyx" @@ -37,83 +38,78 @@ func (t InvoiceType) Validate() error { return fmt.Errorf("invalid invoice type: %s", t) } -// TODO: remove with gobl (once the calculations are in place) -func (t InvoiceType) CBCKey() cbc.Key { - return cbc.Key(t) -} - type InvoiceStatus string const ( // InvoiceStatusGathering is the status of an invoice that is gathering the items to be invoiced. InvoiceStatusGathering InvoiceStatus = "gathering" - // InvoiceStatusPendingCreation is the status of an invoice summarizing the pending items. - InvoiceStatusPendingCreation InvoiceStatus = "pending_creation" - // InvoiceStatusCreated is the status of an invoice that has been created. - InvoiceStatusCreated InvoiceStatus = "created" - // InvoiceStatusValidationFailed is the status of an invoice that failed validation. - InvoiceStatusValidationFailed InvoiceStatus = "validation_failed" - // InvoiceStatusDraft is the status of an invoice that is in draft both on OpenMeter and the provider side. - InvoiceStatusDraft InvoiceStatus = "draft" - // InvoiceStatusDraftSync is the status of an invoice that is being synced with the provider. - InvoiceStatusDraftSync InvoiceStatus = "draft_sync" - // InvoiceStatusDraftSyncFailed is the status of an invoice that failed to sync with the provider. - InvoiceStatusDraftSyncFailed InvoiceStatus = "draft_sync_failed" - // InvoiceStatusIssuing is the status of an invoice that is being issued. - InvoiceStatusIssuing InvoiceStatus = "issuing" - // InvoiceStatusIssued is the status of an invoice that has been issued both on OpenMeter and provider side. + + InvoiceStatusDraftCreated InvoiceStatus = "draft_created" + InvoiceStatusDraftManualApprovalNeeded InvoiceStatus = "draft_manual_approval_needed" + InvoiceStatusDraftValidating InvoiceStatus = "draft_validating" + InvoiceStatusDraftInvalid InvoiceStatus = "draft_invalid" + InvoiceStatusDraftSyncing InvoiceStatus = "draft_syncing" + InvoiceStatusDraftSyncFailed InvoiceStatus = "draft_sync_failed" + InvoiceStatusDraftWaitingAutoApproval InvoiceStatus = "draft_waiting_auto_approval" + InvoiceStatusDraftReadyToIssue InvoiceStatus = "draft_ready_to_issue" + + InvoiceStatusIssuing InvoiceStatus = "issuing_syncing" + InvoiceStatusIssuingSyncFailed InvoiceStatus = "issuing_sync_failed" + + // InvoiceStatusIssued is the status of an invoice that has been issued. InvoiceStatusIssued InvoiceStatus = "issued" - // InvoiceStatusIssuingFailed is the status of an invoice that failed to issue on the provider or OpenMeter side. - InvoiceStatusIssuingFailed InvoiceStatus = "issuing_failed" - // InvoiceStatusManualApprovalNeeded is the status of an invoice that needs manual approval. (due to AutoApprove is disabled) - InvoiceStatusManualApprovalNeeded InvoiceStatus = "manual_approval_needed" - // InvoiceStatusDeleted is the status of an invoice that has been deleted (e.g. removed from the database before being issued). - InvoiceStatusDeleted InvoiceStatus = "deleted" ) -// InvoiceImmutableStatuses are the statuses that forbid any changes to the invoice. -var InvoiceImmutableStatuses = []InvoiceStatus{ +var validStatuses = []InvoiceStatus{ + InvoiceStatusGathering, + InvoiceStatusDraftCreated, + InvoiceStatusDraftManualApprovalNeeded, + InvoiceStatusDraftValidating, + InvoiceStatusDraftInvalid, + InvoiceStatusDraftSyncing, + InvoiceStatusDraftSyncFailed, + InvoiceStatusDraftWaitingAutoApproval, + InvoiceStatusDraftReadyToIssue, + InvoiceStatusIssuing, + InvoiceStatusIssuingSyncFailed, InvoiceStatusIssued, - InvoiceStatusDeleted, } func (s InvoiceStatus) Values() []string { return lo.Map( - []InvoiceStatus{ - InvoiceStatusGathering, - InvoiceStatusCreated, - InvoiceStatusDraft, - InvoiceStatusDraftSync, - InvoiceStatusDraftSyncFailed, - InvoiceStatusIssuing, - InvoiceStatusIssued, - InvoiceStatusIssuingFailed, - InvoiceStatusManualApprovalNeeded, - }, + validStatuses, func(item InvoiceStatus, _ int) string { return string(item) }, ) } -func (s InvoiceStatus) Validate() error { - for _, status := range s.Values() { - if string(s) == status { - return nil - } - } - - return fmt.Errorf("invalid invoice status: %s", s) +func (s InvoiceStatus) ShortStatus() string { + parts := strings.SplitN(string(s), "_", 2) + return parts[0] } +var immutableStatuses = []InvoiceStatus{InvoiceStatusIssued} + func (s InvoiceStatus) IsMutable() bool { - for _, status := range InvoiceImmutableStatuses { - if s == status { - return false - } + return !lo.Contains(immutableStatuses, s) +} + +var failedStatuses = []InvoiceStatus{ + InvoiceStatusDraftSyncFailed, + InvoiceStatusIssuingSyncFailed, +} + +func (s InvoiceStatus) IsFailed() bool { + return lo.Contains(failedStatuses, s) +} + +func (s InvoiceStatus) Validate() error { + if !lo.Contains(validStatuses, s) { + return fmt.Errorf("invalid invoice status: %s", s) } - return true + return nil } type InvoiceID models.NamespacedID @@ -122,6 +118,28 @@ func (i InvoiceID) Validate() error { return models.NamespacedID(i).Validate() } +type InvoiceExpand struct { + Lines bool + Preceding bool + Workflow bool + WorkflowApps bool +} + +var InvoiceExpandAll = InvoiceExpand{ + Lines: true, + Preceding: true, + Workflow: true, + WorkflowApps: true, +} + +func (e InvoiceExpand) Validate() error { + if !e.Workflow && e.WorkflowApps { + return errors.New("workflow.apps can only be expanded when workflow is expanded") + } + + return nil +} + type Invoice struct { Namespace string `json:"namespace"` ID string `json:"id"` @@ -133,19 +151,21 @@ type Invoice struct { Metadata map[string]string `json:"metadata"` - Currency currencyx.Code `json:"currency,omitempty"` - Timezone timezone.Timezone `json:"timezone,omitempty"` - Status InvoiceStatus `json:"status"` + Currency currencyx.Code `json:"currency,omitempty"` + Timezone timezone.Timezone `json:"timezone,omitempty"` + Status InvoiceStatus `json:"status"` + StatusDetails InvoiceStatusDetails `json:"statusDetail,omitempty"` Period *Period `json:"period,omitempty"` DueAt *time.Time `json:"dueDate,omitempty"` - CreatedAt time.Time `json:"createdAt"` - UpdatedAt time.Time `json:"updatedAt"` - VoidedAt *time.Time `json:"voidedAt,omitempty"` - IssuedAt *time.Time `json:"issuedAt,omitempty"` - DeletedAt *time.Time `json:"deletedAt,omitempty"` + CreatedAt time.Time `json:"createdAt"` + UpdatedAt time.Time `json:"updatedAt"` + VoidedAt *time.Time `json:"voidedAt,omitempty"` + DraftUntil *time.Time `json:"draftUntil,omitempty"` + IssuedAt *time.Time `json:"issuedAt,omitempty"` + DeletedAt *time.Time `json:"deletedAt,omitempty"` // Customer is either a snapshot of the contact information of the customer at the time of invoice being sent // or the data from the customer entity (draft state) @@ -156,6 +176,41 @@ type Invoice struct { // Line items Lines []Line `json:"lines,omitempty"` + + // private fields required by the service + Changed bool `json:"-"` + ExpandedFields InvoiceExpand `json:"-"` +} + +func (i *Invoice) Calculate() error { + for _, calc := range InvoiceCalculations { + changed, err := calc(i) + if err != nil { + return err + } + + if changed { + i.Changed = true + } + } + + return nil +} + +type InvoiceAction string + +const ( + InvoiceActionAdvance InvoiceAction = "advance" + InvoiceActionApprove InvoiceAction = "approve" + InvoiceActionDelete InvoiceAction = "delete" + InvoiceActionRetry InvoiceAction = "retry" + InvoiceActionVoid InvoiceAction = "void" +) + +type InvoiceStatusDetails struct { + Immutable bool `json:"immutable"` + Failed bool `json:"failed"` + AvailableActions []InvoiceAction `json:"availableActions"` } type InvoiceWithValidation struct { diff --git a/openmeter/billing/entity/invoicecalc.go b/openmeter/billing/entity/invoicecalc.go new file mode 100644 index 000000000..63715653b --- /dev/null +++ b/openmeter/billing/entity/invoicecalc.go @@ -0,0 +1,21 @@ +package billingentity + +type InvoiceCalculation func(*Invoice) (bool, error) + +var InvoiceCalculations = []InvoiceCalculation{ + CalculateDraftUntilIfMissing, +} + +// CalculateDraftUntilIfMissing calculates the draft until date if it is missing. +// If it's set we are not updating it as the user should update that instead of manipulating the +// workflow config. +func CalculateDraftUntilIfMissing(i *Invoice) (bool, error) { + if !i.ExpandedFields.Workflow || i.DraftUntil != nil || !i.Workflow.Config.Invoicing.AutoAdvance { + return false, nil + } + + draftUntil, _ := i.Workflow.Config.Invoicing.DraftPeriod.AddTo(i.CreatedAt) + i.DraftUntil = &draftUntil + + return true, nil +} diff --git a/openmeter/billing/entity/profile.go b/openmeter/billing/entity/profile.go index 987c6f6f6..9603ef71d 100644 --- a/openmeter/billing/entity/profile.go +++ b/openmeter/billing/entity/profile.go @@ -26,21 +26,6 @@ const ( AlignmentKindSubscription AlignmentKind = "subscription" ) -var DefaultWorkflowConfig = WorkflowConfig{ - Collection: CollectionConfig{ - Alignment: AlignmentKindSubscription, - Interval: lo.Must(datex.ISOString("PT2H").Parse()), - }, - Invoicing: InvoicingConfig{ - AutoAdvance: lo.ToPtr(true), - DraftPeriod: lo.Must(datex.ISOString("P1D").Parse()), - DueAfter: lo.Must(datex.ISOString("P1W").Parse()), - }, - Payment: PaymentConfig{ - CollectionMethod: CollectionMethodChargeAutomatically, - }, -} - func (k AlignmentKind) Values() []string { return []string{ string(AlignmentKindSubscription), @@ -114,13 +99,13 @@ func (c *CollectionConfig) Validate() error { // InvoiceConfig groups fields related to invoice settings. type InvoicingConfig struct { - AutoAdvance *bool `json:"autoAdvance,omitempty"` + AutoAdvance bool `json:"autoAdvance,omitempty"` DraftPeriod datex.Period `json:"draftPeriod,omitempty"` DueAfter datex.Period `json:"dueAfter,omitempty"` } func (c *InvoicingConfig) Validate() error { - if c.DraftPeriod.IsNegative() && c.AutoAdvance != nil && *c.AutoAdvance { + if c.DraftPeriod.IsNegative() && c.AutoAdvance { return fmt.Errorf("draft period must be greater or equal to 0") } @@ -266,7 +251,7 @@ func (p Profile) Merge(o *CustomerOverride) Profile { } p.WorkflowConfig.Invoicing = InvoicingConfig{ - AutoAdvance: lo.CoalesceOrEmpty(o.Invoicing.AutoAdvance, p.WorkflowConfig.Invoicing.AutoAdvance), + AutoAdvance: lo.FromPtrOr(o.Invoicing.AutoAdvance, p.WorkflowConfig.Invoicing.AutoAdvance), DraftPeriod: lo.FromPtrOr(o.Invoicing.DraftPeriod, p.WorkflowConfig.Invoicing.DraftPeriod), DueAfter: lo.FromPtrOr(o.Invoicing.DueAfter, p.WorkflowConfig.Invoicing.DueAfter), } @@ -317,5 +302,5 @@ type InvoiceWorkflow struct { AppReferences ProfileAppReferences `json:"appReferences"` Apps *ProfileApps `json:"apps,omitempty"` SourceBillingProfileID string `json:"sourceBillingProfileId,omitempty"` - WorkflowConfig WorkflowConfig `json:"workflow"` + Config WorkflowConfig `json:"config"` } diff --git a/openmeter/billing/errors.go b/openmeter/billing/errors.go index c95c70c16..00b04ce9f 100644 --- a/openmeter/billing/errors.go +++ b/openmeter/billing/errors.go @@ -67,6 +67,24 @@ func (e NotFoundError) Unwrap() error { return e.Err } +type ConflictError struct { + ID string + Entity string + Err error +} + +func (e ConflictError) Error() string { + if e.ID == "" { + return e.Err.Error() + } + + return fmt.Sprintf("%s [%s/%s]", e.Err, e.Entity, e.ID) +} + +func (e ConflictError) Unwrap() error { + return e.Err +} + type genericError struct { Err error } diff --git a/openmeter/billing/httpdriver/defaults.go b/openmeter/billing/httpdriver/defaults.go index 7c6a96edb..141bad61b 100644 --- a/openmeter/billing/httpdriver/defaults.go +++ b/openmeter/billing/httpdriver/defaults.go @@ -1,8 +1,30 @@ package httpdriver +import ( + "github.com/samber/lo" + + billingentity "github.com/openmeterio/openmeter/openmeter/billing/entity" + "github.com/openmeterio/openmeter/pkg/datex" +) + const ( DefaultPageSize = 100 DefaultPageNumber = 1 DefaultIncludeArchived = false DefaultInvoiceTimezone = "UTC" ) + +var defaultWorkflowConfig = billingentity.WorkflowConfig{ + Collection: billingentity.CollectionConfig{ + Alignment: billingentity.AlignmentKindSubscription, + Interval: lo.Must(datex.ISOString("PT2H").Parse()), + }, + Invoicing: billingentity.InvoicingConfig{ + AutoAdvance: true, + DraftPeriod: lo.Must(datex.ISOString("P1D").Parse()), + DueAfter: lo.Must(datex.ISOString("P1W").Parse()), + }, + Payment: billingentity.PaymentConfig{ + CollectionMethod: billingentity.CollectionMethodChargeAutomatically, + }, +} diff --git a/openmeter/billing/httpdriver/errors.go b/openmeter/billing/httpdriver/errors.go index 76c5944f5..3da7ca1c9 100644 --- a/openmeter/billing/httpdriver/errors.go +++ b/openmeter/billing/httpdriver/errors.go @@ -16,6 +16,7 @@ func errorEncoder() httptransport.ErrorEncoder { return func(ctx context.Context, err error, w http.ResponseWriter, r *http.Request) bool { return commonhttp.HandleErrorIfTypeMatches[*models.GenericUserError](ctx, http.StatusBadRequest, err, w) || commonhttp.HandleErrorIfTypeMatches[billing.NotFoundError](ctx, http.StatusNotFound, err, w) || + commonhttp.HandleErrorIfTypeMatches[billing.ConflictError](ctx, http.StatusConflict, err, w) || commonhttp.HandleErrorIfTypeMatches[billing.ValidationError](ctx, http.StatusBadRequest, err, w) || commonhttp.HandleErrorIfTypeMatches[billing.UpdateAfterDeleteError](ctx, http.StatusConflict, err, w) || // dependency: customer diff --git a/openmeter/billing/httpdriver/invoice.go b/openmeter/billing/httpdriver/invoice.go index 97536635e..14b7473ac 100644 --- a/openmeter/billing/httpdriver/invoice.go +++ b/openmeter/billing/httpdriver/invoice.go @@ -37,7 +37,13 @@ func (h *handler) ListInvoices() ListInvoicesHandler { Customers: lo.FromPtrOr(input.Customers, nil), Statuses: lo.Map( lo.FromPtrOr(input.Statuses, nil), - func(status api.BillingInvoiceStatus, _ int) billingentity.InvoiceStatus { + func(status api.BillingInvoiceStatus, _ int) string { + return string(status) + }, + ), + ExtendedStatuses: lo.Map( + lo.FromPtrOr(input.ExtendedStatuses, nil), + func(status api.BillingInvoiceExtendedStatus, _ int) billingentity.InvoiceStatus { return billingentity.InvoiceStatus(status) }, ), @@ -108,13 +114,14 @@ func mapInvoiceToAPI(invoice billingentity.Invoice) (api.BillingInvoice, error) out := api.BillingInvoice{ Id: invoice.ID, - CreatedAt: invoice.CreatedAt, - UpdatedAt: invoice.UpdatedAt, - DeletedAt: invoice.DeletedAt, - IssuedAt: invoice.IssuedAt, - VoidedAt: invoice.VoidedAt, - DueAt: invoice.DueAt, - Period: mapPeriodToAPI(invoice.Period), + CreatedAt: invoice.CreatedAt, + UpdatedAt: invoice.UpdatedAt, + DeletedAt: invoice.DeletedAt, + IssuedAt: invoice.IssuedAt, + VoidedAt: invoice.VoidedAt, + DueAt: invoice.DueAt, + DraftUntil: invoice.DraftUntil, + Period: mapPeriodToAPI(invoice.Period), Currency: string(invoice.Currency), Customer: mapInvoiceCustomerToAPI(invoice.Customer), @@ -123,7 +130,16 @@ func mapInvoiceToAPI(invoice billingentity.Invoice) (api.BillingInvoice, error) Description: invoice.Description, Metadata: lo.EmptyableToPtr(invoice.Metadata), - Status: api.BillingInvoiceStatus(invoice.Status), + Status: api.BillingInvoiceStatus(invoice.Status.ShortStatus()), + StatusDetails: api.BillingInvoiceStatusDetails{ + Failed: invoice.StatusDetails.Failed, + Immutable: invoice.StatusDetails.Immutable, + ExtendedStatus: api.BillingInvoiceExtendedStatus(invoice.Status), + + AvailableActions: lo.Map(invoice.StatusDetails.AvailableActions, func(a billingentity.InvoiceAction, _ int) api.BillingInvoiceAction { + return api.BillingInvoiceAction(a) + }), + }, Supplier: mapSupplierContactToAPI(invoice.Supplier), // TODO[OM-942]: This needs to be (re)implemented Totals: api.BillingInvoiceTotals{}, @@ -136,7 +152,7 @@ func mapInvoiceToAPI(invoice billingentity.Invoice) (api.BillingInvoice, error) out.Workflow = &api.BillingInvoiceWorkflowSettings{ Apps: apps, SourceBillingProfileID: invoice.Workflow.SourceBillingProfileID, - Workflow: mapWorkflowConfigSettingsToAPI(invoice.Workflow.WorkflowConfig), + Workflow: mapWorkflowConfigSettingsToAPI(invoice.Workflow.Config), Timezone: string(invoice.Timezone), } } @@ -189,13 +205,13 @@ func mapInvoiceCustomerToAPI(c billingentity.InvoiceCustomer) api.BillingParty { } } -func mapInvoiceExpandToEntity(expand []api.BillingInvoiceExpand) billing.InvoiceExpand { +func mapInvoiceExpandToEntity(expand []api.BillingInvoiceExpand) billingentity.InvoiceExpand { if len(expand) == 0 { - return billing.InvoiceExpand{} + return billingentity.InvoiceExpand{} } if slices.Contains(expand, api.BillingInvoiceExpandAll) { - return billing.InvoiceExpand{ + return billingentity.InvoiceExpand{ Lines: true, Preceding: true, Workflow: true, @@ -203,7 +219,7 @@ func mapInvoiceExpandToEntity(expand []api.BillingInvoiceExpand) billing.Invoice } } - return billing.InvoiceExpand{ + return billingentity.InvoiceExpand{ Lines: slices.Contains(expand, api.BillingInvoiceExpandLines), Preceding: slices.Contains(expand, api.BillingInvoiceExpandPreceding), Workflow: slices.Contains(expand, api.BillingInvoiceExpandWorkflow), diff --git a/openmeter/billing/httpdriver/profile.go b/openmeter/billing/httpdriver/profile.go index 57c2aee4e..7582121a1 100644 --- a/openmeter/billing/httpdriver/profile.go +++ b/openmeter/billing/httpdriver/profile.go @@ -356,7 +356,7 @@ func fromAPIBillingAppIdOrType(i string) billingentity.AppReference { } func fromAPIBillingWorkflow(i api.BillingWorkflow) (billingentity.WorkflowConfig, error) { - def := billingentity.DefaultWorkflowConfig + def := defaultWorkflowConfig if i.Collection == nil { i.Collection = &api.BillingWorkflowCollectionSettings{} @@ -397,7 +397,7 @@ func fromAPIBillingWorkflow(i api.BillingWorkflow) (billingentity.WorkflowConfig }, Invoicing: billingentity.InvoicingConfig{ - AutoAdvance: lo.CoalesceOrEmpty(i.Invoicing.AutoAdvance, def.Invoicing.AutoAdvance), + AutoAdvance: lo.FromPtrOr(i.Invoicing.AutoAdvance, def.Invoicing.AutoAdvance), DraftPeriod: draftPeriod, DueAfter: dueAfter, }, @@ -576,7 +576,7 @@ func mapWorkflowConfigToAPI(c billingentity.WorkflowConfig) api.BillingWorkflow }, Invoicing: &api.BillingWorkflowInvoicingSettings{ - AutoAdvance: c.Invoicing.AutoAdvance, + AutoAdvance: lo.ToPtr(c.Invoicing.AutoAdvance), DraftPeriod: lo.EmptyableToPtr(c.Invoicing.DraftPeriod.String()), DueAfter: lo.EmptyableToPtr(c.Invoicing.DueAfter.String()), }, @@ -595,7 +595,7 @@ func mapWorkflowConfigSettingsToAPI(c billingentity.WorkflowConfig) api.BillingW }, Invoicing: &api.BillingWorkflowInvoicingSettings{ - AutoAdvance: c.Invoicing.AutoAdvance, + AutoAdvance: lo.ToPtr(c.Invoicing.AutoAdvance), DraftPeriod: lo.EmptyableToPtr(c.Invoicing.DraftPeriod.String()), DueAfter: lo.EmptyableToPtr(c.Invoicing.DueAfter.String()), }, diff --git a/openmeter/billing/invoice.go b/openmeter/billing/invoice.go index 3536b62a2..d91be597e 100644 --- a/openmeter/billing/invoice.go +++ b/openmeter/billing/invoice.go @@ -14,31 +14,9 @@ import ( "github.com/openmeterio/openmeter/pkg/sortx" ) -type InvoiceExpand struct { - Lines bool - Preceding bool - Workflow bool - WorkflowApps bool -} - -var InvoiceExpandAll = InvoiceExpand{ - Lines: true, - Preceding: true, - Workflow: true, - WorkflowApps: true, -} - -func (e InvoiceExpand) Validate() error { - if !e.Workflow && e.WorkflowApps { - return errors.New("workflow.apps can only be expanded when workflow is expanded") - } - - return nil -} - type GetInvoiceByIdInput struct { Invoice billingentity.InvoiceID - Expand InvoiceExpand + Expand billingentity.InvoiceExpand } func (i GetInvoiceByIdInput) Validate() error { @@ -79,15 +57,18 @@ type ( type ListInvoicesInput struct { pagination.Page - Namespace string - Customers []string - Statuses []billingentity.InvoiceStatus - Currencies []currencyx.Code + Namespace string + Customers []string + // Statuses searches by short InvoiceStatus (e.g. draft, issued) + Statuses []string + // ExtendedStatuses searches by exact InvoiceStatus + ExtendedStatuses []billingentity.InvoiceStatus + Currencies []currencyx.Code IssuedAfter *time.Time IssuedBefore *time.Time - Expand InvoiceExpand + Expand billingentity.InvoiceExpand OrderBy api.BillingInvoiceOrderBy Order sortx.Order @@ -177,3 +158,10 @@ func (i CreateInvoiceInput) Validate() error { type AssociatedLineCountsAdapterResponse struct { Counts map[billingentity.InvoiceID]int64 } + +type ( + AdvanceInvoiceInput = billingentity.InvoiceID + ApproveInvoiceInput = billingentity.InvoiceID +) + +type UpdateInvoiceAdapterInput = billingentity.Invoice diff --git a/openmeter/billing/profile.go b/openmeter/billing/profile.go index 60f52189a..1acccacdf 100644 --- a/openmeter/billing/profile.go +++ b/openmeter/billing/profile.go @@ -4,8 +4,6 @@ import ( "errors" "fmt" - "github.com/samber/lo" - "github.com/openmeterio/openmeter/api" billingentity "github.com/openmeterio/openmeter/openmeter/billing/entity" "github.com/openmeterio/openmeter/pkg/models" @@ -53,37 +51,6 @@ func (i CreateProfileInput) Validate() error { return nil } -func (i CreateProfileInput) WithDefaults() CreateProfileInput { - i.WorkflowConfig = billingentity.WorkflowConfig{ - Collection: billingentity.CollectionConfig{ - Alignment: lo.CoalesceOrEmpty( - i.WorkflowConfig.Collection.Alignment, - billingentity.DefaultWorkflowConfig.Collection.Alignment), - Interval: lo.CoalesceOrEmpty( - i.WorkflowConfig.Collection.Interval, - billingentity.DefaultWorkflowConfig.Collection.Interval), - }, - Invoicing: billingentity.InvoicingConfig{ - AutoAdvance: lo.CoalesceOrEmpty( - i.WorkflowConfig.Invoicing.AutoAdvance, - billingentity.DefaultWorkflowConfig.Invoicing.AutoAdvance), - DraftPeriod: lo.CoalesceOrEmpty( - i.WorkflowConfig.Invoicing.DraftPeriod, - billingentity.DefaultWorkflowConfig.Invoicing.DraftPeriod), - DueAfter: lo.CoalesceOrEmpty( - i.WorkflowConfig.Invoicing.DueAfter, - billingentity.DefaultWorkflowConfig.Invoicing.DueAfter), - }, - Payment: billingentity.PaymentConfig{ - CollectionMethod: lo.CoalesceOrEmpty( - i.WorkflowConfig.Payment.CollectionMethod, - billingentity.DefaultWorkflowConfig.Payment.CollectionMethod), - }, - } - - return i -} - type CreateProfileAppsInput = billingentity.ProfileAppReferences type ListProfilesResult = pagination.PagedResponse[billingentity.Profile] @@ -215,9 +182,5 @@ func (i UpdateProfileAdapterInput) Validate() error { return fmt.Errorf("workflow config id is required") } - if i.TargetState.WorkflowConfig.Invoicing.AutoAdvance == nil { - return fmt.Errorf("invoicing auto advance is required") - } - return nil } diff --git a/openmeter/billing/service.go b/openmeter/billing/service.go index ffe2edddc..cf0f83779 100644 --- a/openmeter/billing/service.go +++ b/openmeter/billing/service.go @@ -39,4 +39,10 @@ type InvoiceService interface { ListInvoices(ctx context.Context, input ListInvoicesInput) (ListInvoicesResponse, error) GetInvoiceByID(ctx context.Context, input GetInvoiceByIdInput) (billingentity.Invoice, error) CreateInvoice(ctx context.Context, input CreateInvoiceInput) ([]billingentity.Invoice, error) + // AdvanceInvoice advances the invoice to the next stage, the advancement is stopped until: + // - an error is occurred + // - the invoice is in a state that cannot be advanced (e.g. waiting for draft period to expire) + // - the invoice is advanced to the final state + AdvanceInvoice(ctx context.Context, input AdvanceInvoiceInput) (*billingentity.Invoice, error) + ApproveInvoice(ctx context.Context, input ApproveInvoiceInput) (*billingentity.Invoice, error) } diff --git a/openmeter/billing/service/invoice.go b/openmeter/billing/service/invoice.go index a4502b47a..8d7b44c6f 100644 --- a/openmeter/billing/service/invoice.go +++ b/openmeter/billing/service/invoice.go @@ -24,19 +24,10 @@ func (s *Service) ListInvoices(ctx context.Context, input billing.ListInvoicesIn return billing.ListInvoicesResponse{}, err } - if input.Expand.WorkflowApps { - for i := range invoices.Items { - invoice := &invoices.Items[i] - resolvedApps, err := s.resolveApps(ctx, input.Namespace, invoice.Workflow.AppReferences) - if err != nil { - return billing.ListInvoicesResponse{}, fmt.Errorf("error resolving apps for invoice [%s]: %w", invoice.ID, err) - } - - invoice.Workflow.Apps = &billingentity.ProfileApps{ - Tax: resolvedApps.Tax.App, - Invoicing: resolvedApps.Invoicing.App, - Payment: resolvedApps.Payment.App, - } + for i := range invoices.Items { + invoices.Items[i], err = s.addInvoiceFields(ctx, invoices.Items[i]) + if err != nil { + return billing.ListInvoicesResponse{}, fmt.Errorf("error adding fields to invoice [%s]: %w", invoices.Items[i].ID, err) } } @@ -44,6 +35,27 @@ func (s *Service) ListInvoices(ctx context.Context, input billing.ListInvoicesIn }) } +func (s *Service) addInvoiceFields(ctx context.Context, invoice billingentity.Invoice) (billingentity.Invoice, error) { + if invoice.ExpandedFields.WorkflowApps { + resolvedApps, err := s.resolveApps(ctx, invoice.Namespace, invoice.Workflow.AppReferences) + if err != nil { + return invoice, fmt.Errorf("error resolving apps for invoice [%s]: %w", invoice.ID, err) + } + + invoice.Workflow.Apps = &billingentity.ProfileApps{ + Tax: resolvedApps.Tax.App, + Invoicing: resolvedApps.Invoicing.App, + Payment: resolvedApps.Payment.App, + } + } + + // let's resolve the statatus details + invoice.StatusDetails = NewInvoiceStateMachine(&invoice). + StatusDetails(ctx) + + return invoice, nil +} + func (s *Service) GetInvoiceByID(ctx context.Context, input billing.GetInvoiceByIdInput) (billingentity.Invoice, error) { return entutils.TransactingRepo(ctx, s.adapter, func(ctx context.Context, txAdapter billing.Adapter) (billingentity.Invoice, error) { invoice, err := txAdapter.GetInvoiceById(ctx, input) @@ -51,19 +63,10 @@ func (s *Service) GetInvoiceByID(ctx context.Context, input billing.GetInvoiceBy return billingentity.Invoice{}, err } - if input.Expand.WorkflowApps { - resolvedApps, err := s.resolveApps(ctx, input.Invoice.Namespace, invoice.Workflow.AppReferences) - if err != nil { - return billingentity.Invoice{}, fmt.Errorf("error resolving apps for invoice [%s]: %w", invoice.ID, err) - } - - invoice.Workflow.Apps = &billingentity.ProfileApps{ - Tax: resolvedApps.Tax.App, - Invoicing: resolvedApps.Invoicing.App, - Payment: resolvedApps.Payment.App, - } + invoice, err = s.addInvoiceFields(ctx, invoice) + if err != nil { + return billingentity.Invoice{}, fmt.Errorf("error adding fields to invoice [%s]: %w", invoice.ID, err) } - return invoice, nil }) } @@ -130,7 +133,7 @@ func (s *Service) CreateInvoice(ctx context.Context, input billing.CreateInvoice Profile: customerProfile.Profile, Currency: currency, - Status: billingentity.InvoiceStatusDraft, + Status: billingentity.InvoiceStatusDraftCreated, Type: billingentity.InvoiceTypeStandard, }) @@ -184,12 +187,26 @@ func (s *Service) CreateInvoice(ctx context.Context, input billing.CreateInvoice for _, invoiceID := range createdInvoices { invoiceWithLines, err := s.GetInvoiceByID(ctx, billing.GetInvoiceByIdInput{ Invoice: invoiceID, - Expand: billing.InvoiceExpandAll, + Expand: billingentity.InvoiceExpandAll, }) if err != nil { return nil, fmt.Errorf("cannot get invoice[%s]: %w", invoiceWithLines.ID, err) } + // let's update any calculated fields on the invoice + err = invoiceWithLines.Calculate() + if err != nil { + return nil, fmt.Errorf("calculating invoice fields: %w", err) + } + + // let's update the invoice in the DB if needed + if invoiceWithLines.Changed { + err = txAdapter.UpdateInvoice(ctx, invoiceWithLines) + if err != nil { + return nil, fmt.Errorf("updating invoice: %w", err) + } + } + out = append(out, invoiceWithLines) } return out, nil @@ -262,3 +279,91 @@ func (s *Service) gatherInscopeLines(ctx context.Context, input billing.CreateIn return lines, nil } + +func (s *Service) getInvoiceStatMachineWithLock(ctx context.Context, txAdapter billing.Adapter, invoiceID billingentity.InvoiceID) (*InvoiceStateMachine, error) { + // let's lock the invoice for update, we are using the dedicated call, so that + // edges won't end up having SELECT FOR UPDATE locks + if err := txAdapter.LockInvoicesForUpdate(ctx, billing.LockInvoicesForUpdateInput{ + Namespace: invoiceID.Namespace, + InvoiceIDs: []string{invoiceID.ID}, + }); err != nil { + return nil, fmt.Errorf("locking invoice: %w", err) + } + + invoice, err := s.GetInvoiceByID(ctx, billing.GetInvoiceByIdInput{ + Invoice: invoiceID, + Expand: billingentity.InvoiceExpandAll, + }) + if err != nil { + return nil, fmt.Errorf("fetching invoice: %w", err) + } + + return NewInvoiceStateMachine(&invoice), nil +} + +func (s *Service) AdvanceInvoice(ctx context.Context, input billing.AdvanceInvoiceInput) (*billingentity.Invoice, error) { + if err := input.Validate(); err != nil { + return nil, billing.ValidationError{ + Err: err, + } + } + + return entutils.TransactingRepo(ctx, s.adapter, func(ctx context.Context, txAdapter billing.Adapter) (*billingentity.Invoice, error) { + fsm, err := s.getInvoiceStatMachineWithLock(ctx, txAdapter, input) + if err != nil { + return nil, err + } + + preActivationStatus := fsm.Invoice.Status + + if err := fsm.ActivateUntilStateStable(ctx); err != nil { + return nil, fmt.Errorf("activating invoice: %w", err) + } + + s.logger.Info("invoice advanced", "invoice", input.ID, "from", preActivationStatus, "to", fsm.Invoice.Status) + + // Given the amount of state transitions, we are only saving the invoice after the whole chain + // this means that some of the intermittent states will not be persisted in the DB. + if err := txAdapter.UpdateInvoice(ctx, *fsm.Invoice); err != nil { + return nil, fmt.Errorf("updating invoice: %w", err) + } + + return fsm.Invoice, nil + }) +} + +func (s *Service) ApproveInvoice(ctx context.Context, input billing.ApproveInvoiceInput) (*billingentity.Invoice, error) { + if err := input.Validate(); err != nil { + return nil, billing.ValidationError{ + Err: err, + } + } + + return entutils.TransactingRepo(ctx, s.adapter, func(ctx context.Context, txAdapter billing.Adapter) (*billingentity.Invoice, error) { + fsm, err := s.getInvoiceStatMachineWithLock(ctx, txAdapter, input) + if err != nil { + return nil, err + } + + canFire, err := fsm.CanFire(ctx, triggerApprove) + if err != nil { + return nil, fmt.Errorf("checking if can fire: %w", err) + } + + if !canFire { + return nil, billing.ValidationError{ + Err: fmt.Errorf("cannot approve invoice in status [%s]", fsm.Invoice.Status), + } + } + + if err := fsm.FireAndActivate(ctx, triggerApprove); err != nil { + return nil, fmt.Errorf("firing approve: %w", err) + } + + if err := txAdapter.UpdateInvoice(ctx, *fsm.Invoice); err != nil { + return nil, fmt.Errorf("updating invoice: %w", err) + } + + return fsm.Invoice, nil + }) +} diff --git a/openmeter/billing/service/invoiceline.go b/openmeter/billing/service/invoiceline.go index e9cb6fa8a..ba7e2f9d7 100644 --- a/openmeter/billing/service/invoiceline.go +++ b/openmeter/billing/service/invoiceline.go @@ -104,12 +104,12 @@ func (s *Service) upsertLineInvoice(ctx context.Context, txAdapter billing.Adapt PageNumber: 1, PageSize: 10, }, - Customers: []string{input.CustomerID}, - Namespace: input.Namespace, - Statuses: []billingentity.InvoiceStatus{billingentity.InvoiceStatusGathering}, - Currencies: []currencyx.Code{line.Currency}, - OrderBy: api.BillingInvoiceOrderByCreatedAt, - Order: sortx.OrderAsc, + Customers: []string{input.CustomerID}, + Namespace: input.Namespace, + ExtendedStatuses: []billingentity.InvoiceStatus{billingentity.InvoiceStatusGathering}, + Currencies: []currencyx.Code{line.Currency}, + OrderBy: api.BillingInvoiceOrderByCreatedAt, + Order: sortx.OrderAsc, }) if err != nil { return line, fmt.Errorf("fetching gathering invoices: %w", err) diff --git a/openmeter/billing/service/invoicestate.go b/openmeter/billing/service/invoicestate.go new file mode 100644 index 000000000..49ec0df18 --- /dev/null +++ b/openmeter/billing/service/invoicestate.go @@ -0,0 +1,234 @@ +package billingservice + +import ( + "context" + "fmt" + "time" + + "github.com/qmuntal/stateless" + + billingentity "github.com/openmeterio/openmeter/openmeter/billing/entity" + "github.com/openmeterio/openmeter/pkg/clock" +) + +type InvoiceStateMachine struct { + Invoice *billingentity.Invoice + StateMachine *stateless.StateMachine +} + +var ( + // triggerRetry is used to retry a state transition that failed, used by the end user to invoke it manually + triggerRetry stateless.Trigger = "trigger_retry" + // triggerApprove is used to approve a state manually + triggerApprove stateless.Trigger = "trigger_approve" + // triggerNext is used to advance the invoice to the next state if automatically possible + triggerNext stateless.Trigger = "trigger_next" + // triggerFailed is used to trigger the failure state transition associated with the current state + triggerFailed stateless.Trigger = "trigger_failed" + + // TODO[later]: we should have a triggerAsyncNext to signify that a transition should be done asynchronously ( + // e.g. the invoice needs to be synced to an external system such as stripe) +) + +func NewInvoiceStateMachine(invoice *billingentity.Invoice) *InvoiceStateMachine { + out := &InvoiceStateMachine{ + Invoice: invoice, + } + + // TODO[later]: Tax is not captured here for now, as it would require the DB schema too + // TODO[later]: Delete invoice is not implemented yet + // TODO[optimization]: The state machine can be added to sync.Pool to avoid allocations (state is stored in the Invoice entity) + + stateMachine := stateless.NewStateMachineWithExternalStorage( + func(ctx context.Context) (stateless.State, error) { + return out.Invoice.Status, nil + }, + func(ctx context.Context, state stateless.State) error { + invState, ok := state.(billingentity.InvoiceStatus) + if !ok { + return fmt.Errorf("invalid state type: %v", state) + } + + out.Invoice.Status = invState + out.Invoice.StatusDetails = out.StatusDetails(ctx) + return nil + }, + stateless.FiringImmediate, + ) + + // Draft states + + // NOTE: we are not using the substate support of stateless for now, as the + // substate inherits all the parent's state transitions resulting in unexpected behavior. + + stateMachine.Configure(billingentity.InvoiceStatusDraftCreated). + Permit(triggerNext, billingentity.InvoiceStatusDraftValidating) + + stateMachine.Configure(billingentity.InvoiceStatusDraftValidating). + Permit(triggerNext, billingentity.InvoiceStatusDraftSyncing). + Permit(triggerFailed, billingentity.InvoiceStatusDraftInvalid). + OnActive(out.validateDraftInvoice) + + stateMachine.Configure(billingentity.InvoiceStatusDraftInvalid). + Permit(triggerRetry, billingentity.InvoiceStatusDraftValidating) + + stateMachine.Configure(billingentity.InvoiceStatusDraftSyncing). + Permit(triggerNext, + billingentity.InvoiceStatusDraftManualApprovalNeeded, + boolFn(not(out.isAutoAdvanceEnabled))). + Permit(triggerNext, + billingentity.InvoiceStatusDraftWaitingAutoApproval, + boolFn(out.isAutoAdvanceEnabled)). + Permit(triggerFailed, billingentity.InvoiceStatusDraftSyncFailed). + OnActive(out.syncDraftInvoice) + + stateMachine.Configure(billingentity.InvoiceStatusDraftSyncFailed). + Permit(triggerRetry, billingentity.InvoiceStatusDraftValidating) + + stateMachine.Configure(billingentity.InvoiceStatusDraftReadyToIssue). + Permit(triggerNext, billingentity.InvoiceStatusIssuing) + + // Automatic and manual approvals + stateMachine.Configure(billingentity.InvoiceStatusDraftWaitingAutoApproval). + // Manual approval forces the draft invoice to be issued regardless of the review period + Permit(triggerApprove, billingentity.InvoiceStatusDraftReadyToIssue). + Permit(triggerNext, + billingentity.InvoiceStatusDraftReadyToIssue, + boolFn(out.shouldAutoAdvance), + ) + + // This state is a pre-issuing state where we can halt the execution and execute issuing in the background + // if needed + stateMachine.Configure(billingentity.InvoiceStatusDraftManualApprovalNeeded). + Permit(triggerApprove, billingentity.InvoiceStatusDraftReadyToIssue) + + // Issuing state + + stateMachine.Configure(billingentity.InvoiceStatusIssuing). + Permit(triggerNext, billingentity.InvoiceStatusIssued). + Permit(triggerFailed, billingentity.InvoiceStatusIssuingSyncFailed). + OnActive(out.issueInvoice) + + stateMachine.Configure(billingentity.InvoiceStatusIssuingSyncFailed). + Permit(triggerRetry, billingentity.InvoiceStatusIssuing) + + // Issued state (final) + stateMachine.Configure(billingentity.InvoiceStatusIssued) + + out.StateMachine = stateMachine + + return out +} + +func (m *InvoiceStateMachine) StatusDetails(ctx context.Context) billingentity.InvoiceStatusDetails { + actions := make([]billingentity.InvoiceAction, 0, 4) + + if ok, err := m.StateMachine.CanFireCtx(ctx, triggerNext); err == nil && ok { + actions = append(actions, billingentity.InvoiceActionAdvance) + } + + if ok, err := m.StateMachine.CanFireCtx(ctx, triggerRetry); err == nil && ok { + actions = append(actions, billingentity.InvoiceActionRetry) + } + + if ok, err := m.StateMachine.CanFireCtx(ctx, triggerApprove); err == nil && ok { + actions = append(actions, billingentity.InvoiceActionApprove) + } + + // TODO[later]: add more actions (void, delete, etc.) + + return billingentity.InvoiceStatusDetails{ + Immutable: !m.Invoice.Status.IsMutable(), + Failed: m.Invoice.Status.IsFailed(), + AvailableActions: actions, + } +} + +func (m *InvoiceStateMachine) ActivateUntilStateStable(ctx context.Context) error { + for { + canFire, err := m.StateMachine.CanFireCtx(ctx, triggerNext) + if err != nil { + return err + } + + // We have reached a state that requires either manual intervention or that is final + if !canFire { + return nil + } + + if err := m.FireAndActivate(ctx, triggerNext); err != nil { + return fmt.Errorf("cannot transition to the next status [current_status=%s]: %w", m.Invoice.Status, err) + } + } +} + +func (m *InvoiceStateMachine) CanFire(ctx context.Context, trigger stateless.Trigger) (bool, error) { + return m.StateMachine.CanFireCtx(ctx, trigger) +} + +// FireAndActivate fires the trigger and activates the new state, if activation fails it automatically +// transitions to the failed state and activates that. +func (m *InvoiceStateMachine) FireAndActivate(ctx context.Context, trigger stateless.Trigger) error { + if err := m.StateMachine.FireCtx(ctx, trigger); err != nil { + return err + } + + err := m.StateMachine.ActivateCtx(ctx) + if err != nil { + // There was an error activating the state, we should trigger a transition to the failed state + activationError := err + + // TODO[later]: depending on the final implementation, we might want to make this a special error + // that signals that the invoice is in an inconsistent state + if err := m.StateMachine.FireCtx(ctx, triggerFailed); err != nil { + return fmt.Errorf("failed to transition to failed state: %w", err) + } + + if err := m.StateMachine.ActivateCtx(ctx); err != nil { + return fmt.Errorf("failed to activate failed state: %w", err) + } + + return activationError + } + + return nil +} + +// validateDraftInvoice validates the draft invoice using the apps referenced in the invoice. +func (m *InvoiceStateMachine) validateDraftInvoice(ctx context.Context) error { + return nil +} + +// syncDraftInvoice syncs the draft invoice with the external system. +func (m *InvoiceStateMachine) syncDraftInvoice(ctx context.Context) error { + return nil +} + +// issueInvoice issues the invoice using the invoicing app +func (m *InvoiceStateMachine) issueInvoice(ctx context.Context) error { + return nil +} + +func (m *InvoiceStateMachine) isAutoAdvanceEnabled() bool { + return m.Invoice.Workflow.Config.Invoicing.AutoAdvance +} + +func (m *InvoiceStateMachine) shouldAutoAdvance() bool { + if !m.isAutoAdvanceEnabled() || m.Invoice.DraftUntil == nil { + return false + } + + return !clock.Now().In(time.UTC).Before(*m.Invoice.DraftUntil) +} + +func boolFn(fn func() bool) func(context.Context, ...any) bool { + return func(context.Context, ...any) bool { + return fn() + } +} + +func not(fn func() bool) func() bool { + return func() bool { + return !fn() + } +} diff --git a/openmeter/billing/service/profile.go b/openmeter/billing/service/profile.go index 87ede35d5..086841a9d 100644 --- a/openmeter/billing/service/profile.go +++ b/openmeter/billing/service/profile.go @@ -19,8 +19,6 @@ import ( var _ billing.ProfileService = (*Service)(nil) func (s *Service) CreateProfile(ctx context.Context, input billing.CreateProfileInput) (*billingentity.Profile, error) { - input = input.WithDefaults() - if err := input.Validate(); err != nil { return nil, billing.ValidationError{ Err: err, diff --git a/openmeter/ent/db/billinginvoice.go b/openmeter/ent/db/billinginvoice.go index d0e45241d..555198de2 100644 --- a/openmeter/ent/db/billinginvoice.go +++ b/openmeter/ent/db/billinginvoice.go @@ -86,6 +86,8 @@ type BillingInvoice struct { VoidedAt *time.Time `json:"voided_at,omitempty"` // IssuedAt holds the value of the "issued_at" field. IssuedAt *time.Time `json:"issued_at,omitempty"` + // DraftUntil holds the value of the "draft_until" field. + DraftUntil *time.Time `json:"draft_until,omitempty"` // Currency holds the value of the "currency" field. Currency currencyx.Code `json:"currency,omitempty"` // DueAt holds the value of the "due_at" field. @@ -215,7 +217,7 @@ func (*BillingInvoice) scanValues(columns []string) ([]any, error) { values[i] = new([]byte) case billinginvoice.FieldID, billinginvoice.FieldNamespace, billinginvoice.FieldSupplierAddressCountry, billinginvoice.FieldSupplierAddressPostalCode, billinginvoice.FieldSupplierAddressState, billinginvoice.FieldSupplierAddressCity, billinginvoice.FieldSupplierAddressLine1, billinginvoice.FieldSupplierAddressLine2, billinginvoice.FieldSupplierAddressPhoneNumber, billinginvoice.FieldCustomerAddressCountry, billinginvoice.FieldCustomerAddressPostalCode, billinginvoice.FieldCustomerAddressState, billinginvoice.FieldCustomerAddressCity, billinginvoice.FieldCustomerAddressLine1, billinginvoice.FieldCustomerAddressLine2, billinginvoice.FieldCustomerAddressPhoneNumber, billinginvoice.FieldSupplierName, billinginvoice.FieldSupplierTaxCode, billinginvoice.FieldCustomerName, billinginvoice.FieldCustomerTimezone, billinginvoice.FieldNumber, billinginvoice.FieldType, billinginvoice.FieldDescription, billinginvoice.FieldCustomerID, billinginvoice.FieldSourceBillingProfileID, billinginvoice.FieldCurrency, billinginvoice.FieldStatus, billinginvoice.FieldWorkflowConfigID, billinginvoice.FieldTaxAppID, billinginvoice.FieldInvoicingAppID, billinginvoice.FieldPaymentAppID: values[i] = new(sql.NullString) - case billinginvoice.FieldCreatedAt, billinginvoice.FieldUpdatedAt, billinginvoice.FieldDeletedAt, billinginvoice.FieldVoidedAt, billinginvoice.FieldIssuedAt, billinginvoice.FieldDueAt, billinginvoice.FieldPeriodStart, billinginvoice.FieldPeriodEnd: + case billinginvoice.FieldCreatedAt, billinginvoice.FieldUpdatedAt, billinginvoice.FieldDeletedAt, billinginvoice.FieldVoidedAt, billinginvoice.FieldIssuedAt, billinginvoice.FieldDraftUntil, billinginvoice.FieldDueAt, billinginvoice.FieldPeriodStart, billinginvoice.FieldPeriodEnd: values[i] = new(sql.NullTime) default: values[i] = new(sql.UnknownType) @@ -441,6 +443,13 @@ func (bi *BillingInvoice) assignValues(columns []string, values []any) error { bi.IssuedAt = new(time.Time) *bi.IssuedAt = value.Time } + case billinginvoice.FieldDraftUntil: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field draft_until", values[i]) + } else if value.Valid { + bi.DraftUntil = new(time.Time) + *bi.DraftUntil = value.Time + } case billinginvoice.FieldCurrency: if value, ok := values[i].(*sql.NullString); !ok { return fmt.Errorf("unexpected type %T for field currency", values[i]) @@ -701,6 +710,11 @@ func (bi *BillingInvoice) String() string { builder.WriteString(v.Format(time.ANSIC)) } builder.WriteString(", ") + if v := bi.DraftUntil; v != nil { + builder.WriteString("draft_until=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") builder.WriteString("currency=") builder.WriteString(fmt.Sprintf("%v", bi.Currency)) builder.WriteString(", ") diff --git a/openmeter/ent/db/billinginvoice/billinginvoice.go b/openmeter/ent/db/billinginvoice/billinginvoice.go index 641b9ddd1..0b41e6dcb 100644 --- a/openmeter/ent/db/billinginvoice/billinginvoice.go +++ b/openmeter/ent/db/billinginvoice/billinginvoice.go @@ -76,6 +76,8 @@ const ( FieldVoidedAt = "voided_at" // FieldIssuedAt holds the string denoting the issued_at field in the database. FieldIssuedAt = "issued_at" + // FieldDraftUntil holds the string denoting the draft_until field in the database. + FieldDraftUntil = "draft_until" // FieldCurrency holds the string denoting the currency field in the database. FieldCurrency = "currency" // FieldDueAt holds the string denoting the due_at field in the database. @@ -194,6 +196,7 @@ var Columns = []string{ FieldSourceBillingProfileID, FieldVoidedAt, FieldIssuedAt, + FieldDraftUntil, FieldCurrency, FieldDueAt, FieldStatus, @@ -255,7 +258,7 @@ func TypeValidator(_type billingentity.InvoiceType) error { // StatusValidator is a validator for the "status" field enum values. It is called by the builders before save. func StatusValidator(s billingentity.InvoiceStatus) error { switch s { - case "gathering", "created", "draft", "draft_sync", "draft_sync_failed", "issuing", "issued", "issuing_failed", "manual_approval_needed": + case "gathering", "draft_created", "draft_manual_approval_needed", "draft_validating", "draft_invalid", "draft_syncing", "draft_sync_failed", "draft_waiting_auto_approval", "draft_ready_to_issue", "issuing_syncing", "issuing_sync_failed", "issued": return nil default: return fmt.Errorf("billinginvoice: invalid enum value for status field: %q", s) @@ -415,6 +418,11 @@ func ByIssuedAt(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldIssuedAt, opts...).ToFunc() } +// ByDraftUntil orders the results by the draft_until field. +func ByDraftUntil(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDraftUntil, opts...).ToFunc() +} + // ByCurrency orders the results by the currency field. func ByCurrency(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldCurrency, opts...).ToFunc() diff --git a/openmeter/ent/db/billinginvoice/where.go b/openmeter/ent/db/billinginvoice/where.go index 4f5bee357..8c733df1e 100644 --- a/openmeter/ent/db/billinginvoice/where.go +++ b/openmeter/ent/db/billinginvoice/where.go @@ -212,6 +212,11 @@ func IssuedAt(v time.Time) predicate.BillingInvoice { return predicate.BillingInvoice(sql.FieldEQ(FieldIssuedAt, v)) } +// DraftUntil applies equality check predicate on the "draft_until" field. It's identical to DraftUntilEQ. +func DraftUntil(v time.Time) predicate.BillingInvoice { + return predicate.BillingInvoice(sql.FieldEQ(FieldDraftUntil, v)) +} + // Currency applies equality check predicate on the "currency" field. It's identical to CurrencyEQ. func Currency(v currencyx.Code) predicate.BillingInvoice { vc := string(v) @@ -2255,6 +2260,56 @@ func IssuedAtNotNil() predicate.BillingInvoice { return predicate.BillingInvoice(sql.FieldNotNull(FieldIssuedAt)) } +// DraftUntilEQ applies the EQ predicate on the "draft_until" field. +func DraftUntilEQ(v time.Time) predicate.BillingInvoice { + return predicate.BillingInvoice(sql.FieldEQ(FieldDraftUntil, v)) +} + +// DraftUntilNEQ applies the NEQ predicate on the "draft_until" field. +func DraftUntilNEQ(v time.Time) predicate.BillingInvoice { + return predicate.BillingInvoice(sql.FieldNEQ(FieldDraftUntil, v)) +} + +// DraftUntilIn applies the In predicate on the "draft_until" field. +func DraftUntilIn(vs ...time.Time) predicate.BillingInvoice { + return predicate.BillingInvoice(sql.FieldIn(FieldDraftUntil, vs...)) +} + +// DraftUntilNotIn applies the NotIn predicate on the "draft_until" field. +func DraftUntilNotIn(vs ...time.Time) predicate.BillingInvoice { + return predicate.BillingInvoice(sql.FieldNotIn(FieldDraftUntil, vs...)) +} + +// DraftUntilGT applies the GT predicate on the "draft_until" field. +func DraftUntilGT(v time.Time) predicate.BillingInvoice { + return predicate.BillingInvoice(sql.FieldGT(FieldDraftUntil, v)) +} + +// DraftUntilGTE applies the GTE predicate on the "draft_until" field. +func DraftUntilGTE(v time.Time) predicate.BillingInvoice { + return predicate.BillingInvoice(sql.FieldGTE(FieldDraftUntil, v)) +} + +// DraftUntilLT applies the LT predicate on the "draft_until" field. +func DraftUntilLT(v time.Time) predicate.BillingInvoice { + return predicate.BillingInvoice(sql.FieldLT(FieldDraftUntil, v)) +} + +// DraftUntilLTE applies the LTE predicate on the "draft_until" field. +func DraftUntilLTE(v time.Time) predicate.BillingInvoice { + return predicate.BillingInvoice(sql.FieldLTE(FieldDraftUntil, v)) +} + +// DraftUntilIsNil applies the IsNil predicate on the "draft_until" field. +func DraftUntilIsNil() predicate.BillingInvoice { + return predicate.BillingInvoice(sql.FieldIsNull(FieldDraftUntil)) +} + +// DraftUntilNotNil applies the NotNil predicate on the "draft_until" field. +func DraftUntilNotNil() predicate.BillingInvoice { + return predicate.BillingInvoice(sql.FieldNotNull(FieldDraftUntil)) +} + // CurrencyEQ applies the EQ predicate on the "currency" field. func CurrencyEQ(v currencyx.Code) predicate.BillingInvoice { vc := string(v) diff --git a/openmeter/ent/db/billinginvoice_create.go b/openmeter/ent/db/billinginvoice_create.go index d2cb6382c..64bbc3dbf 100644 --- a/openmeter/ent/db/billinginvoice_create.go +++ b/openmeter/ent/db/billinginvoice_create.go @@ -396,6 +396,20 @@ func (bic *BillingInvoiceCreate) SetNillableIssuedAt(t *time.Time) *BillingInvoi return bic } +// SetDraftUntil sets the "draft_until" field. +func (bic *BillingInvoiceCreate) SetDraftUntil(t time.Time) *BillingInvoiceCreate { + bic.mutation.SetDraftUntil(t) + return bic +} + +// SetNillableDraftUntil sets the "draft_until" field if the given value is not nil. +func (bic *BillingInvoiceCreate) SetNillableDraftUntil(t *time.Time) *BillingInvoiceCreate { + if t != nil { + bic.SetDraftUntil(*t) + } + return bic +} + // SetCurrency sets the "currency" field. func (bic *BillingInvoiceCreate) SetCurrency(c currencyx.Code) *BillingInvoiceCreate { bic.mutation.SetCurrency(c) @@ -859,6 +873,10 @@ func (bic *BillingInvoiceCreate) createSpec() (*BillingInvoice, *sqlgraph.Create _spec.SetField(billinginvoice.FieldIssuedAt, field.TypeTime, value) _node.IssuedAt = &value } + if value, ok := bic.mutation.DraftUntil(); ok { + _spec.SetField(billinginvoice.FieldDraftUntil, field.TypeTime, value) + _node.DraftUntil = &value + } if value, ok := bic.mutation.Currency(); ok { _spec.SetField(billinginvoice.FieldCurrency, field.TypeString, value) _node.Currency = value @@ -1493,6 +1511,24 @@ func (u *BillingInvoiceUpsert) ClearIssuedAt() *BillingInvoiceUpsert { return u } +// SetDraftUntil sets the "draft_until" field. +func (u *BillingInvoiceUpsert) SetDraftUntil(v time.Time) *BillingInvoiceUpsert { + u.Set(billinginvoice.FieldDraftUntil, v) + return u +} + +// UpdateDraftUntil sets the "draft_until" field to the value that was provided on create. +func (u *BillingInvoiceUpsert) UpdateDraftUntil() *BillingInvoiceUpsert { + u.SetExcluded(billinginvoice.FieldDraftUntil) + return u +} + +// ClearDraftUntil clears the value of the "draft_until" field. +func (u *BillingInvoiceUpsert) ClearDraftUntil() *BillingInvoiceUpsert { + u.SetNull(billinginvoice.FieldDraftUntil) + return u +} + // SetDueAt sets the "due_at" field. func (u *BillingInvoiceUpsert) SetDueAt(v time.Time) *BillingInvoiceUpsert { u.Set(billinginvoice.FieldDueAt, v) @@ -2161,6 +2197,27 @@ func (u *BillingInvoiceUpsertOne) ClearIssuedAt() *BillingInvoiceUpsertOne { }) } +// SetDraftUntil sets the "draft_until" field. +func (u *BillingInvoiceUpsertOne) SetDraftUntil(v time.Time) *BillingInvoiceUpsertOne { + return u.Update(func(s *BillingInvoiceUpsert) { + s.SetDraftUntil(v) + }) +} + +// UpdateDraftUntil sets the "draft_until" field to the value that was provided on create. +func (u *BillingInvoiceUpsertOne) UpdateDraftUntil() *BillingInvoiceUpsertOne { + return u.Update(func(s *BillingInvoiceUpsert) { + s.UpdateDraftUntil() + }) +} + +// ClearDraftUntil clears the value of the "draft_until" field. +func (u *BillingInvoiceUpsertOne) ClearDraftUntil() *BillingInvoiceUpsertOne { + return u.Update(func(s *BillingInvoiceUpsert) { + s.ClearDraftUntil() + }) +} + // SetDueAt sets the "due_at" field. func (u *BillingInvoiceUpsertOne) SetDueAt(v time.Time) *BillingInvoiceUpsertOne { return u.Update(func(s *BillingInvoiceUpsert) { @@ -3009,6 +3066,27 @@ func (u *BillingInvoiceUpsertBulk) ClearIssuedAt() *BillingInvoiceUpsertBulk { }) } +// SetDraftUntil sets the "draft_until" field. +func (u *BillingInvoiceUpsertBulk) SetDraftUntil(v time.Time) *BillingInvoiceUpsertBulk { + return u.Update(func(s *BillingInvoiceUpsert) { + s.SetDraftUntil(v) + }) +} + +// UpdateDraftUntil sets the "draft_until" field to the value that was provided on create. +func (u *BillingInvoiceUpsertBulk) UpdateDraftUntil() *BillingInvoiceUpsertBulk { + return u.Update(func(s *BillingInvoiceUpsert) { + s.UpdateDraftUntil() + }) +} + +// ClearDraftUntil clears the value of the "draft_until" field. +func (u *BillingInvoiceUpsertBulk) ClearDraftUntil() *BillingInvoiceUpsertBulk { + return u.Update(func(s *BillingInvoiceUpsert) { + s.ClearDraftUntil() + }) +} + // SetDueAt sets the "due_at" field. func (u *BillingInvoiceUpsertBulk) SetDueAt(v time.Time) *BillingInvoiceUpsertBulk { return u.Update(func(s *BillingInvoiceUpsert) { diff --git a/openmeter/ent/db/billinginvoice_update.go b/openmeter/ent/db/billinginvoice_update.go index 1c5e57d18..901b39414 100644 --- a/openmeter/ent/db/billinginvoice_update.go +++ b/openmeter/ent/db/billinginvoice_update.go @@ -513,6 +513,26 @@ func (biu *BillingInvoiceUpdate) ClearIssuedAt() *BillingInvoiceUpdate { return biu } +// SetDraftUntil sets the "draft_until" field. +func (biu *BillingInvoiceUpdate) SetDraftUntil(t time.Time) *BillingInvoiceUpdate { + biu.mutation.SetDraftUntil(t) + return biu +} + +// SetNillableDraftUntil sets the "draft_until" field if the given value is not nil. +func (biu *BillingInvoiceUpdate) SetNillableDraftUntil(t *time.Time) *BillingInvoiceUpdate { + if t != nil { + biu.SetDraftUntil(*t) + } + return biu +} + +// ClearDraftUntil clears the value of the "draft_until" field. +func (biu *BillingInvoiceUpdate) ClearDraftUntil() *BillingInvoiceUpdate { + biu.mutation.ClearDraftUntil() + return biu +} + // SetDueAt sets the "due_at" field. func (biu *BillingInvoiceUpdate) SetDueAt(t time.Time) *BillingInvoiceUpdate { biu.mutation.SetDueAt(t) @@ -909,6 +929,12 @@ func (biu *BillingInvoiceUpdate) sqlSave(ctx context.Context) (n int, err error) if biu.mutation.IssuedAtCleared() { _spec.ClearField(billinginvoice.FieldIssuedAt, field.TypeTime) } + if value, ok := biu.mutation.DraftUntil(); ok { + _spec.SetField(billinginvoice.FieldDraftUntil, field.TypeTime, value) + } + if biu.mutation.DraftUntilCleared() { + _spec.ClearField(billinginvoice.FieldDraftUntil, field.TypeTime) + } if value, ok := biu.mutation.DueAt(); ok { _spec.SetField(billinginvoice.FieldDueAt, field.TypeTime, value) } @@ -1504,6 +1530,26 @@ func (biuo *BillingInvoiceUpdateOne) ClearIssuedAt() *BillingInvoiceUpdateOne { return biuo } +// SetDraftUntil sets the "draft_until" field. +func (biuo *BillingInvoiceUpdateOne) SetDraftUntil(t time.Time) *BillingInvoiceUpdateOne { + biuo.mutation.SetDraftUntil(t) + return biuo +} + +// SetNillableDraftUntil sets the "draft_until" field if the given value is not nil. +func (biuo *BillingInvoiceUpdateOne) SetNillableDraftUntil(t *time.Time) *BillingInvoiceUpdateOne { + if t != nil { + biuo.SetDraftUntil(*t) + } + return biuo +} + +// ClearDraftUntil clears the value of the "draft_until" field. +func (biuo *BillingInvoiceUpdateOne) ClearDraftUntil() *BillingInvoiceUpdateOne { + biuo.mutation.ClearDraftUntil() + return biuo +} + // SetDueAt sets the "due_at" field. func (biuo *BillingInvoiceUpdateOne) SetDueAt(t time.Time) *BillingInvoiceUpdateOne { biuo.mutation.SetDueAt(t) @@ -1930,6 +1976,12 @@ func (biuo *BillingInvoiceUpdateOne) sqlSave(ctx context.Context) (_node *Billin if biuo.mutation.IssuedAtCleared() { _spec.ClearField(billinginvoice.FieldIssuedAt, field.TypeTime) } + if value, ok := biuo.mutation.DraftUntil(); ok { + _spec.SetField(billinginvoice.FieldDraftUntil, field.TypeTime, value) + } + if biuo.mutation.DraftUntilCleared() { + _spec.ClearField(billinginvoice.FieldDraftUntil, field.TypeTime) + } if value, ok := biuo.mutation.DueAt(); ok { _spec.SetField(billinginvoice.FieldDueAt, field.TypeTime, value) } diff --git a/openmeter/ent/db/migrate/schema.go b/openmeter/ent/db/migrate/schema.go index 0c82b435d..800b7a0eb 100644 --- a/openmeter/ent/db/migrate/schema.go +++ b/openmeter/ent/db/migrate/schema.go @@ -327,9 +327,10 @@ var ( {Name: "description", Type: field.TypeString, Nullable: true}, {Name: "voided_at", Type: field.TypeTime, Nullable: true}, {Name: "issued_at", Type: field.TypeTime, Nullable: true}, + {Name: "draft_until", Type: field.TypeTime, Nullable: true}, {Name: "currency", Type: field.TypeString, SchemaType: map[string]string{"postgres": "varchar(3)"}}, {Name: "due_at", Type: field.TypeTime, Nullable: true}, - {Name: "status", Type: field.TypeEnum, Enums: []string{"gathering", "created", "draft", "draft_sync", "draft_sync_failed", "issuing", "issued", "issuing_failed", "manual_approval_needed"}}, + {Name: "status", Type: field.TypeEnum, Enums: []string{"gathering", "draft_created", "draft_manual_approval_needed", "draft_validating", "draft_invalid", "draft_syncing", "draft_sync_failed", "draft_waiting_auto_approval", "draft_ready_to_issue", "issuing_syncing", "issuing_sync_failed", "issued"}}, {Name: "period_start", Type: field.TypeTime, Nullable: true}, {Name: "period_end", Type: field.TypeTime, Nullable: true}, {Name: "tax_app_id", Type: field.TypeString, SchemaType: map[string]string{"postgres": "char(26)"}}, @@ -347,37 +348,37 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "billing_invoices_apps_billing_invoice_tax_app", - Columns: []*schema.Column{BillingInvoicesColumns[34]}, + Columns: []*schema.Column{BillingInvoicesColumns[35]}, RefColumns: []*schema.Column{AppsColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "billing_invoices_apps_billing_invoice_invoicing_app", - Columns: []*schema.Column{BillingInvoicesColumns[35]}, + Columns: []*schema.Column{BillingInvoicesColumns[36]}, RefColumns: []*schema.Column{AppsColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "billing_invoices_apps_billing_invoice_payment_app", - Columns: []*schema.Column{BillingInvoicesColumns[36]}, + Columns: []*schema.Column{BillingInvoicesColumns[37]}, RefColumns: []*schema.Column{AppsColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "billing_invoices_billing_profiles_billing_invoices", - Columns: []*schema.Column{BillingInvoicesColumns[37]}, + Columns: []*schema.Column{BillingInvoicesColumns[38]}, RefColumns: []*schema.Column{BillingProfilesColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "billing_invoices_billing_workflow_configs_billing_invoices", - Columns: []*schema.Column{BillingInvoicesColumns[38]}, + Columns: []*schema.Column{BillingInvoicesColumns[39]}, RefColumns: []*schema.Column{BillingWorkflowConfigsColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "billing_invoices_customers_billing_invoice", - Columns: []*schema.Column{BillingInvoicesColumns[39]}, + Columns: []*schema.Column{BillingInvoicesColumns[40]}, RefColumns: []*schema.Column{CustomersColumns[0]}, OnDelete: schema.NoAction, }, @@ -401,7 +402,7 @@ var ( { Name: "billinginvoice_namespace_customer_id", Unique: false, - Columns: []*schema.Column{BillingInvoicesColumns[1], BillingInvoicesColumns[39]}, + Columns: []*schema.Column{BillingInvoicesColumns[1], BillingInvoicesColumns[40]}, }, }, } diff --git a/openmeter/ent/db/mutation.go b/openmeter/ent/db/mutation.go index 651341479..67b5215e0 100644 --- a/openmeter/ent/db/mutation.go +++ b/openmeter/ent/db/mutation.go @@ -6192,6 +6192,7 @@ type BillingInvoiceMutation struct { description *string voided_at *time.Time issued_at *time.Time + draft_until *time.Time currency *currencyx.Code due_at *time.Time status *billingentity.InvoiceStatus @@ -7688,6 +7689,55 @@ func (m *BillingInvoiceMutation) ResetIssuedAt() { delete(m.clearedFields, billinginvoice.FieldIssuedAt) } +// SetDraftUntil sets the "draft_until" field. +func (m *BillingInvoiceMutation) SetDraftUntil(t time.Time) { + m.draft_until = &t +} + +// DraftUntil returns the value of the "draft_until" field in the mutation. +func (m *BillingInvoiceMutation) DraftUntil() (r time.Time, exists bool) { + v := m.draft_until + if v == nil { + return + } + return *v, true +} + +// OldDraftUntil returns the old "draft_until" field's value of the BillingInvoice entity. +// If the BillingInvoice object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *BillingInvoiceMutation) OldDraftUntil(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDraftUntil is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDraftUntil requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDraftUntil: %w", err) + } + return oldValue.DraftUntil, nil +} + +// ClearDraftUntil clears the value of the "draft_until" field. +func (m *BillingInvoiceMutation) ClearDraftUntil() { + m.draft_until = nil + m.clearedFields[billinginvoice.FieldDraftUntil] = struct{}{} +} + +// DraftUntilCleared returns if the "draft_until" field was cleared in this mutation. +func (m *BillingInvoiceMutation) DraftUntilCleared() bool { + _, ok := m.clearedFields[billinginvoice.FieldDraftUntil] + return ok +} + +// ResetDraftUntil resets all changes to the "draft_until" field. +func (m *BillingInvoiceMutation) ResetDraftUntil() { + m.draft_until = nil + delete(m.clearedFields, billinginvoice.FieldDraftUntil) +} + // SetCurrency sets the "currency" field. func (m *BillingInvoiceMutation) SetCurrency(c currencyx.Code) { m.currency = &c @@ -8327,7 +8377,7 @@ func (m *BillingInvoiceMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *BillingInvoiceMutation) Fields() []string { - fields := make([]string, 0, 39) + fields := make([]string, 0, 40) if m.namespace != nil { fields = append(fields, billinginvoice.FieldNamespace) } @@ -8418,6 +8468,9 @@ func (m *BillingInvoiceMutation) Fields() []string { if m.issued_at != nil { fields = append(fields, billinginvoice.FieldIssuedAt) } + if m.draft_until != nil { + fields = append(fields, billinginvoice.FieldDraftUntil) + } if m.currency != nil { fields = append(fields, billinginvoice.FieldCurrency) } @@ -8513,6 +8566,8 @@ func (m *BillingInvoiceMutation) Field(name string) (ent.Value, bool) { return m.VoidedAt() case billinginvoice.FieldIssuedAt: return m.IssuedAt() + case billinginvoice.FieldDraftUntil: + return m.DraftUntil() case billinginvoice.FieldCurrency: return m.Currency() case billinginvoice.FieldDueAt: @@ -8600,6 +8655,8 @@ func (m *BillingInvoiceMutation) OldField(ctx context.Context, name string) (ent return m.OldVoidedAt(ctx) case billinginvoice.FieldIssuedAt: return m.OldIssuedAt(ctx) + case billinginvoice.FieldDraftUntil: + return m.OldDraftUntil(ctx) case billinginvoice.FieldCurrency: return m.OldCurrency(ctx) case billinginvoice.FieldDueAt: @@ -8837,6 +8894,13 @@ func (m *BillingInvoiceMutation) SetField(name string, value ent.Value) error { } m.SetIssuedAt(v) return nil + case billinginvoice.FieldDraftUntil: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDraftUntil(v) + return nil case billinginvoice.FieldCurrency: v, ok := value.(currencyx.Code) if !ok { @@ -8996,6 +9060,9 @@ func (m *BillingInvoiceMutation) ClearedFields() []string { if m.FieldCleared(billinginvoice.FieldIssuedAt) { fields = append(fields, billinginvoice.FieldIssuedAt) } + if m.FieldCleared(billinginvoice.FieldDraftUntil) { + fields = append(fields, billinginvoice.FieldDraftUntil) + } if m.FieldCleared(billinginvoice.FieldDueAt) { fields = append(fields, billinginvoice.FieldDueAt) } @@ -9085,6 +9152,9 @@ func (m *BillingInvoiceMutation) ClearField(name string) error { case billinginvoice.FieldIssuedAt: m.ClearIssuedAt() return nil + case billinginvoice.FieldDraftUntil: + m.ClearDraftUntil() + return nil case billinginvoice.FieldDueAt: m.ClearDueAt() return nil @@ -9192,6 +9262,9 @@ func (m *BillingInvoiceMutation) ResetField(name string) error { case billinginvoice.FieldIssuedAt: m.ResetIssuedAt() return nil + case billinginvoice.FieldDraftUntil: + m.ResetDraftUntil() + return nil case billinginvoice.FieldCurrency: m.ResetCurrency() return nil diff --git a/openmeter/ent/db/runtime.go b/openmeter/ent/db/runtime.go index 7845c1c5b..82567cc5a 100644 --- a/openmeter/ent/db/runtime.go +++ b/openmeter/ent/db/runtime.go @@ -292,7 +292,7 @@ func init() { // billinginvoice.SourceBillingProfileIDValidator is a validator for the "source_billing_profile_id" field. It is called by the builders before save. billinginvoice.SourceBillingProfileIDValidator = billinginvoiceDescSourceBillingProfileID.Validators[0].(func(string) error) // billinginvoiceDescCurrency is the schema descriptor for currency field. - billinginvoiceDescCurrency := billinginvoiceFields[11].Descriptor() + billinginvoiceDescCurrency := billinginvoiceFields[12].Descriptor() // billinginvoice.CurrencyValidator is a validator for the "currency" field. It is called by the builders before save. billinginvoice.CurrencyValidator = billinginvoiceDescCurrency.Validators[0].(func(string) error) // billinginvoiceDescID is the schema descriptor for id field. diff --git a/openmeter/ent/db/setorclear.go b/openmeter/ent/db/setorclear.go index 4bd88a69a..051c17c5e 100644 --- a/openmeter/ent/db/setorclear.go +++ b/openmeter/ent/db/setorclear.go @@ -545,6 +545,20 @@ func (u *BillingInvoiceUpdateOne) SetOrClearIssuedAt(value *time.Time) *BillingI return u.SetIssuedAt(*value) } +func (u *BillingInvoiceUpdate) SetOrClearDraftUntil(value *time.Time) *BillingInvoiceUpdate { + if value == nil { + return u.ClearDraftUntil() + } + return u.SetDraftUntil(*value) +} + +func (u *BillingInvoiceUpdateOne) SetOrClearDraftUntil(value *time.Time) *BillingInvoiceUpdateOne { + if value == nil { + return u.ClearDraftUntil() + } + return u.SetDraftUntil(*value) +} + func (u *BillingInvoiceUpdate) SetOrClearDueAt(value *time.Time) *BillingInvoiceUpdate { if value == nil { return u.ClearDueAt() diff --git a/openmeter/ent/schema/billing.go b/openmeter/ent/schema/billing.go index 9d41b479b..a558742f1 100644 --- a/openmeter/ent/schema/billing.go +++ b/openmeter/ent/schema/billing.go @@ -402,6 +402,10 @@ func (BillingInvoice) Fields() []ent.Field { Optional(). Nillable(), + field.Time("draft_until"). + Optional(). + Nillable(), + field.String("currency"). GoType(currencyx.Code("")). NotEmpty(). diff --git a/test/billing/customeroverride_test.go b/test/billing/customeroverride_test.go index 423442b4c..7ed4ca599 100644 --- a/test/billing/customeroverride_test.go +++ b/test/billing/customeroverride_test.go @@ -135,7 +135,7 @@ func (s *CustomerOverrideTestSuite) TestDefaultProfileHandling() { wfConfig := customerProfile.Profile.WorkflowConfig require.Equal(t, wfConfig.Collection.Interval, datex.MustParse(t, "PT1H")) - require.Equal(t, *wfConfig.Invoicing.AutoAdvance, false) + require.Equal(t, wfConfig.Invoicing.AutoAdvance, false) require.Equal(t, wfConfig.Invoicing.DraftPeriod, datex.MustParse(t, "PT2H")) require.Equal(t, wfConfig.Invoicing.DueAfter, datex.MustParse(t, "PT3H")) require.Equal(t, wfConfig.Payment.CollectionMethod, billingentity.CollectionMethodSendInvoice) @@ -210,9 +210,9 @@ func (s *CustomerOverrideTestSuite) TestPinnedProfileHandling() { wfConfig := customerProfile.Profile.WorkflowConfig require.Equal(t, wfConfig.Collection.Interval, datex.MustParse(s.T(), "PT1H")) - require.Equal(t, *wfConfig.Invoicing.AutoAdvance, true) - require.Equal(t, wfConfig.Invoicing.DraftPeriod, billingentity.DefaultWorkflowConfig.Invoicing.DraftPeriod) - require.Equal(t, wfConfig.Invoicing.DueAfter, billingentity.DefaultWorkflowConfig.Invoicing.DueAfter) + require.Equal(t, wfConfig.Invoicing.AutoAdvance, true) + require.Equal(t, wfConfig.Invoicing.DraftPeriod, lo.Must(datex.ISOString("P1D").Parse())) + require.Equal(t, wfConfig.Invoicing.DueAfter, lo.Must(datex.ISOString("P1W").Parse())) require.Equal(t, wfConfig.Payment.CollectionMethod, billingentity.CollectionMethodChargeAutomatically) }) } diff --git a/test/billing/invoice_test.go b/test/billing/invoice_test.go index 70cdcc8a2..2f317e082 100644 --- a/test/billing/invoice_test.go +++ b/test/billing/invoice_test.go @@ -2,6 +2,8 @@ package billing_test import ( "context" + "errors" + "fmt" "testing" "time" @@ -15,6 +17,7 @@ import ( billingentity "github.com/openmeterio/openmeter/openmeter/billing/entity" customerentity "github.com/openmeterio/openmeter/openmeter/customer/entity" "github.com/openmeterio/openmeter/pkg/currencyx" + "github.com/openmeterio/openmeter/pkg/datex" "github.com/openmeterio/openmeter/pkg/models" "github.com/openmeterio/openmeter/pkg/pagination" ) @@ -141,11 +144,11 @@ func (s *InvoicingTestSuite) TestPendingLineCreation() { PageSize: 10, }, - Namespace: namespace, - Customers: []string{customerEntity.ID}, - Expand: billing.InvoiceExpandAll, - Statuses: []billingentity.InvoiceStatus{billingentity.InvoiceStatusGathering}, - Currencies: []currencyx.Code{currencyx.Code(currency.USD)}, + Namespace: namespace, + Customers: []string{customerEntity.ID}, + Expand: billingentity.InvoiceExpandAll, + ExtendedStatuses: []billingentity.InvoiceStatus{billingentity.InvoiceStatusGathering}, + Currencies: []currencyx.Code{currencyx.Code(currency.USD)}, }) require.NoError(s.T(), err) require.Len(s.T(), usdInvoices.Items, 1) @@ -179,7 +182,7 @@ func (s *InvoicingTestSuite) TestPendingLineCreation() { }, } // Let's make sure that the workflow config is cloned - require.NotEqual(s.T(), usdInvoice.Workflow.WorkflowConfig.ID, billingProfile.WorkflowConfig.ID) + require.NotEqual(s.T(), usdInvoice.Workflow.Config.ID, billingProfile.WorkflowConfig.ID) require.Equal(s.T(), usdInvoice, billingentity.Invoice{ Namespace: namespace, @@ -188,15 +191,18 @@ func (s *InvoicingTestSuite) TestPendingLineCreation() { Type: billingentity.InvoiceTypeStandard, Currency: currencyx.Code(currency.USD), Status: billingentity.InvoiceStatusGathering, + StatusDetails: billingentity.InvoiceStatusDetails{ + AvailableActions: []billingentity.InvoiceAction{}, + }, CreatedAt: usdInvoice.CreatedAt, UpdatedAt: usdInvoice.UpdatedAt, Workflow: &billingentity.InvoiceWorkflow{ - WorkflowConfig: billingentity.WorkflowConfig{ - ID: usdInvoice.Workflow.WorkflowConfig.ID, - CreatedAt: usdInvoice.Workflow.WorkflowConfig.CreatedAt, - UpdatedAt: usdInvoice.Workflow.WorkflowConfig.UpdatedAt, + Config: billingentity.WorkflowConfig{ + ID: usdInvoice.Workflow.Config.ID, + CreatedAt: usdInvoice.Workflow.Config.CreatedAt, + UpdatedAt: usdInvoice.Workflow.Config.UpdatedAt, Timezone: billingProfile.WorkflowConfig.Timezone, Collection: billingProfile.WorkflowConfig.Collection, @@ -217,6 +223,8 @@ func (s *InvoicingTestSuite) TestPendingLineCreation() { Supplier: billingProfile.Supplier, Lines: []billingentity.Line{expectedUSDLine}, + + ExpandedFields: billingentity.InvoiceExpandAll, }) require.Len(s.T(), items, 2) @@ -240,11 +248,11 @@ func (s *InvoicingTestSuite) TestPendingLineCreation() { PageSize: 10, }, - Namespace: namespace, - Customers: []string{customerEntity.ID}, - Expand: billing.InvoiceExpandAll, - Statuses: []billingentity.InvoiceStatus{billingentity.InvoiceStatusGathering}, - Currencies: []currencyx.Code{currencyx.Code(currency.HUF)}, + Namespace: namespace, + Customers: []string{customerEntity.ID}, + Expand: billingentity.InvoiceExpandAll, + ExtendedStatuses: []billingentity.InvoiceStatus{billingentity.InvoiceStatusGathering}, + Currencies: []currencyx.Code{currencyx.Code(currency.HUF)}, }) require.NoError(s.T(), err) require.Len(s.T(), hufInvoices.Items, 1) @@ -260,11 +268,11 @@ func (s *InvoicingTestSuite) TestPendingLineCreation() { PageSize: 10, }, - Namespace: namespace, - Customers: []string{customerEntity.ID}, - Expand: billing.InvoiceExpand{}, - Statuses: []billingentity.InvoiceStatus{billingentity.InvoiceStatusGathering}, - Currencies: []currencyx.Code{currencyx.Code(currency.USD)}, + Namespace: namespace, + Customers: []string{customerEntity.ID}, + Expand: billingentity.InvoiceExpand{}, + ExtendedStatuses: []billingentity.InvoiceStatus{billingentity.InvoiceStatusGathering}, + Currencies: []currencyx.Code{currencyx.Code(currency.USD)}, }) require.NoError(s.T(), err) require.Len(s.T(), invoices.Items, 1) @@ -283,11 +291,11 @@ func (s *InvoicingTestSuite) TestPendingLineCreation() { Namespace: namespace, Customers: []string{customerEntity.ID}, - Expand: billing.InvoiceExpand{ + Expand: billingentity.InvoiceExpand{ Workflow: true, }, - Statuses: []billingentity.InvoiceStatus{billingentity.InvoiceStatusGathering}, - Currencies: []currencyx.Code{currencyx.Code(currency.USD)}, + ExtendedStatuses: []billingentity.InvoiceStatus{billingentity.InvoiceStatusGathering}, + Currencies: []currencyx.Code{currencyx.Code(currency.USD)}, }) require.NoError(s.T(), err) require.Len(s.T(), invoices.Items, 1) @@ -307,12 +315,12 @@ func (s *InvoicingTestSuite) TestPendingLineCreation() { Namespace: namespace, Customers: []string{customerEntity.ID}, - Expand: billing.InvoiceExpand{ + Expand: billingentity.InvoiceExpand{ Workflow: true, WorkflowApps: true, }, - Statuses: []billingentity.InvoiceStatus{billingentity.InvoiceStatusGathering}, - Currencies: []currencyx.Code{currencyx.Code(currency.USD)}, + ExtendedStatuses: []billingentity.InvoiceStatus{billingentity.InvoiceStatusGathering}, + Currencies: []currencyx.Code{currencyx.Code(currency.USD)}, }) require.NoError(s.T(), err) require.Len(s.T(), invoices.Items, 1) @@ -488,7 +496,7 @@ func (s *InvoicingTestSuite) TestCreateInvoice() { // Then we expect that the gathering invoice is still present, with item2 gatheringInvoice, err := s.BillingService.GetInvoiceByID(ctx, billing.GetInvoiceByIdInput{ Invoice: gatheringInvoiceID, - Expand: billing.InvoiceExpandAll, + Expand: billingentity.InvoiceExpandAll, }) require.NoError(s.T(), err) require.Nil(s.T(), gatheringInvoice.DeletedAt, "gathering invoice should be present") @@ -532,10 +540,293 @@ func (s *InvoicingTestSuite) TestCreateInvoice() { // Then we expect that the gathering invoice is deleted and empty gatheringInvoice, err := s.BillingService.GetInvoiceByID(ctx, billing.GetInvoiceByIdInput{ Invoice: gatheringInvoiceID, - Expand: billing.InvoiceExpandAll, + Expand: billingentity.InvoiceExpandAll, }) require.NoError(s.T(), err) require.NotNil(s.T(), gatheringInvoice.DeletedAt, "gathering invoice should be present") require.Len(s.T(), gatheringInvoice.Lines, 0, "deleted gathering invoice is empty") }) } + +type draftInvoiceInput struct { + Namespace string + Customer *customerentity.Customer +} + +func (i draftInvoiceInput) Validate() error { + if i.Namespace == "" { + return errors.New("namespace is required") + } + + if err := i.Customer.Validate(); err != nil { + return err + } + + return nil +} + +func (s *InvoicingTestSuite) createDraftInvoice(t *testing.T, ctx context.Context, in draftInvoiceInput) billingentity.Invoice { + namespace := in.Customer.Namespace + + now := time.Now() + invoiceAt := now.Add(-time.Second) + periodEnd := now.Add(-24 * time.Hour) + periodStart := periodEnd.Add(-24 * 30 * time.Hour) + // Given we have a default profile for the namespace + + res, err := s.BillingService.CreateInvoiceLines(ctx, + billing.CreateInvoiceLinesInput{ + Namespace: in.Customer.Namespace, + CustomerID: in.Customer.ID, + Lines: []billingentity.Line{ + { + LineBase: billingentity.LineBase{ + Namespace: namespace, + Period: billingentity.Period{Start: periodStart, End: periodEnd}, + + InvoiceAt: invoiceAt, + + Type: billingentity.InvoiceLineTypeManualFee, + + Name: "Test item1", + Currency: currencyx.Code(currency.USD), + + Metadata: map[string]string{ + "key": "value", + }, + }, + ManualFee: &billingentity.ManualFeeLine{ + Price: alpacadecimal.NewFromFloat(100), + Quantity: alpacadecimal.NewFromFloat(1), + }, + }, + { + LineBase: billingentity.LineBase{ + Namespace: namespace, + Period: billingentity.Period{Start: periodStart, End: periodEnd}, + + InvoiceAt: invoiceAt, + + Type: billingentity.InvoiceLineTypeManualFee, + + Name: "Test item2", + Currency: currencyx.Code(currency.USD), + }, + ManualFee: &billingentity.ManualFeeLine{ + Price: alpacadecimal.NewFromFloat(200), + Quantity: alpacadecimal.NewFromFloat(3), + }, + }, + }, + }) + + require.NoError(s.T(), err) + require.Len(s.T(), res.Lines, 2) + line1ID := res.Lines[0].ID + line2ID := res.Lines[1].ID + require.NotEmpty(s.T(), line1ID) + require.NotEmpty(s.T(), line2ID) + + invoice, err := s.BillingService.CreateInvoice(ctx, billing.CreateInvoiceInput{ + Customer: customerentity.CustomerID{ + ID: in.Customer.ID, + Namespace: in.Customer.Namespace, + }, + AsOf: lo.ToPtr(now), + }) + + require.NoError(t, err) + require.Len(t, invoice, 1) + require.Len(t, invoice[0].Lines, 2) + + return invoice[0] +} + +func (s *InvoicingTestSuite) TestInvoicingFlowInstantIssue() { + cases := []struct { + name string + workflowConfig billingentity.WorkflowConfig + advance func(t *testing.T, ctx context.Context, invoice billingentity.Invoice) + expectedState billingentity.InvoiceStatus + }{ + { + name: "instant issue", + workflowConfig: billingentity.WorkflowConfig{ + Collection: billingentity.CollectionConfig{ + Alignment: billingentity.AlignmentKindSubscription, + }, + Invoicing: billingentity.InvoicingConfig{ + AutoAdvance: true, + DraftPeriod: lo.Must(datex.ISOString("PT0S").Parse()), + DueAfter: lo.Must(datex.ISOString("P1W").Parse()), + }, + Payment: billingentity.PaymentConfig{ + CollectionMethod: billingentity.CollectionMethodChargeAutomatically, + }, + }, + advance: func(t *testing.T, ctx context.Context, invoice billingentity.Invoice) { + _, err := s.BillingService.AdvanceInvoice(ctx, billing.AdvanceInvoiceInput{ + ID: invoice.ID, + Namespace: invoice.Namespace, + }) + + require.NoError(s.T(), err) + }, + expectedState: billingentity.InvoiceStatusIssued, + }, + { + name: "draft period bypass with manual approve", + workflowConfig: billingentity.WorkflowConfig{ + Collection: billingentity.CollectionConfig{ + Alignment: billingentity.AlignmentKindSubscription, + }, + Invoicing: billingentity.InvoicingConfig{ + AutoAdvance: true, + DraftPeriod: lo.Must(datex.ISOString("PT1H").Parse()), + DueAfter: lo.Must(datex.ISOString("P1W").Parse()), + }, + Payment: billingentity.PaymentConfig{ + CollectionMethod: billingentity.CollectionMethodChargeAutomatically, + }, + }, + advance: func(t *testing.T, ctx context.Context, invoice billingentity.Invoice) { + advancedInvoice, err := s.BillingService.AdvanceInvoice(ctx, billing.AdvanceInvoiceInput{ + ID: invoice.ID, + Namespace: invoice.Namespace, + }) + + require.NoError(s.T(), err) + require.Equal(s.T(), billingentity.InvoiceStatusDraftWaitingAutoApproval, advancedInvoice.Status) + + // Approve the invoice, should become DraftReadyToIssue + advancedInvoice, err = s.BillingService.ApproveInvoice(ctx, billing.ApproveInvoiceInput{ + ID: advancedInvoice.ID, + Namespace: advancedInvoice.Namespace, + }) + + require.NoError(s.T(), err) + require.Equal(s.T(), billingentity.InvoiceStatusDraftReadyToIssue, advancedInvoice.Status) + + // Advance the invoice, should become Issued + advancedInvoice, err = s.BillingService.AdvanceInvoice(ctx, billing.AdvanceInvoiceInput{ + ID: invoice.ID, + Namespace: invoice.Namespace, + }) + + require.NoError(s.T(), err) + require.Equal(s.T(), billingentity.InvoiceStatusIssued, advancedInvoice.Status) + }, + expectedState: billingentity.InvoiceStatusIssued, + }, + { + name: "manual approvement flow", + workflowConfig: billingentity.WorkflowConfig{ + Collection: billingentity.CollectionConfig{ + Alignment: billingentity.AlignmentKindSubscription, + }, + Invoicing: billingentity.InvoicingConfig{ + AutoAdvance: false, + DraftPeriod: lo.Must(datex.ISOString("PT0H").Parse()), + DueAfter: lo.Must(datex.ISOString("P1W").Parse()), + }, + Payment: billingentity.PaymentConfig{ + CollectionMethod: billingentity.CollectionMethodChargeAutomatically, + }, + }, + advance: func(t *testing.T, ctx context.Context, invoice billingentity.Invoice) { + advancedInvoice, err := s.BillingService.AdvanceInvoice(ctx, billing.AdvanceInvoiceInput{ + ID: invoice.ID, + Namespace: invoice.Namespace, + }) + + require.NoError(s.T(), err) + require.Equal(s.T(), billingentity.InvoiceStatusDraftManualApprovalNeeded, advancedInvoice.Status) + require.Equal(s.T(), billingentity.InvoiceStatusDetails{ + AvailableActions: []billingentity.InvoiceAction{billingentity.InvoiceActionApprove}, + }, advancedInvoice.StatusDetails) + + // Approve the invoice, should become DraftReadyToIssue + advancedInvoice, err = s.BillingService.ApproveInvoice(ctx, billing.ApproveInvoiceInput{ + ID: advancedInvoice.ID, + Namespace: advancedInvoice.Namespace, + }) + + require.NoError(s.T(), err) + require.Equal(s.T(), billingentity.InvoiceStatusDraftReadyToIssue, advancedInvoice.Status) + + // Advance the invoice, should become Issued + advancedInvoice, err = s.BillingService.AdvanceInvoice(ctx, billing.AdvanceInvoiceInput{ + ID: invoice.ID, + Namespace: invoice.Namespace, + }) + + require.NoError(s.T(), err) + require.Equal(s.T(), billingentity.InvoiceStatusIssued, advancedInvoice.Status) + }, + expectedState: billingentity.InvoiceStatusIssued, + }, + } + + ctx := context.Background() + + for i, tc := range cases { + s.T().Run(tc.name, func(t *testing.T) { + namespace := fmt.Sprintf("ns-invoicing-flow-happy-path-%d", i) + + _ = s.installSandboxApp(s.T(), namespace) + + // Given we have a test customer + customerEntity, err := s.CustomerService.CreateCustomer(ctx, customerentity.CreateCustomerInput{ + Namespace: namespace, + + Customer: customerentity.Customer{ + ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ + Name: "Test Customer", + }), + PrimaryEmail: lo.ToPtr("test@test.com"), + BillingAddress: &models.Address{ + Country: lo.ToPtr(models.CountryCode("US")), + }, + Currency: lo.ToPtr(currencyx.Code(currency.USD)), + }, + }) + require.NoError(s.T(), err) + require.NotNil(s.T(), customerEntity) + require.NotEmpty(s.T(), customerEntity.ID) + + // Given we have a billing profile + minimalCreateProfileInput := minimalCreateProfileInputTemplate + minimalCreateProfileInput.Namespace = namespace + minimalCreateProfileInput.WorkflowConfig = tc.workflowConfig + + profile, err := s.BillingService.CreateProfile(ctx, minimalCreateProfileInput) + + require.NoError(s.T(), err) + require.NotNil(s.T(), profile) + + invoice := s.createDraftInvoice(s.T(), ctx, draftInvoiceInput{ + Namespace: namespace, + Customer: customerEntity, + }) + require.NotNil(s.T(), invoice) + + // Given we have a draft invoice + require.Equal(s.T(), billingentity.InvoiceStatusDraftCreated, invoice.Status) + + // When we advance the invoice + tc.advance(t, ctx, invoice) + + resultingInvoice, err := s.BillingService.GetInvoiceByID(ctx, billing.GetInvoiceByIdInput{ + Invoice: billingentity.InvoiceID{ + Namespace: namespace, + ID: invoice.ID, + }, + Expand: billingentity.InvoiceExpandAll, + }) + + require.NoError(s.T(), err) + require.NotNil(s.T(), resultingInvoice) + require.Equal(s.T(), tc.expectedState, resultingInvoice.Status) + }) + } +} diff --git a/test/billing/profile_test.go b/test/billing/profile_test.go index a9a715988..09a6c0994 100644 --- a/test/billing/profile_test.go +++ b/test/billing/profile_test.go @@ -23,6 +23,15 @@ var minimalCreateProfileInputTemplate = billing.CreateProfileInput{ WorkflowConfig: billingentity.WorkflowConfig{ Collection: billingentity.CollectionConfig{ Alignment: billingentity.AlignmentKindSubscription, + Interval: lo.Must(datex.ISOString("PT2H").Parse()), + }, + Invoicing: billingentity.InvoicingConfig{ + AutoAdvance: true, + DraftPeriod: lo.Must(datex.ISOString("P1D").Parse()), + DueAfter: lo.Must(datex.ISOString("P1W").Parse()), + }, + Payment: billingentity.PaymentConfig{ + CollectionMethod: billingentity.CollectionMethodChargeAutomatically, }, }, @@ -172,7 +181,7 @@ func (s *ProfileTestSuite) TestProfileFieldSetting() { Interval: datex.MustParse(t, "PT30M"), }, Invoicing: billingentity.InvoicingConfig{ - AutoAdvance: lo.ToPtr(true), + AutoAdvance: true, DraftPeriod: datex.MustParse(t, "PT1H"), DueAfter: datex.MustParse(t, "PT24H"), }, @@ -285,7 +294,7 @@ func (s *ProfileTestSuite) TestProfileUpdates() { Interval: datex.MustParse(s.T(), "PT30M"), }, Invoicing: billingentity.InvoicingConfig{ - AutoAdvance: lo.ToPtr(true), + AutoAdvance: true, DraftPeriod: datex.MustParse(s.T(), "PT1H"), DueAfter: datex.MustParse(s.T(), "PT24H"), }, @@ -349,7 +358,7 @@ func (s *ProfileTestSuite) TestProfileUpdates() { Interval: datex.MustParse(s.T(), "PT30M"), }, Invoicing: billingentity.InvoicingConfig{ - AutoAdvance: lo.ToPtr(false), + AutoAdvance: true, DraftPeriod: datex.MustParse(s.T(), "PT2H"), DueAfter: datex.MustParse(s.T(), "PT48H"), }, @@ -424,7 +433,7 @@ func (s *ProfileTestSuite) TestProfileUpdates() { Interval: datex.MustParse(t, "PT30M"), }, Invoicing: billingentity.InvoicingConfig{ - AutoAdvance: lo.ToPtr(false), + AutoAdvance: true, DraftPeriod: datex.MustParse(t, "PT2H"), DueAfter: datex.MustParse(t, "PT48H"), }, diff --git a/tools/migrate/migrations/20241030140919_billing-draft-until.down.sql b/tools/migrate/migrations/20241030140919_billing-draft-until.down.sql new file mode 100644 index 000000000..64b5cbc06 --- /dev/null +++ b/tools/migrate/migrations/20241030140919_billing-draft-until.down.sql @@ -0,0 +1,2 @@ +-- reverse: modify "billing_invoices" table +ALTER TABLE "billing_invoices" DROP COLUMN "draft_until"; diff --git a/tools/migrate/migrations/20241030140919_billing-draft-until.up.sql b/tools/migrate/migrations/20241030140919_billing-draft-until.up.sql new file mode 100644 index 000000000..d3a299708 --- /dev/null +++ b/tools/migrate/migrations/20241030140919_billing-draft-until.up.sql @@ -0,0 +1,2 @@ +-- modify "billing_invoices" table +ALTER TABLE "billing_invoices" ADD COLUMN "draft_until" timestamptz NULL; diff --git a/tools/migrate/migrations/atlas.sum b/tools/migrate/migrations/atlas.sum index f05c7d951..a1bf7a501 100644 --- a/tools/migrate/migrations/atlas.sum +++ b/tools/migrate/migrations/atlas.sum @@ -1,4 +1,4 @@ -h1:Ujj5cNFyLFHPRjucUB4a7bSLAjf0ssECl/WWVP4ZdFc= +h1:FZpHWubUybsoKb/tU+Cks/LwabBKAl7n/ewcnsM8eCk= 20240826120919_init.down.sql h1:AIbgwwngjkJEYa3yRZsIXQyBa2+qoZttwMXHxXEbHLI= 20240826120919_init.up.sql h1:/hYHWF3Z3dab8SMKnw99ixVktCuJe2bAw5wstCZIEN8= 20240903155435_entitlement-expired-index.down.sql h1:np2xgYs3KQ2z7qPBcobtGNhqWQ3V8NwEP9E5U3TmpSA= @@ -31,3 +31,5 @@ h1:Ujj5cNFyLFHPRjucUB4a7bSLAjf0ssECl/WWVP4ZdFc= 20241021124045_billing-profile.up.sql h1:9wXAjftZcFqmzmnKBeLFe2hv0FxBaSROC3trvs+sSo4= 20241024122007_line-item-fixes.down.sql h1:NlWHQnpzszNx0NlEDVo5MGBOoLdd+9v+x58CuiX9/e8= 20241024122007_line-item-fixes.up.sql h1:B9Efi4lli4HEgVnmizbVE4s2xr2Uiq8gjkLyJ57a5oM= +20241030140919_billing-draft-until.down.sql h1:SusBStUaGJIn8bDLb4SXZMltokzN7zJYqrVwkCEssps= +20241030140919_billing-draft-until.up.sql h1:NGOQufkJREFrK7uAg9hwy63dcfaR5KYNK4howAISLDU=