diff --git a/go.mod b/go.mod index e7f39ee88..f1d486c00 100644 --- a/go.mod +++ b/go.mod @@ -45,6 +45,7 @@ require ( github.com/prometheus/client_golang v1.20.4 github.com/prometheus/client_model v0.6.1 github.com/prometheus/common v0.60.0 + github.com/qmuntal/stateless v1.7.1 github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 github.com/redis/go-redis/extra/redisotel/v9 v9.5.3 github.com/redis/go-redis/v9 v9.7.0 diff --git a/go.sum b/go.sum index 2182e253b..3b8a263b6 100644 --- a/go.sum +++ b/go.sum @@ -1161,6 +1161,8 @@ github.com/protocolbuffers/txtpbfmt v0.0.0-20230328191034-3462fbc510c0 h1:sadMIs github.com/protocolbuffers/txtpbfmt v0.0.0-20230328191034-3462fbc510c0/go.mod h1:jgxiZysxFPM+iWKwQwPR+y+Jvo54ARd4EisXxKYpB5c= github.com/pusher/pusher-http-go v4.0.1+incompatible h1:4u6tomPG1WhHaST7Wi9mw83Y+MS/j2EplR2YmDh8Xp4= github.com/pusher/pusher-http-go v4.0.1+incompatible/go.mod h1:XAv1fxRmVTI++2xsfofDhg7whapsLRG/gH/DXbF3a18= +github.com/qmuntal/stateless v1.7.1 h1:dI+BtLHq/nD6u46POkOINTDjY9uE33/4auEzfX3TWp0= +github.com/qmuntal/stateless v1.7.1/go.mod h1:n1HjRBM/cq4uCr3rfUjaMkgeGcd+ykAZwkjLje6jGBM= github.com/quipo/dependencysolver v0.0.0-20170801134659-2b009cb4ddcc h1:hK577yxEJ2f5s8w2iy2KimZmgrdAUZUNftE1ESmg2/Q= github.com/quipo/dependencysolver v0.0.0-20170801134659-2b009cb4ddcc/go.mod h1:OQt6Zo5B3Zs+C49xul8kcHo+fZ1mCLPvd0LFxiZ2DHc= github.com/r3labs/diff/v3 v3.0.1 h1:CBKqf3XmNRHXKmdU7mZP1w7TV0pDyVCis1AUHtA4Xtg= diff --git a/openmeter/billing/adapter.go b/openmeter/billing/adapter.go index 83cc8f5f1..e08b1b35c 100644 --- a/openmeter/billing/adapter.go +++ b/openmeter/billing/adapter.go @@ -69,4 +69,5 @@ type InvoiceAdapter interface { DeleteInvoices(ctx context.Context, input DeleteInvoicesAdapterInput) error ListInvoices(ctx context.Context, input ListInvoicesInput) (ListInvoicesResponse, error) AssociatedLineCounts(ctx context.Context, input AssociatedLineCountsAdapterInput) (AssociatedLineCountsAdapterResponse, error) + UpdateInvoice(ctx context.Context, input UpdateInvoiceAdapterInput) (billingentity.Invoice, error) } diff --git a/openmeter/billing/adapter/invoice.go b/openmeter/billing/adapter/invoice.go index c629f5b79..84e45ee8d 100644 --- a/openmeter/billing/adapter/invoice.go +++ b/openmeter/billing/adapter/invoice.go @@ -316,6 +316,139 @@ func (r *adapter) AssociatedLineCounts(ctx context.Context, input billing.Associ }, nil } +func (r *adapter) validateUpdateRequest(req billing.UpdateInvoiceAdapterInput, existing *db.BillingInvoice) error { + if !existing.UpdatedAt.Equal(req.UpdatedAt) { + return billing.ConflictError{ + Entity: billing.EntityInvoice, + ID: req.ID, + Err: fmt.Errorf("invoice has been updated since last read"), + } + } + + if req.Currency != existing.Currency { + return billing.ValidationError{ + Err: fmt.Errorf("currency cannot be changed"), + } + } + + if req.Type != existing.Type { + return billing.ValidationError{ + Err: fmt.Errorf("type cannot be changed"), + } + } + + if req.Customer.CustomerID != existing.CustomerID { + return billing.ValidationError{ + Err: fmt.Errorf("customer cannot be changed"), + } + } + + return nil +} + +func (r *adapter) UpdateInvoice(ctx context.Context, in billing.UpdateInvoiceAdapterInput) (billingentity.Invoice, error) { + existingInvoice, err := r.db.BillingInvoice.Query(). + Where(billinginvoice.ID(in.ID)). + Where(billinginvoice.Namespace(in.Namespace)). + Only(ctx) + if err != nil { + return billingentity.Invoice{}, err + } + + if err := r.validateUpdateRequest(in, existingInvoice); err != nil { + return billingentity.Invoice{}, err + } + + updateQuery := r.db.BillingInvoice.UpdateOneID(in.ID). + Where(billinginvoice.Namespace(in.Namespace)). + SetMetadata(in.Metadata). + // Currency is immutable + SetStatus(in.Status). + // Type is immutable + SetOrClearNumber(in.Number). + SetOrClearDescription(in.Description). + SetOrClearDueAt(in.DueAt). + SetOrClearDraftUntil(in.DraftUntil). + SetOrClearIssuedAt(in.IssuedAt) + + if in.Period != nil { + updateQuery = updateQuery. + SetPeriodStart(in.Period.Start). + SetPeriodEnd(in.Period.End) + } else { + updateQuery = updateQuery. + ClearPeriodStart(). + ClearPeriodEnd() + } + + // Supplier + updateQuery = updateQuery. + SetSupplierName(in.Supplier.Name). + SetOrClearSupplierAddressCountry(in.Supplier.Address.Country). + SetOrClearSupplierAddressPostalCode(in.Supplier.Address.PostalCode). + SetOrClearSupplierAddressCity(in.Supplier.Address.City). + SetOrClearSupplierAddressState(in.Supplier.Address.State). + SetOrClearSupplierAddressLine1(in.Supplier.Address.Line1). + SetOrClearSupplierAddressLine2(in.Supplier.Address.Line2). + SetOrClearSupplierAddressPhoneNumber(in.Supplier.Address.PhoneNumber) + + // Customer + updateQuery = updateQuery. + // CustomerID is immutable + SetCustomerName(in.Customer.Name). + SetOrClearCustomerAddressCountry(in.Customer.BillingAddress.Country). + SetOrClearCustomerAddressPostalCode(in.Customer.BillingAddress.PostalCode). + SetOrClearCustomerAddressCity(in.Customer.BillingAddress.City). + SetOrClearCustomerAddressState(in.Customer.BillingAddress.State). + SetOrClearCustomerAddressLine1(in.Customer.BillingAddress.Line1). + SetOrClearCustomerAddressLine2(in.Customer.BillingAddress.Line2). + SetOrClearCustomerAddressPhoneNumber(in.Customer.BillingAddress.PhoneNumber). + SetOrClearCustomerTimezone(in.Customer.Timezone) + + if in.Workflow != nil { + // Update the workflow config + // TODO: let's have a test for this + updateQuery = updateQuery.SetBillingWorkflowConfig( + mapWorkflowConfigToDB(in.Workflow.WorkflowConfig), + ) + } + + _, err = updateQuery.Save(ctx) + if err != nil { + return billingentity.Invoice{}, err + } + + // TODO: store this as part of the invoice + expandFromIn := expandFromInvoice(in) + // We need to re-fetch the invoice to get the updated edges + + return r.GetInvoiceById(ctx, billing.GetInvoiceByIdInput{ + Invoice: billingentity.InvoiceID{ + ID: in.ID, + Namespace: in.Namespace, + }, + Expand: expandFromIn, + }) +} + +func expandFromInvoice(invoice billingentity.Invoice) billing.InvoiceExpand { + expand := billing.InvoiceExpand{} + + if invoice.Workflow != nil { + expand.Workflow = true + } + + if len(invoice.Lines) > 0 { + expand.Lines = true + } + + if invoice.Workflow.Apps != nil { + expand.WorkflowApps = true + } + + return expand +} + func mapInvoiceFromDB(invoice db.BillingInvoice, expand billing.InvoiceExpand) (billingentity.Invoice, error) { res := billingentity.Invoice{ ID: invoice.ID, @@ -327,6 +460,7 @@ func mapInvoiceFromDB(invoice db.BillingInvoice, expand billing.InvoiceExpand) ( Number: invoice.Number, Description: invoice.Description, DueAt: invoice.DueAt, + DraftUntil: invoice.DraftUntil, Supplier: billingentity.SupplierContact{ Name: invoice.SupplierName, Address: models.Address{ diff --git a/openmeter/billing/entity/invoice.go b/openmeter/billing/entity/invoice.go index fbc26f312..acc3d1019 100644 --- a/openmeter/billing/entity/invoice.go +++ b/openmeter/billing/entity/invoice.go @@ -1,9 +1,12 @@ package billingentity import ( + "database/sql/driver" "fmt" + "strings" "time" + "entgo.io/ent/schema/field" "github.com/invopop/gobl/bill" "github.com/invopop/gobl/cbc" "github.com/samber/lo" @@ -42,78 +45,167 @@ func (t InvoiceType) CBCKey() cbc.Key { return cbc.Key(t) } -type InvoiceStatus string +var _ field.ValueScanner = (*InvoiceStatus)(nil) + +type InvoiceStatus struct { + Status string + Sub string +} + +func (s InvoiceStatus) String() string { + if s.Sub == "" { + return s.Status + } + + return fmt.Sprintf("%s_%s", s.Status, s.Sub) +} + +func ParseInvoiceStatus(status string) InvoiceStatus { + parts := strings.SplitN(status, "_", 2) + if len(parts) == 1 { + return InvoiceStatus{ + Status: parts[0], + } + } + + return InvoiceStatus{ + Status: parts[0], + Sub: parts[1], + } +} + +func (s InvoiceStatus) Value() (driver.Value, error) { + return s.String(), nil +} + +func (s *InvoiceStatus) Scan(value interface{}) error { + if value == nil { + return nil + } + + switch v := value.(type) { + case string: + *s = ParseInvoiceStatus(v) + return nil + default: + return fmt.Errorf("unsupported type: %T", value) + } +} const ( + StatusNameGathering = "gathering" + StatusNameDraft = "draft" + StatusNameIssuing = "issuing" + StatusNameIssued = "issued" + + SubStatusNameCreated = "created" + SubStatusNameValidating = "validating" + SubStatusNameInvalid = "invalid" + SubStatusManualApprovePending = "manual_approve_pending" + SubStatusWaitingAutoApproval = "waiting_auto_approval" + SubStatusNameSync = "sync" + SubStatusNameSyncFailed = "sync_failed" +) + +// TODO: get rid of this crap +var ( // InvoiceStatusGathering is the status of an invoice that is gathering the items to be invoiced. - InvoiceStatusGathering InvoiceStatus = "gathering" - // InvoiceStatusPendingCreation is the status of an invoice summarizing the pending items. - InvoiceStatusPendingCreation InvoiceStatus = "pending_creation" - // InvoiceStatusCreated is the status of an invoice that has been created. - InvoiceStatusCreated InvoiceStatus = "created" - // InvoiceStatusValidationFailed is the status of an invoice that failed validation. - InvoiceStatusValidationFailed InvoiceStatus = "validation_failed" + InvoiceStatusGathering InvoiceStatus = InvoiceStatus{ + Status: StatusNameGathering, + } + // InvoiceStatusDraft is the status of an invoice that is in draft both on OpenMeter and the provider side. - InvoiceStatusDraft InvoiceStatus = "draft" - // InvoiceStatusDraftSync is the status of an invoice that is being synced with the provider. - InvoiceStatusDraftSync InvoiceStatus = "draft_sync" - // InvoiceStatusDraftSyncFailed is the status of an invoice that failed to sync with the provider. - InvoiceStatusDraftSyncFailed InvoiceStatus = "draft_sync_failed" + InvoiceStatusDraft InvoiceStatus = InvoiceStatus{ + Status: StatusNameDraft, + } + + InvoiceStatusDraftCreated InvoiceStatus = InvoiceStatus{ + Status: StatusNameDraft, + Sub: SubStatusNameCreated, + } + + InvoiceStatusDraftManualApprovalNeeded InvoiceStatus = InvoiceStatus{ + Status: StatusNameDraft, + Sub: SubStatusManualApprovePending, + } + + InvoiceStatusDraftValidating InvoiceStatus = InvoiceStatus{ + Status: StatusNameDraft, + Sub: SubStatusNameValidating, + } + + InvoiceStatusDraftInvalid InvoiceStatus = InvoiceStatus{ + Status: StatusNameDraft, + Sub: SubStatusNameInvalid, + } + + InvoiceStatusDraftSync InvoiceStatus = InvoiceStatus{ + Status: StatusNameDraft, + Sub: SubStatusNameSync, + } + + InvoiceStatusDraftSyncFailed InvoiceStatus = InvoiceStatus{ + Status: StatusNameDraft, + Sub: SubStatusNameSyncFailed, + } + + InvoiceStatusDraftWaitAutoApproval InvoiceStatus = InvoiceStatus{ + Status: StatusNameDraft, + Sub: SubStatusWaitingAutoApproval, + } + // InvoiceStatusIssuing is the status of an invoice that is being issued. - InvoiceStatusIssuing InvoiceStatus = "issuing" - // InvoiceStatusIssued is the status of an invoice that has been issued both on OpenMeter and provider side. - InvoiceStatusIssued InvoiceStatus = "issued" - // InvoiceStatusIssuingFailed is the status of an invoice that failed to issue on the provider or OpenMeter side. - InvoiceStatusIssuingFailed InvoiceStatus = "issuing_failed" - // InvoiceStatusManualApprovalNeeded is the status of an invoice that needs manual approval. (due to AutoApprove is disabled) - InvoiceStatusManualApprovalNeeded InvoiceStatus = "manual_approval_needed" - // InvoiceStatusDeleted is the status of an invoice that has been deleted (e.g. removed from the database before being issued). - InvoiceStatusDeleted InvoiceStatus = "deleted" -) + InvoiceStatusIssuing InvoiceStatus = InvoiceStatus{ + Status: StatusNameIssuing, + } -// InvoiceImmutableStatuses are the statuses that forbid any changes to the invoice. -var InvoiceImmutableStatuses = []InvoiceStatus{ - InvoiceStatusIssued, - InvoiceStatusDeleted, -} + InvoiceStatusIssuingSyncFailed InvoiceStatus = InvoiceStatus{ + Status: StatusNameIssuing, + Sub: SubStatusNameSyncFailed, + } + + // InvoiceStatusIssued is the status of an invoice that has been issued. + InvoiceStatusIssued InvoiceStatus = InvoiceStatus{ + Status: StatusNameIssued, + } + + validStatuses = []InvoiceStatus{ + InvoiceStatusGathering, + InvoiceStatusDraft, + InvoiceStatusDraftCreated, + InvoiceStatusDraftManualApprovalNeeded, + InvoiceStatusDraftValidating, + InvoiceStatusDraftInvalid, + InvoiceStatusDraftSync, + InvoiceStatusDraftSyncFailed, + InvoiceStatusDraftWaitAutoApproval, + InvoiceStatusIssuing, + InvoiceStatusIssuingSyncFailed, + InvoiceStatusIssued, + } +) func (s InvoiceStatus) Values() []string { return lo.Map( - []InvoiceStatus{ - InvoiceStatusGathering, - InvoiceStatusCreated, - InvoiceStatusDraft, - InvoiceStatusDraftSync, - InvoiceStatusDraftSyncFailed, - InvoiceStatusIssuing, - InvoiceStatusIssued, - InvoiceStatusIssuingFailed, - InvoiceStatusManualApprovalNeeded, - }, + validStatuses, func(item InvoiceStatus, _ int) string { - return string(item) + return item.String() }, ) } -func (s InvoiceStatus) Validate() error { - for _, status := range s.Values() { - if string(s) == status { - return nil - } - } +var immutableStatuses = []string{StatusNameIssued} - return fmt.Errorf("invalid invoice status: %s", s) +func (s InvoiceStatus) IsMutable() bool { + return lo.Contains(immutableStatuses, s.Status) } -func (s InvoiceStatus) IsMutable() bool { - for _, status := range InvoiceImmutableStatuses { - if s == status { - return false - } +func (s InvoiceStatus) Validate() error { + if !lo.Contains(validStatuses, s) { + return fmt.Errorf("invalid invoice status: %s", s) } - return true + return nil } type InvoiceID models.NamespacedID @@ -141,11 +233,12 @@ type Invoice struct { DueAt *time.Time `json:"dueDate,omitempty"` - CreatedAt time.Time `json:"createdAt"` - UpdatedAt time.Time `json:"updatedAt"` - VoidedAt *time.Time `json:"voidedAt,omitempty"` - IssuedAt *time.Time `json:"issuedAt,omitempty"` - DeletedAt *time.Time `json:"deletedAt,omitempty"` + CreatedAt time.Time `json:"createdAt"` + UpdatedAt time.Time `json:"updatedAt"` + VoidedAt *time.Time `json:"voidedAt,omitempty"` + DraftUntil *time.Time `json:"draftUntil,omitempty"` + IssuedAt *time.Time `json:"issuedAt,omitempty"` + DeletedAt *time.Time `json:"deletedAt,omitempty"` // Customer is either a snapshot of the contact information of the customer at the time of invoice being sent // or the data from the customer entity (draft state) @@ -156,6 +249,31 @@ type Invoice struct { // Line items Lines []Line `json:"lines,omitempty"` + + changed bool `json:"-"` +} + +func (i *Invoice) Calculate() error { + for _, calc := range InvoiceCalculations { + changed, err := calc(i) + if err != nil { + return err + } + + if changed { + i.SetChanged() + } + } + + return nil +} + +func (i Invoice) Changed() bool { + return i.changed +} + +func (i *Invoice) SetChanged() { + i.changed = true } type InvoiceWithValidation struct { diff --git a/openmeter/billing/entity/invoicecalc.go b/openmeter/billing/entity/invoicecalc.go new file mode 100644 index 000000000..937dc2aa0 --- /dev/null +++ b/openmeter/billing/entity/invoicecalc.go @@ -0,0 +1,21 @@ +package billingentity + +type InvoiceCalculation func(*Invoice) (bool, error) + +var InvoiceCalculations = []InvoiceCalculation{ + CalculateDraftUntilIfMissing, +} + +// CalculateDraftUntilIfMissing calculates the draft until date if it is missing. +// If it's set we are not updating it as the user should update that instead of manipulating the +// workflow config. +func CalculateDraftUntilIfMissing(i *Invoice) (bool, error) { + if i.DraftUntil != nil || !i.Workflow.WorkflowConfig.Invoicing.GetAutoAdvance() { + return false, nil + } + + draftUntil, _ := i.Workflow.WorkflowConfig.Invoicing.DraftPeriod.AddTo(i.CreatedAt) + i.DraftUntil = &draftUntil + + return true, nil +} diff --git a/openmeter/billing/entity/profile.go b/openmeter/billing/entity/profile.go index 987c6f6f6..a36befb7e 100644 --- a/openmeter/billing/entity/profile.go +++ b/openmeter/billing/entity/profile.go @@ -131,6 +131,10 @@ func (c *InvoicingConfig) Validate() error { return nil } +func (c *InvoicingConfig) GetAutoAdvance() bool { + return lo.FromPtrOr(c.AutoAdvance, false) +} + type GranularityResolution string const ( diff --git a/openmeter/billing/errors.go b/openmeter/billing/errors.go index c95c70c16..00b04ce9f 100644 --- a/openmeter/billing/errors.go +++ b/openmeter/billing/errors.go @@ -67,6 +67,24 @@ func (e NotFoundError) Unwrap() error { return e.Err } +type ConflictError struct { + ID string + Entity string + Err error +} + +func (e ConflictError) Error() string { + if e.ID == "" { + return e.Err.Error() + } + + return fmt.Sprintf("%s [%s/%s]", e.Err, e.Entity, e.ID) +} + +func (e ConflictError) Unwrap() error { + return e.Err +} + type genericError struct { Err error } diff --git a/openmeter/billing/httpdriver/invoice.go b/openmeter/billing/httpdriver/invoice.go index 97536635e..2ddb5968e 100644 --- a/openmeter/billing/httpdriver/invoice.go +++ b/openmeter/billing/httpdriver/invoice.go @@ -31,6 +31,7 @@ func (h *handler) ListInvoices() ListInvoicesHandler { return ListInvoicesRequest{}, fmt.Errorf("failed to resolve namespace: %w", err) } + // TODO: allow filtering by state and substate too return ListInvoicesRequest{ Namespace: ns, @@ -38,7 +39,7 @@ func (h *handler) ListInvoices() ListInvoicesHandler { Statuses: lo.Map( lo.FromPtrOr(input.Statuses, nil), func(status api.BillingInvoiceStatus, _ int) billingentity.InvoiceStatus { - return billingentity.InvoiceStatus(status) + return billingentity.ParseInvoiceStatus(string(status)) }, ), @@ -114,7 +115,8 @@ func mapInvoiceToAPI(invoice billingentity.Invoice) (api.BillingInvoice, error) IssuedAt: invoice.IssuedAt, VoidedAt: invoice.VoidedAt, DueAt: invoice.DueAt, - Period: mapPeriodToAPI(invoice.Period), + // TODO: add draft until (as soon as typespec freeze is lifted) + Period: mapPeriodToAPI(invoice.Period), Currency: string(invoice.Currency), Customer: mapInvoiceCustomerToAPI(invoice.Customer), @@ -123,7 +125,7 @@ func mapInvoiceToAPI(invoice billingentity.Invoice) (api.BillingInvoice, error) Description: invoice.Description, Metadata: lo.EmptyableToPtr(invoice.Metadata), - Status: api.BillingInvoiceStatus(invoice.Status), + Status: api.BillingInvoiceStatus(invoice.Status.String()), Supplier: mapSupplierContactToAPI(invoice.Supplier), // TODO[OM-942]: This needs to be (re)implemented Totals: api.BillingInvoiceTotals{}, diff --git a/openmeter/billing/invoice.go b/openmeter/billing/invoice.go index 3536b62a2..e7b4207b3 100644 --- a/openmeter/billing/invoice.go +++ b/openmeter/billing/invoice.go @@ -177,3 +177,7 @@ func (i CreateInvoiceInput) Validate() error { type AssociatedLineCountsAdapterResponse struct { Counts map[billingentity.InvoiceID]int64 } + +type AdvanceInvoiceInput = billingentity.InvoiceID + +type UpdateInvoiceAdapterInput = billingentity.Invoice diff --git a/openmeter/billing/service.go b/openmeter/billing/service.go index ffe2edddc..55341dcce 100644 --- a/openmeter/billing/service.go +++ b/openmeter/billing/service.go @@ -39,4 +39,9 @@ type InvoiceService interface { ListInvoices(ctx context.Context, input ListInvoicesInput) (ListInvoicesResponse, error) GetInvoiceByID(ctx context.Context, input GetInvoiceByIdInput) (billingentity.Invoice, error) CreateInvoice(ctx context.Context, input CreateInvoiceInput) ([]billingentity.Invoice, error) + // AdvanceInvoice advances the invoice to the next stage, the advancement is stopped until: + // - an error is occured + // - the invoice is in a state that cannot be advanced (e.g. waiting for draft period to expire) + // - the invoice is advanced to the final state + AdvanceInvoice(ctx context.Context, input AdvanceInvoiceInput) (*billingentity.Invoice, error) } diff --git a/openmeter/billing/service/invoice.go b/openmeter/billing/service/invoice.go index a4502b47a..c6e543201 100644 --- a/openmeter/billing/service/invoice.go +++ b/openmeter/billing/service/invoice.go @@ -130,7 +130,7 @@ func (s *Service) CreateInvoice(ctx context.Context, input billing.CreateInvoice Profile: customerProfile.Profile, Currency: currency, - Status: billingentity.InvoiceStatusDraft, + Status: billingentity.InvoiceStatusDraftCreated, Type: billingentity.InvoiceTypeStandard, }) @@ -148,6 +148,23 @@ func (s *Service) CreateInvoice(ctx context.Context, input billing.CreateInvoice if err != nil { return nil, fmt.Errorf("associating lines to invoice: %w", err) } + + // TODO: we should have the lines added too for any line-specific calculations + // TODO: let's use the Invoice object as input maybe? + + // let's update any calculated fields on the invoice + err = invoice.Calculate() + if err != nil { + return nil, fmt.Errorf("calculating invoice fields: %w", err) + } + + // let's update the invoice if needed + if invoice.Changed() { + invoice, err = txAdapter.UpdateInvoice(ctx, invoice) + if err != nil { + return nil, fmt.Errorf("updating invoice: %w", err) + } + } } // Let's check if we need to remove any empty gathering invoices (e.g. if they don't have any line items) @@ -180,6 +197,7 @@ func (s *Service) CreateInvoice(ctx context.Context, input billing.CreateInvoice // Assemble output: we need to refetch as the association call will have side-effects of updating // invoice objects (e.g. totals, period, etc.) + // TODO: this would not be required if we allow creating the invoice from in-memory objects out := make([]billingentity.Invoice, 0, len(createdInvoices)) for _, invoiceID := range createdInvoices { invoiceWithLines, err := s.GetInvoiceByID(ctx, billing.GetInvoiceByIdInput{ @@ -262,3 +280,47 @@ func (s *Service) gatherInscopeLines(ctx context.Context, input billing.CreateIn return lines, nil } + +func (s *Service) AdvanceInvoice(ctx context.Context, input billing.AdvanceInvoiceInput) (*billingentity.Invoice, error) { + if err := input.Validate(); err != nil { + return nil, billing.ValidationError{ + Err: err, + } + } + + return entutils.TransactingRepo(ctx, s.adapter, func(ctx context.Context, txAdapter billing.Adapter) (*billingentity.Invoice, error) { + // let's lock the invoice for update, we are using the dedicated call, so that + // edges won't end up having SELECT FOR UPDATE locks + if err := txAdapter.LockInvoicesForUpdate(ctx, billing.LockInvoicesForUpdateInput{ + Namespace: input.Namespace, + InvoiceIDs: []string{input.ID}, + }); err != nil { + return nil, fmt.Errorf("locking invoice: %w", err) + } + + invoice, err := s.GetInvoiceByID(ctx, billing.GetInvoiceByIdInput{ + Invoice: input, + Expand: billing.InvoiceExpandAll, + }) + if err != nil { + return nil, fmt.Errorf("fetching invoice: %w", err) + } + + fsm := NewInvoiceStateMachine(&invoice, s.logger) + + preActivationState, err := fsm.FSM.State(ctx) + if err != nil { + return nil, fmt.Errorf("fetching state: %w", err) + } + + postActivationState, err := fsm.ActiveUntilStateStable(ctx) + if err != nil { + return nil, fmt.Errorf("activating invoice: %w", err) + } + + s.logger.Info("invoice advanced", "invoice", invoice.ID, "from", preActivationState, "to", postActivationState) + + // TODO: save invoice + return &invoice, nil + }) +} diff --git a/openmeter/billing/service/invoicestate.go b/openmeter/billing/service/invoicestate.go new file mode 100644 index 000000000..dfd351f65 --- /dev/null +++ b/openmeter/billing/service/invoicestate.go @@ -0,0 +1,240 @@ +package billingservice + +import ( + "context" + "fmt" + "log/slog" + + billingentity "github.com/openmeterio/openmeter/openmeter/billing/entity" + "github.com/openmeterio/openmeter/pkg/clock" + "github.com/qmuntal/stateless" +) + +type InvoiceStateMachine struct { + Invoice *billingentity.Invoice + FSM *stateless.StateMachine + Logger *slog.Logger +} + +const ( + triggerDraftPeriodExpired = "trigger_draft_period_expired" + triggerDraftValidationSucceeded = "trigger_draft_validation_succeeded" + triggerDraftValidationFailed = "trigger_draft_validation_failed" + triggerDraftValidation = "trigger_draft_validation" + triggerDraftSyncComplete = "trigger_draft_sync_complete" + triggerDraftSyncFailed = "trigger_draft_sync_failed" + + triggerIssuingFailed = "trigger_issuing_failed" + triggerIssuingComplete = "trigger_issuing_complete" + + triggerRetry = "trigger_retry" + triggerApproved = "trigger_approved" +) + +// TODO: config! +func NewInvoiceStateMachine(invoice *billingentity.Invoice, logger *slog.Logger) *InvoiceStateMachine { + out := &InvoiceStateMachine{ + Invoice: invoice, + Logger: logger, + } + + // TODO[later]: Tax is not captured here for now, as it would require the DB schema too + // TODO[optimization]: The state machine can be added to sync.Pool to avoid allocations (state is stored in the Invoice entity) + + fsm := stateless.NewStateMachineWithExternalStorage( + func(ctx context.Context) (stateless.State, error) { + return out.Invoice.Status, nil + }, + func(ctx context.Context, state stateless.State) error { + invState, ok := state.(billingentity.InvoiceStatus) + if !ok { + return fmt.Errorf("invalid state type: %v", state) + } + + out.Logger.Info("setting invoice state", "invoice", out.Invoice.ID, "state", invState) + out.Invoice.Status = invState + return nil + }, + stateless.FiringImmediate, + ) + + // Draft states + + // Note: this is for grouping purposes only, the invoice should never be in this state + // and this state must never have any Actions defined or they will be always executed for + // substate transitions. + fsm.Configure(billingentity.InvoiceStatusDraft) + + // TODO: do we need both? + fsm.Configure(billingentity.InvoiceStatusDraftCreated). + Permit(triggerDraftValidation, billingentity.InvoiceStatusDraftValidating). + OnActive(func(ctx context.Context) error { + if err := out.populateWorkflowFields(ctx); err != nil { + return err + } + + return fsm.Fire(triggerDraftValidation) + }) + + fsm.Configure(billingentity.InvoiceStatusDraftValidating). + SubstateOf(billingentity.InvoiceStatusDraft). + Permit(triggerDraftValidationSucceeded, billingentity.InvoiceStatusDraftSync). + Permit(triggerDraftValidationFailed, billingentity.InvoiceStatusDraftInvalid). + OnActive(func(ctx context.Context) error { + if err := out.validateDraftInvoice(ctx); err != nil { + fsm.Fire(triggerDraftValidationFailed, err) + } + + return fsm.Fire(triggerDraftValidationSucceeded) + }) + + fsm.Configure(billingentity.InvoiceStatusDraftInvalid). + SubstateOf(billingentity.InvoiceStatusDraft). + Permit(triggerRetry, billingentity.InvoiceStatusDraftValidating) + + fsm.Configure(billingentity.InvoiceStatusDraftSync). + SubstateOf(billingentity.InvoiceStatusDraft). + Permit(triggerDraftSyncComplete, billingentity.InvoiceStatusDraftManualApprovalNeeded, boolFn(not(out.isAutoAdvanceEnabled))). + Permit(triggerDraftSyncComplete, billingentity.InvoiceStatusDraftWaitAutoApproval, boolFn(out.isAutoAdvanceEnabled)). + Permit(triggerDraftSyncFailed, billingentity.InvoiceStatusDraftSyncFailed). + OnActive(func(ctx context.Context) error { + if err := out.syncDraftInvoice(ctx); err != nil { + return fsm.Fire(triggerDraftSyncFailed, err) + } + + return fsm.Fire(triggerDraftSyncComplete) + }) + + fsm.Configure(billingentity.InvoiceStatusDraftSyncFailed). + SubstateOf(billingentity.InvoiceStatusDraft). + Permit(triggerRetry, billingentity.InvoiceStatusDraftValidating) + + // Automatic and manual approvals + fsm.Configure(billingentity.InvoiceStatusDraftWaitAutoApproval). + SubstateOf(billingentity.InvoiceStatusDraft). + // Manual approval forces the draft invoice to be issued regardless of the review period + Permit(triggerApproved, billingentity.InvoiceStatusIssuing). + OnActive(func(ctx context.Context) error { + if out.shouldAutoAdvance() { + return fsm.Fire(triggerApproved) + } + + return nil + }) + + fsm.Configure(billingentity.InvoiceStatusDraftManualApprovalNeeded). + SubstateOf(billingentity.InvoiceStatusDraft). + Permit(triggerApproved, billingentity.InvoiceStatusIssuing) + + // Issuing state + + fsm.Configure(billingentity.InvoiceStatusIssuing). + Permit(triggerIssuingComplete, billingentity.InvoiceStatusIssued). + Permit(triggerIssuingFailed, billingentity.InvoiceStatusIssuingSyncFailed). + OnActive(func(ctx context.Context) error { + if err := out.issueInvoice(ctx); err != nil { + return fsm.Fire(triggerIssuingFailed, err) + } + + return fsm.Fire(triggerIssuingComplete) + }) + + fsm.Configure(billingentity.InvoiceStatusIssuingSyncFailed). + SubstateOf(billingentity.InvoiceStatusIssuing). + Permit(triggerRetry, billingentity.InvoiceStatusIssuing) + + // Issued state + fsm.Configure(billingentity.InvoiceStatusIssued) + + out.FSM = fsm + + return out +} + +func (m *InvoiceStateMachine) ActiveUntilStateStable(ctx context.Context) (stateless.State, error) { + previousState, err := m.FSM.State(ctx) + if err != nil { + return nil, err + } + + for { + if err := m.FSM.ActivateCtx(ctx); err != nil { + return nil, err + } + + currentState, err := m.FSM.State(ctx) + if err != nil { + return nil, err + } + + m.Logger.Info("invoice advanced", "invoice", m.Invoice.ID, "from", previousState, "to", currentState) + + if currentState == previousState { + return currentState, nil + } + + previousState = currentState + } +} + +func (m *InvoiceStateMachine) populateWorkflowFields(ctx context.Context) error { + // TODO: this should be calculated on save/create + if m.isAutoAdvanceEnabled() { + } + return nil +} + +// validateDraftInvoice validates the draft invoice using the apps referenced in the invoice. +func (m *InvoiceStateMachine) validateDraftInvoice(ctx context.Context) error { + return nil +} + +// syncDraftInvoice syncs the draft invoice with the external system. +func (m *InvoiceStateMachine) syncDraftInvoice(ctx context.Context) error { + return nil +} + +// issueInvoice issues the invoice using the invoicing app +func (m *InvoiceStateMachine) issueInvoice(ctx context.Context) error { + return nil +} + +func (m *InvoiceStateMachine) isAutoAdvanceEnabled() bool { + return m.Invoice.Workflow.WorkflowConfig.Invoicing.GetAutoAdvance() +} + +func (m *InvoiceStateMachine) shouldAutoAdvance() bool { + if !m.isAutoAdvanceEnabled() || m.Invoice.DraftUntil == nil { + return false + } + + return !m.Invoice.DraftUntil.Before(clock.Now()) +} + +func boolFn(fn func() bool) func(context.Context, ...any) bool { + return func(context.Context, ...any) bool { + return fn() + } +} + +func not(fn func() bool) func() bool { + return func() bool { + return !fn() + } +} + +func trigger(fsm *stateless.StateMachine, trigger string) func(context.Context) error { + return func(ctx context.Context) error { + return fsm.FireCtx(ctx, trigger) + } +} + +func invoke(fsm *stateless.StateMachine, f func(context.Context) error, okTrigger, failTrigger string) func(context.Context) error { + return func(ctx context.Context) error { + if err := f(ctx); err != nil { + return fsm.Fire(failTrigger, err) + } + + return fsm.Fire(okTrigger) + } +} diff --git a/openmeter/ent/db/billinginvoice.go b/openmeter/ent/db/billinginvoice.go index d0e45241d..d84fa35f5 100644 --- a/openmeter/ent/db/billinginvoice.go +++ b/openmeter/ent/db/billinginvoice.go @@ -86,6 +86,8 @@ type BillingInvoice struct { VoidedAt *time.Time `json:"voided_at,omitempty"` // IssuedAt holds the value of the "issued_at" field. IssuedAt *time.Time `json:"issued_at,omitempty"` + // DraftUntil holds the value of the "draft_until" field. + DraftUntil *time.Time `json:"draft_until,omitempty"` // Currency holds the value of the "currency" field. Currency currencyx.Code `json:"currency,omitempty"` // DueAt holds the value of the "due_at" field. @@ -213,9 +215,11 @@ func (*BillingInvoice) scanValues(columns []string) ([]any, error) { switch columns[i] { case billinginvoice.FieldMetadata: values[i] = new([]byte) - case billinginvoice.FieldID, billinginvoice.FieldNamespace, billinginvoice.FieldSupplierAddressCountry, billinginvoice.FieldSupplierAddressPostalCode, billinginvoice.FieldSupplierAddressState, billinginvoice.FieldSupplierAddressCity, billinginvoice.FieldSupplierAddressLine1, billinginvoice.FieldSupplierAddressLine2, billinginvoice.FieldSupplierAddressPhoneNumber, billinginvoice.FieldCustomerAddressCountry, billinginvoice.FieldCustomerAddressPostalCode, billinginvoice.FieldCustomerAddressState, billinginvoice.FieldCustomerAddressCity, billinginvoice.FieldCustomerAddressLine1, billinginvoice.FieldCustomerAddressLine2, billinginvoice.FieldCustomerAddressPhoneNumber, billinginvoice.FieldSupplierName, billinginvoice.FieldSupplierTaxCode, billinginvoice.FieldCustomerName, billinginvoice.FieldCustomerTimezone, billinginvoice.FieldNumber, billinginvoice.FieldType, billinginvoice.FieldDescription, billinginvoice.FieldCustomerID, billinginvoice.FieldSourceBillingProfileID, billinginvoice.FieldCurrency, billinginvoice.FieldStatus, billinginvoice.FieldWorkflowConfigID, billinginvoice.FieldTaxAppID, billinginvoice.FieldInvoicingAppID, billinginvoice.FieldPaymentAppID: + case billinginvoice.FieldStatus: + values[i] = new(billingentity.InvoiceStatus) + case billinginvoice.FieldID, billinginvoice.FieldNamespace, billinginvoice.FieldSupplierAddressCountry, billinginvoice.FieldSupplierAddressPostalCode, billinginvoice.FieldSupplierAddressState, billinginvoice.FieldSupplierAddressCity, billinginvoice.FieldSupplierAddressLine1, billinginvoice.FieldSupplierAddressLine2, billinginvoice.FieldSupplierAddressPhoneNumber, billinginvoice.FieldCustomerAddressCountry, billinginvoice.FieldCustomerAddressPostalCode, billinginvoice.FieldCustomerAddressState, billinginvoice.FieldCustomerAddressCity, billinginvoice.FieldCustomerAddressLine1, billinginvoice.FieldCustomerAddressLine2, billinginvoice.FieldCustomerAddressPhoneNumber, billinginvoice.FieldSupplierName, billinginvoice.FieldSupplierTaxCode, billinginvoice.FieldCustomerName, billinginvoice.FieldCustomerTimezone, billinginvoice.FieldNumber, billinginvoice.FieldType, billinginvoice.FieldDescription, billinginvoice.FieldCustomerID, billinginvoice.FieldSourceBillingProfileID, billinginvoice.FieldCurrency, billinginvoice.FieldWorkflowConfigID, billinginvoice.FieldTaxAppID, billinginvoice.FieldInvoicingAppID, billinginvoice.FieldPaymentAppID: values[i] = new(sql.NullString) - case billinginvoice.FieldCreatedAt, billinginvoice.FieldUpdatedAt, billinginvoice.FieldDeletedAt, billinginvoice.FieldVoidedAt, billinginvoice.FieldIssuedAt, billinginvoice.FieldDueAt, billinginvoice.FieldPeriodStart, billinginvoice.FieldPeriodEnd: + case billinginvoice.FieldCreatedAt, billinginvoice.FieldUpdatedAt, billinginvoice.FieldDeletedAt, billinginvoice.FieldVoidedAt, billinginvoice.FieldIssuedAt, billinginvoice.FieldDraftUntil, billinginvoice.FieldDueAt, billinginvoice.FieldPeriodStart, billinginvoice.FieldPeriodEnd: values[i] = new(sql.NullTime) default: values[i] = new(sql.UnknownType) @@ -441,6 +445,13 @@ func (bi *BillingInvoice) assignValues(columns []string, values []any) error { bi.IssuedAt = new(time.Time) *bi.IssuedAt = value.Time } + case billinginvoice.FieldDraftUntil: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field draft_until", values[i]) + } else if value.Valid { + bi.DraftUntil = new(time.Time) + *bi.DraftUntil = value.Time + } case billinginvoice.FieldCurrency: if value, ok := values[i].(*sql.NullString); !ok { return fmt.Errorf("unexpected type %T for field currency", values[i]) @@ -455,10 +466,10 @@ func (bi *BillingInvoice) assignValues(columns []string, values []any) error { *bi.DueAt = value.Time } case billinginvoice.FieldStatus: - if value, ok := values[i].(*sql.NullString); !ok { + if value, ok := values[i].(*billingentity.InvoiceStatus); !ok { return fmt.Errorf("unexpected type %T for field status", values[i]) - } else if value.Valid { - bi.Status = billingentity.InvoiceStatus(value.String) + } else if value != nil { + bi.Status = *value } case billinginvoice.FieldWorkflowConfigID: if value, ok := values[i].(*sql.NullString); !ok { @@ -701,6 +712,11 @@ func (bi *BillingInvoice) String() string { builder.WriteString(v.Format(time.ANSIC)) } builder.WriteString(", ") + if v := bi.DraftUntil; v != nil { + builder.WriteString("draft_until=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") builder.WriteString("currency=") builder.WriteString(fmt.Sprintf("%v", bi.Currency)) builder.WriteString(", ") diff --git a/openmeter/ent/db/billinginvoice/billinginvoice.go b/openmeter/ent/db/billinginvoice/billinginvoice.go index 641b9ddd1..18bbea4e5 100644 --- a/openmeter/ent/db/billinginvoice/billinginvoice.go +++ b/openmeter/ent/db/billinginvoice/billinginvoice.go @@ -76,6 +76,8 @@ const ( FieldVoidedAt = "voided_at" // FieldIssuedAt holds the string denoting the issued_at field in the database. FieldIssuedAt = "issued_at" + // FieldDraftUntil holds the string denoting the draft_until field in the database. + FieldDraftUntil = "draft_until" // FieldCurrency holds the string denoting the currency field in the database. FieldCurrency = "currency" // FieldDueAt holds the string denoting the due_at field in the database. @@ -194,6 +196,7 @@ var Columns = []string{ FieldSourceBillingProfileID, FieldVoidedAt, FieldIssuedAt, + FieldDraftUntil, FieldCurrency, FieldDueAt, FieldStatus, @@ -254,8 +257,8 @@ func TypeValidator(_type billingentity.InvoiceType) error { // StatusValidator is a validator for the "status" field enum values. It is called by the builders before save. func StatusValidator(s billingentity.InvoiceStatus) error { - switch s { - case "gathering", "created", "draft", "draft_sync", "draft_sync_failed", "issuing", "issued", "issuing_failed", "manual_approval_needed": + switch s.String() { + case "gathering", "draft", "draft_created", "draft_manual_approve_pending", "draft_validating", "draft_invalid", "draft_sync", "draft_sync_failed", "draft_waiting_auto_approval", "issuing", "issuing_sync_failed", "issued": return nil default: return fmt.Errorf("billinginvoice: invalid enum value for status field: %q", s) @@ -415,6 +418,11 @@ func ByIssuedAt(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldIssuedAt, opts...).ToFunc() } +// ByDraftUntil orders the results by the draft_until field. +func ByDraftUntil(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDraftUntil, opts...).ToFunc() +} + // ByCurrency orders the results by the currency field. func ByCurrency(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldCurrency, opts...).ToFunc() diff --git a/openmeter/ent/db/billinginvoice/where.go b/openmeter/ent/db/billinginvoice/where.go index 4f5bee357..cdb2648e6 100644 --- a/openmeter/ent/db/billinginvoice/where.go +++ b/openmeter/ent/db/billinginvoice/where.go @@ -212,6 +212,11 @@ func IssuedAt(v time.Time) predicate.BillingInvoice { return predicate.BillingInvoice(sql.FieldEQ(FieldIssuedAt, v)) } +// DraftUntil applies equality check predicate on the "draft_until" field. It's identical to DraftUntilEQ. +func DraftUntil(v time.Time) predicate.BillingInvoice { + return predicate.BillingInvoice(sql.FieldEQ(FieldDraftUntil, v)) +} + // Currency applies equality check predicate on the "currency" field. It's identical to CurrencyEQ. func Currency(v currencyx.Code) predicate.BillingInvoice { vc := string(v) @@ -2255,6 +2260,56 @@ func IssuedAtNotNil() predicate.BillingInvoice { return predicate.BillingInvoice(sql.FieldNotNull(FieldIssuedAt)) } +// DraftUntilEQ applies the EQ predicate on the "draft_until" field. +func DraftUntilEQ(v time.Time) predicate.BillingInvoice { + return predicate.BillingInvoice(sql.FieldEQ(FieldDraftUntil, v)) +} + +// DraftUntilNEQ applies the NEQ predicate on the "draft_until" field. +func DraftUntilNEQ(v time.Time) predicate.BillingInvoice { + return predicate.BillingInvoice(sql.FieldNEQ(FieldDraftUntil, v)) +} + +// DraftUntilIn applies the In predicate on the "draft_until" field. +func DraftUntilIn(vs ...time.Time) predicate.BillingInvoice { + return predicate.BillingInvoice(sql.FieldIn(FieldDraftUntil, vs...)) +} + +// DraftUntilNotIn applies the NotIn predicate on the "draft_until" field. +func DraftUntilNotIn(vs ...time.Time) predicate.BillingInvoice { + return predicate.BillingInvoice(sql.FieldNotIn(FieldDraftUntil, vs...)) +} + +// DraftUntilGT applies the GT predicate on the "draft_until" field. +func DraftUntilGT(v time.Time) predicate.BillingInvoice { + return predicate.BillingInvoice(sql.FieldGT(FieldDraftUntil, v)) +} + +// DraftUntilGTE applies the GTE predicate on the "draft_until" field. +func DraftUntilGTE(v time.Time) predicate.BillingInvoice { + return predicate.BillingInvoice(sql.FieldGTE(FieldDraftUntil, v)) +} + +// DraftUntilLT applies the LT predicate on the "draft_until" field. +func DraftUntilLT(v time.Time) predicate.BillingInvoice { + return predicate.BillingInvoice(sql.FieldLT(FieldDraftUntil, v)) +} + +// DraftUntilLTE applies the LTE predicate on the "draft_until" field. +func DraftUntilLTE(v time.Time) predicate.BillingInvoice { + return predicate.BillingInvoice(sql.FieldLTE(FieldDraftUntil, v)) +} + +// DraftUntilIsNil applies the IsNil predicate on the "draft_until" field. +func DraftUntilIsNil() predicate.BillingInvoice { + return predicate.BillingInvoice(sql.FieldIsNull(FieldDraftUntil)) +} + +// DraftUntilNotNil applies the NotNil predicate on the "draft_until" field. +func DraftUntilNotNil() predicate.BillingInvoice { + return predicate.BillingInvoice(sql.FieldNotNull(FieldDraftUntil)) +} + // CurrencyEQ applies the EQ predicate on the "currency" field. func CurrencyEQ(v currencyx.Code) predicate.BillingInvoice { vc := string(v) @@ -2391,32 +2446,22 @@ func DueAtNotNil() predicate.BillingInvoice { // StatusEQ applies the EQ predicate on the "status" field. func StatusEQ(v billingentity.InvoiceStatus) predicate.BillingInvoice { - vc := v - return predicate.BillingInvoice(sql.FieldEQ(FieldStatus, vc)) + return predicate.BillingInvoice(sql.FieldEQ(FieldStatus, v)) } // StatusNEQ applies the NEQ predicate on the "status" field. func StatusNEQ(v billingentity.InvoiceStatus) predicate.BillingInvoice { - vc := v - return predicate.BillingInvoice(sql.FieldNEQ(FieldStatus, vc)) + return predicate.BillingInvoice(sql.FieldNEQ(FieldStatus, v)) } // StatusIn applies the In predicate on the "status" field. func StatusIn(vs ...billingentity.InvoiceStatus) predicate.BillingInvoice { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.BillingInvoice(sql.FieldIn(FieldStatus, v...)) + return predicate.BillingInvoice(sql.FieldIn(FieldStatus, vs...)) } // StatusNotIn applies the NotIn predicate on the "status" field. func StatusNotIn(vs ...billingentity.InvoiceStatus) predicate.BillingInvoice { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.BillingInvoice(sql.FieldNotIn(FieldStatus, v...)) + return predicate.BillingInvoice(sql.FieldNotIn(FieldStatus, vs...)) } // WorkflowConfigIDEQ applies the EQ predicate on the "workflow_config_id" field. diff --git a/openmeter/ent/db/billinginvoice_create.go b/openmeter/ent/db/billinginvoice_create.go index d2cb6382c..64bbc3dbf 100644 --- a/openmeter/ent/db/billinginvoice_create.go +++ b/openmeter/ent/db/billinginvoice_create.go @@ -396,6 +396,20 @@ func (bic *BillingInvoiceCreate) SetNillableIssuedAt(t *time.Time) *BillingInvoi return bic } +// SetDraftUntil sets the "draft_until" field. +func (bic *BillingInvoiceCreate) SetDraftUntil(t time.Time) *BillingInvoiceCreate { + bic.mutation.SetDraftUntil(t) + return bic +} + +// SetNillableDraftUntil sets the "draft_until" field if the given value is not nil. +func (bic *BillingInvoiceCreate) SetNillableDraftUntil(t *time.Time) *BillingInvoiceCreate { + if t != nil { + bic.SetDraftUntil(*t) + } + return bic +} + // SetCurrency sets the "currency" field. func (bic *BillingInvoiceCreate) SetCurrency(c currencyx.Code) *BillingInvoiceCreate { bic.mutation.SetCurrency(c) @@ -859,6 +873,10 @@ func (bic *BillingInvoiceCreate) createSpec() (*BillingInvoice, *sqlgraph.Create _spec.SetField(billinginvoice.FieldIssuedAt, field.TypeTime, value) _node.IssuedAt = &value } + if value, ok := bic.mutation.DraftUntil(); ok { + _spec.SetField(billinginvoice.FieldDraftUntil, field.TypeTime, value) + _node.DraftUntil = &value + } if value, ok := bic.mutation.Currency(); ok { _spec.SetField(billinginvoice.FieldCurrency, field.TypeString, value) _node.Currency = value @@ -1493,6 +1511,24 @@ func (u *BillingInvoiceUpsert) ClearIssuedAt() *BillingInvoiceUpsert { return u } +// SetDraftUntil sets the "draft_until" field. +func (u *BillingInvoiceUpsert) SetDraftUntil(v time.Time) *BillingInvoiceUpsert { + u.Set(billinginvoice.FieldDraftUntil, v) + return u +} + +// UpdateDraftUntil sets the "draft_until" field to the value that was provided on create. +func (u *BillingInvoiceUpsert) UpdateDraftUntil() *BillingInvoiceUpsert { + u.SetExcluded(billinginvoice.FieldDraftUntil) + return u +} + +// ClearDraftUntil clears the value of the "draft_until" field. +func (u *BillingInvoiceUpsert) ClearDraftUntil() *BillingInvoiceUpsert { + u.SetNull(billinginvoice.FieldDraftUntil) + return u +} + // SetDueAt sets the "due_at" field. func (u *BillingInvoiceUpsert) SetDueAt(v time.Time) *BillingInvoiceUpsert { u.Set(billinginvoice.FieldDueAt, v) @@ -2161,6 +2197,27 @@ func (u *BillingInvoiceUpsertOne) ClearIssuedAt() *BillingInvoiceUpsertOne { }) } +// SetDraftUntil sets the "draft_until" field. +func (u *BillingInvoiceUpsertOne) SetDraftUntil(v time.Time) *BillingInvoiceUpsertOne { + return u.Update(func(s *BillingInvoiceUpsert) { + s.SetDraftUntil(v) + }) +} + +// UpdateDraftUntil sets the "draft_until" field to the value that was provided on create. +func (u *BillingInvoiceUpsertOne) UpdateDraftUntil() *BillingInvoiceUpsertOne { + return u.Update(func(s *BillingInvoiceUpsert) { + s.UpdateDraftUntil() + }) +} + +// ClearDraftUntil clears the value of the "draft_until" field. +func (u *BillingInvoiceUpsertOne) ClearDraftUntil() *BillingInvoiceUpsertOne { + return u.Update(func(s *BillingInvoiceUpsert) { + s.ClearDraftUntil() + }) +} + // SetDueAt sets the "due_at" field. func (u *BillingInvoiceUpsertOne) SetDueAt(v time.Time) *BillingInvoiceUpsertOne { return u.Update(func(s *BillingInvoiceUpsert) { @@ -3009,6 +3066,27 @@ func (u *BillingInvoiceUpsertBulk) ClearIssuedAt() *BillingInvoiceUpsertBulk { }) } +// SetDraftUntil sets the "draft_until" field. +func (u *BillingInvoiceUpsertBulk) SetDraftUntil(v time.Time) *BillingInvoiceUpsertBulk { + return u.Update(func(s *BillingInvoiceUpsert) { + s.SetDraftUntil(v) + }) +} + +// UpdateDraftUntil sets the "draft_until" field to the value that was provided on create. +func (u *BillingInvoiceUpsertBulk) UpdateDraftUntil() *BillingInvoiceUpsertBulk { + return u.Update(func(s *BillingInvoiceUpsert) { + s.UpdateDraftUntil() + }) +} + +// ClearDraftUntil clears the value of the "draft_until" field. +func (u *BillingInvoiceUpsertBulk) ClearDraftUntil() *BillingInvoiceUpsertBulk { + return u.Update(func(s *BillingInvoiceUpsert) { + s.ClearDraftUntil() + }) +} + // SetDueAt sets the "due_at" field. func (u *BillingInvoiceUpsertBulk) SetDueAt(v time.Time) *BillingInvoiceUpsertBulk { return u.Update(func(s *BillingInvoiceUpsert) { diff --git a/openmeter/ent/db/billinginvoice_update.go b/openmeter/ent/db/billinginvoice_update.go index 1c5e57d18..901b39414 100644 --- a/openmeter/ent/db/billinginvoice_update.go +++ b/openmeter/ent/db/billinginvoice_update.go @@ -513,6 +513,26 @@ func (biu *BillingInvoiceUpdate) ClearIssuedAt() *BillingInvoiceUpdate { return biu } +// SetDraftUntil sets the "draft_until" field. +func (biu *BillingInvoiceUpdate) SetDraftUntil(t time.Time) *BillingInvoiceUpdate { + biu.mutation.SetDraftUntil(t) + return biu +} + +// SetNillableDraftUntil sets the "draft_until" field if the given value is not nil. +func (biu *BillingInvoiceUpdate) SetNillableDraftUntil(t *time.Time) *BillingInvoiceUpdate { + if t != nil { + biu.SetDraftUntil(*t) + } + return biu +} + +// ClearDraftUntil clears the value of the "draft_until" field. +func (biu *BillingInvoiceUpdate) ClearDraftUntil() *BillingInvoiceUpdate { + biu.mutation.ClearDraftUntil() + return biu +} + // SetDueAt sets the "due_at" field. func (biu *BillingInvoiceUpdate) SetDueAt(t time.Time) *BillingInvoiceUpdate { biu.mutation.SetDueAt(t) @@ -909,6 +929,12 @@ func (biu *BillingInvoiceUpdate) sqlSave(ctx context.Context) (n int, err error) if biu.mutation.IssuedAtCleared() { _spec.ClearField(billinginvoice.FieldIssuedAt, field.TypeTime) } + if value, ok := biu.mutation.DraftUntil(); ok { + _spec.SetField(billinginvoice.FieldDraftUntil, field.TypeTime, value) + } + if biu.mutation.DraftUntilCleared() { + _spec.ClearField(billinginvoice.FieldDraftUntil, field.TypeTime) + } if value, ok := biu.mutation.DueAt(); ok { _spec.SetField(billinginvoice.FieldDueAt, field.TypeTime, value) } @@ -1504,6 +1530,26 @@ func (biuo *BillingInvoiceUpdateOne) ClearIssuedAt() *BillingInvoiceUpdateOne { return biuo } +// SetDraftUntil sets the "draft_until" field. +func (biuo *BillingInvoiceUpdateOne) SetDraftUntil(t time.Time) *BillingInvoiceUpdateOne { + biuo.mutation.SetDraftUntil(t) + return biuo +} + +// SetNillableDraftUntil sets the "draft_until" field if the given value is not nil. +func (biuo *BillingInvoiceUpdateOne) SetNillableDraftUntil(t *time.Time) *BillingInvoiceUpdateOne { + if t != nil { + biuo.SetDraftUntil(*t) + } + return biuo +} + +// ClearDraftUntil clears the value of the "draft_until" field. +func (biuo *BillingInvoiceUpdateOne) ClearDraftUntil() *BillingInvoiceUpdateOne { + biuo.mutation.ClearDraftUntil() + return biuo +} + // SetDueAt sets the "due_at" field. func (biuo *BillingInvoiceUpdateOne) SetDueAt(t time.Time) *BillingInvoiceUpdateOne { biuo.mutation.SetDueAt(t) @@ -1930,6 +1976,12 @@ func (biuo *BillingInvoiceUpdateOne) sqlSave(ctx context.Context) (_node *Billin if biuo.mutation.IssuedAtCleared() { _spec.ClearField(billinginvoice.FieldIssuedAt, field.TypeTime) } + if value, ok := biuo.mutation.DraftUntil(); ok { + _spec.SetField(billinginvoice.FieldDraftUntil, field.TypeTime, value) + } + if biuo.mutation.DraftUntilCleared() { + _spec.ClearField(billinginvoice.FieldDraftUntil, field.TypeTime) + } if value, ok := biuo.mutation.DueAt(); ok { _spec.SetField(billinginvoice.FieldDueAt, field.TypeTime, value) } diff --git a/openmeter/ent/db/migrate/schema.go b/openmeter/ent/db/migrate/schema.go index 0c82b435d..374ab071c 100644 --- a/openmeter/ent/db/migrate/schema.go +++ b/openmeter/ent/db/migrate/schema.go @@ -327,9 +327,10 @@ var ( {Name: "description", Type: field.TypeString, Nullable: true}, {Name: "voided_at", Type: field.TypeTime, Nullable: true}, {Name: "issued_at", Type: field.TypeTime, Nullable: true}, + {Name: "draft_until", Type: field.TypeTime, Nullable: true}, {Name: "currency", Type: field.TypeString, SchemaType: map[string]string{"postgres": "varchar(3)"}}, {Name: "due_at", Type: field.TypeTime, Nullable: true}, - {Name: "status", Type: field.TypeEnum, Enums: []string{"gathering", "created", "draft", "draft_sync", "draft_sync_failed", "issuing", "issued", "issuing_failed", "manual_approval_needed"}}, + {Name: "status", Type: field.TypeEnum, Enums: []string{"gathering", "draft", "draft_created", "draft_manual_approve_pending", "draft_validating", "draft_invalid", "draft_sync", "draft_sync_failed", "draft_waiting_auto_approval", "issuing", "issuing_sync_failed", "issued"}}, {Name: "period_start", Type: field.TypeTime, Nullable: true}, {Name: "period_end", Type: field.TypeTime, Nullable: true}, {Name: "tax_app_id", Type: field.TypeString, SchemaType: map[string]string{"postgres": "char(26)"}}, @@ -347,37 +348,37 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "billing_invoices_apps_billing_invoice_tax_app", - Columns: []*schema.Column{BillingInvoicesColumns[34]}, + Columns: []*schema.Column{BillingInvoicesColumns[35]}, RefColumns: []*schema.Column{AppsColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "billing_invoices_apps_billing_invoice_invoicing_app", - Columns: []*schema.Column{BillingInvoicesColumns[35]}, + Columns: []*schema.Column{BillingInvoicesColumns[36]}, RefColumns: []*schema.Column{AppsColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "billing_invoices_apps_billing_invoice_payment_app", - Columns: []*schema.Column{BillingInvoicesColumns[36]}, + Columns: []*schema.Column{BillingInvoicesColumns[37]}, RefColumns: []*schema.Column{AppsColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "billing_invoices_billing_profiles_billing_invoices", - Columns: []*schema.Column{BillingInvoicesColumns[37]}, + Columns: []*schema.Column{BillingInvoicesColumns[38]}, RefColumns: []*schema.Column{BillingProfilesColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "billing_invoices_billing_workflow_configs_billing_invoices", - Columns: []*schema.Column{BillingInvoicesColumns[38]}, + Columns: []*schema.Column{BillingInvoicesColumns[39]}, RefColumns: []*schema.Column{BillingWorkflowConfigsColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "billing_invoices_customers_billing_invoice", - Columns: []*schema.Column{BillingInvoicesColumns[39]}, + Columns: []*schema.Column{BillingInvoicesColumns[40]}, RefColumns: []*schema.Column{CustomersColumns[0]}, OnDelete: schema.NoAction, }, @@ -401,7 +402,7 @@ var ( { Name: "billinginvoice_namespace_customer_id", Unique: false, - Columns: []*schema.Column{BillingInvoicesColumns[1], BillingInvoicesColumns[39]}, + Columns: []*schema.Column{BillingInvoicesColumns[1], BillingInvoicesColumns[40]}, }, }, } diff --git a/openmeter/ent/db/mutation.go b/openmeter/ent/db/mutation.go index 651341479..67b5215e0 100644 --- a/openmeter/ent/db/mutation.go +++ b/openmeter/ent/db/mutation.go @@ -6192,6 +6192,7 @@ type BillingInvoiceMutation struct { description *string voided_at *time.Time issued_at *time.Time + draft_until *time.Time currency *currencyx.Code due_at *time.Time status *billingentity.InvoiceStatus @@ -7688,6 +7689,55 @@ func (m *BillingInvoiceMutation) ResetIssuedAt() { delete(m.clearedFields, billinginvoice.FieldIssuedAt) } +// SetDraftUntil sets the "draft_until" field. +func (m *BillingInvoiceMutation) SetDraftUntil(t time.Time) { + m.draft_until = &t +} + +// DraftUntil returns the value of the "draft_until" field in the mutation. +func (m *BillingInvoiceMutation) DraftUntil() (r time.Time, exists bool) { + v := m.draft_until + if v == nil { + return + } + return *v, true +} + +// OldDraftUntil returns the old "draft_until" field's value of the BillingInvoice entity. +// If the BillingInvoice object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *BillingInvoiceMutation) OldDraftUntil(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDraftUntil is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDraftUntil requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDraftUntil: %w", err) + } + return oldValue.DraftUntil, nil +} + +// ClearDraftUntil clears the value of the "draft_until" field. +func (m *BillingInvoiceMutation) ClearDraftUntil() { + m.draft_until = nil + m.clearedFields[billinginvoice.FieldDraftUntil] = struct{}{} +} + +// DraftUntilCleared returns if the "draft_until" field was cleared in this mutation. +func (m *BillingInvoiceMutation) DraftUntilCleared() bool { + _, ok := m.clearedFields[billinginvoice.FieldDraftUntil] + return ok +} + +// ResetDraftUntil resets all changes to the "draft_until" field. +func (m *BillingInvoiceMutation) ResetDraftUntil() { + m.draft_until = nil + delete(m.clearedFields, billinginvoice.FieldDraftUntil) +} + // SetCurrency sets the "currency" field. func (m *BillingInvoiceMutation) SetCurrency(c currencyx.Code) { m.currency = &c @@ -8327,7 +8377,7 @@ func (m *BillingInvoiceMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *BillingInvoiceMutation) Fields() []string { - fields := make([]string, 0, 39) + fields := make([]string, 0, 40) if m.namespace != nil { fields = append(fields, billinginvoice.FieldNamespace) } @@ -8418,6 +8468,9 @@ func (m *BillingInvoiceMutation) Fields() []string { if m.issued_at != nil { fields = append(fields, billinginvoice.FieldIssuedAt) } + if m.draft_until != nil { + fields = append(fields, billinginvoice.FieldDraftUntil) + } if m.currency != nil { fields = append(fields, billinginvoice.FieldCurrency) } @@ -8513,6 +8566,8 @@ func (m *BillingInvoiceMutation) Field(name string) (ent.Value, bool) { return m.VoidedAt() case billinginvoice.FieldIssuedAt: return m.IssuedAt() + case billinginvoice.FieldDraftUntil: + return m.DraftUntil() case billinginvoice.FieldCurrency: return m.Currency() case billinginvoice.FieldDueAt: @@ -8600,6 +8655,8 @@ func (m *BillingInvoiceMutation) OldField(ctx context.Context, name string) (ent return m.OldVoidedAt(ctx) case billinginvoice.FieldIssuedAt: return m.OldIssuedAt(ctx) + case billinginvoice.FieldDraftUntil: + return m.OldDraftUntil(ctx) case billinginvoice.FieldCurrency: return m.OldCurrency(ctx) case billinginvoice.FieldDueAt: @@ -8837,6 +8894,13 @@ func (m *BillingInvoiceMutation) SetField(name string, value ent.Value) error { } m.SetIssuedAt(v) return nil + case billinginvoice.FieldDraftUntil: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDraftUntil(v) + return nil case billinginvoice.FieldCurrency: v, ok := value.(currencyx.Code) if !ok { @@ -8996,6 +9060,9 @@ func (m *BillingInvoiceMutation) ClearedFields() []string { if m.FieldCleared(billinginvoice.FieldIssuedAt) { fields = append(fields, billinginvoice.FieldIssuedAt) } + if m.FieldCleared(billinginvoice.FieldDraftUntil) { + fields = append(fields, billinginvoice.FieldDraftUntil) + } if m.FieldCleared(billinginvoice.FieldDueAt) { fields = append(fields, billinginvoice.FieldDueAt) } @@ -9085,6 +9152,9 @@ func (m *BillingInvoiceMutation) ClearField(name string) error { case billinginvoice.FieldIssuedAt: m.ClearIssuedAt() return nil + case billinginvoice.FieldDraftUntil: + m.ClearDraftUntil() + return nil case billinginvoice.FieldDueAt: m.ClearDueAt() return nil @@ -9192,6 +9262,9 @@ func (m *BillingInvoiceMutation) ResetField(name string) error { case billinginvoice.FieldIssuedAt: m.ResetIssuedAt() return nil + case billinginvoice.FieldDraftUntil: + m.ResetDraftUntil() + return nil case billinginvoice.FieldCurrency: m.ResetCurrency() return nil diff --git a/openmeter/ent/db/runtime.go b/openmeter/ent/db/runtime.go index 7845c1c5b..82567cc5a 100644 --- a/openmeter/ent/db/runtime.go +++ b/openmeter/ent/db/runtime.go @@ -292,7 +292,7 @@ func init() { // billinginvoice.SourceBillingProfileIDValidator is a validator for the "source_billing_profile_id" field. It is called by the builders before save. billinginvoice.SourceBillingProfileIDValidator = billinginvoiceDescSourceBillingProfileID.Validators[0].(func(string) error) // billinginvoiceDescCurrency is the schema descriptor for currency field. - billinginvoiceDescCurrency := billinginvoiceFields[11].Descriptor() + billinginvoiceDescCurrency := billinginvoiceFields[12].Descriptor() // billinginvoice.CurrencyValidator is a validator for the "currency" field. It is called by the builders before save. billinginvoice.CurrencyValidator = billinginvoiceDescCurrency.Validators[0].(func(string) error) // billinginvoiceDescID is the schema descriptor for id field. diff --git a/openmeter/ent/db/setorclear.go b/openmeter/ent/db/setorclear.go index 4bd88a69a..051c17c5e 100644 --- a/openmeter/ent/db/setorclear.go +++ b/openmeter/ent/db/setorclear.go @@ -545,6 +545,20 @@ func (u *BillingInvoiceUpdateOne) SetOrClearIssuedAt(value *time.Time) *BillingI return u.SetIssuedAt(*value) } +func (u *BillingInvoiceUpdate) SetOrClearDraftUntil(value *time.Time) *BillingInvoiceUpdate { + if value == nil { + return u.ClearDraftUntil() + } + return u.SetDraftUntil(*value) +} + +func (u *BillingInvoiceUpdateOne) SetOrClearDraftUntil(value *time.Time) *BillingInvoiceUpdateOne { + if value == nil { + return u.ClearDraftUntil() + } + return u.SetDraftUntil(*value) +} + func (u *BillingInvoiceUpdate) SetOrClearDueAt(value *time.Time) *BillingInvoiceUpdate { if value == nil { return u.ClearDueAt() diff --git a/openmeter/ent/schema/billing.go b/openmeter/ent/schema/billing.go index 9d41b479b..b8eb7f03d 100644 --- a/openmeter/ent/schema/billing.go +++ b/openmeter/ent/schema/billing.go @@ -402,6 +402,10 @@ func (BillingInvoice) Fields() []ent.Field { Optional(). Nillable(), + field.Time("draft_until"). + Optional(). + Nillable(), + field.String("currency"). GoType(currencyx.Code("")). NotEmpty(). @@ -415,7 +419,7 @@ func (BillingInvoice) Fields() []ent.Field { Nillable(), field.Enum("status"). - GoType(billingentity.InvoiceStatus("")), + GoType(billingentity.InvoiceStatus{}), // Cloned profile settings field.String("workflow_config_id"). diff --git a/test/billing/invoice_test.go b/test/billing/invoice_test.go index 70cdcc8a2..cb161ca00 100644 --- a/test/billing/invoice_test.go +++ b/test/billing/invoice_test.go @@ -2,6 +2,7 @@ package billing_test import ( "context" + "errors" "testing" "time" @@ -15,6 +16,7 @@ import ( billingentity "github.com/openmeterio/openmeter/openmeter/billing/entity" customerentity "github.com/openmeterio/openmeter/openmeter/customer/entity" "github.com/openmeterio/openmeter/pkg/currencyx" + "github.com/openmeterio/openmeter/pkg/datex" "github.com/openmeterio/openmeter/pkg/models" "github.com/openmeterio/openmeter/pkg/pagination" ) @@ -539,3 +541,161 @@ func (s *InvoicingTestSuite) TestCreateInvoice() { require.Len(s.T(), gatheringInvoice.Lines, 0, "deleted gathering invoice is empty") }) } + +type draftInvoiceInput struct { + Namespace string + Customer *customerentity.Customer +} + +func (i draftInvoiceInput) Validate() error { + if i.Namespace == "" { + return errors.New("namespace is required") + } + + if err := i.Customer.Validate(); err != nil { + return err + } + + return nil +} + +func (s *InvoicingTestSuite) createDraftInvoice(t *testing.T, ctx context.Context, in draftInvoiceInput) billingentity.Invoice { + namespace := in.Customer.Namespace + + now := time.Now() + invoiceAt := now.Add(-time.Second) + periodEnd := now.Add(-24 * time.Hour) + periodStart := periodEnd.Add(-24 * 30 * time.Hour) + // Given we have a default profile for the namespace + + res, err := s.BillingService.CreateInvoiceLines(ctx, + billing.CreateInvoiceLinesInput{ + Namespace: in.Customer.Namespace, + CustomerID: in.Customer.ID, + Lines: []billingentity.Line{ + { + LineBase: billingentity.LineBase{ + Namespace: namespace, + Period: billingentity.Period{Start: periodStart, End: periodEnd}, + + InvoiceAt: invoiceAt, + + Type: billingentity.InvoiceLineTypeManualFee, + + Name: "Test item1", + Currency: currencyx.Code(currency.USD), + + Metadata: map[string]string{ + "key": "value", + }, + }, + ManualFee: &billingentity.ManualFeeLine{ + Price: alpacadecimal.NewFromFloat(100), + Quantity: alpacadecimal.NewFromFloat(1), + }, + }, + { + LineBase: billingentity.LineBase{ + Namespace: namespace, + Period: billingentity.Period{Start: periodStart, End: periodEnd}, + + InvoiceAt: invoiceAt, + + Type: billingentity.InvoiceLineTypeManualFee, + + Name: "Test item2", + Currency: currencyx.Code(currency.USD), + }, + ManualFee: &billingentity.ManualFeeLine{ + Price: alpacadecimal.NewFromFloat(200), + Quantity: alpacadecimal.NewFromFloat(3), + }, + }, + }, + }) + + require.NoError(s.T(), err) + require.Len(s.T(), res.Lines, 2) + line1ID := res.Lines[0].ID + line2ID := res.Lines[1].ID + require.NotEmpty(s.T(), line1ID) + require.NotEmpty(s.T(), line2ID) + + invoice, err := s.BillingService.CreateInvoice(ctx, billing.CreateInvoiceInput{ + Customer: customerentity.CustomerID{ + ID: in.Customer.ID, + Namespace: in.Customer.Namespace, + }, + AsOf: lo.ToPtr(now), + }) + + require.NoError(t, err) + require.Len(t, invoice, 1) + require.Len(t, invoice[0].Lines, 2) + + return invoice[0] +} + +func (s *InvoicingTestSuite) TestInvoicingFlowHappyPath() { + namespace := "ns-invoicing-flow-happy-path" + + _ = s.installSandboxApp(s.T(), namespace) + + ctx := context.Background() + + // Given we have a test customer + + customerEntity, err := s.CustomerService.CreateCustomer(ctx, customerentity.CreateCustomerInput{ + Namespace: namespace, + + Customer: customerentity.Customer{ + ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ + Name: "Test Customer", + }), + PrimaryEmail: lo.ToPtr("test@test.com"), + BillingAddress: &models.Address{ + Country: lo.ToPtr(models.CountryCode("US")), + }, + Currency: lo.ToPtr(currencyx.Code(currency.USD)), + }, + }) + require.NoError(s.T(), err) + require.NotNil(s.T(), customerEntity) + require.NotEmpty(s.T(), customerEntity.ID) + + // Given we have a billing profile + minimalCreateProfileInput := minimalCreateProfileInputTemplate + minimalCreateProfileInput.Namespace = namespace + minimalCreateProfileInput.WorkflowConfig = billingentity.WorkflowConfig{ + Collection: billingentity.CollectionConfig{ + Alignment: billingentity.AlignmentKindSubscription, + }, + Invoicing: billingentity.InvoicingConfig{ + AutoAdvance: lo.ToPtr(true), + DraftPeriod: lo.Must(datex.ISOString("PT0S").Parse()), + DueAfter: lo.Must(datex.ISOString("P1W").Parse()), + }, + } + + profile, err := s.BillingService.CreateProfile(ctx, minimalCreateProfileInput) + + require.NoError(s.T(), err) + require.NotNil(s.T(), profile) + + invoice := s.createDraftInvoice(s.T(), ctx, draftInvoiceInput{ + Namespace: namespace, + Customer: customerEntity, + }) + require.NotNil(s.T(), invoice) + + // Given we have a draft invoice + require.Equal(s.T(), billingentity.InvoiceStatusDraftCreated, invoice.Status) + + // When we advance the invoice + _, err = s.BillingService.AdvanceInvoice(ctx, billing.AdvanceInvoiceInput{ + ID: invoice.ID, + Namespace: namespace, + }) + + require.NoError(s.T(), err) +}