Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: invoice lifecycle #1766

Merged
merged 4 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ require (
github.com/prometheus/client_golang v1.20.4
github.com/prometheus/client_model v0.6.1
github.com/prometheus/common v0.60.0
github.com/qmuntal/stateless v1.7.1
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
github.com/redis/go-redis/extra/redisotel/v9 v9.5.3
github.com/redis/go-redis/v9 v9.7.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -1161,6 +1161,8 @@ github.com/protocolbuffers/txtpbfmt v0.0.0-20230328191034-3462fbc510c0 h1:sadMIs
github.com/protocolbuffers/txtpbfmt v0.0.0-20230328191034-3462fbc510c0/go.mod h1:jgxiZysxFPM+iWKwQwPR+y+Jvo54ARd4EisXxKYpB5c=
github.com/pusher/pusher-http-go v4.0.1+incompatible h1:4u6tomPG1WhHaST7Wi9mw83Y+MS/j2EplR2YmDh8Xp4=
github.com/pusher/pusher-http-go v4.0.1+incompatible/go.mod h1:XAv1fxRmVTI++2xsfofDhg7whapsLRG/gH/DXbF3a18=
github.com/qmuntal/stateless v1.7.1 h1:dI+BtLHq/nD6u46POkOINTDjY9uE33/4auEzfX3TWp0=
github.com/qmuntal/stateless v1.7.1/go.mod h1:n1HjRBM/cq4uCr3rfUjaMkgeGcd+ykAZwkjLje6jGBM=
github.com/quipo/dependencysolver v0.0.0-20170801134659-2b009cb4ddcc h1:hK577yxEJ2f5s8w2iy2KimZmgrdAUZUNftE1ESmg2/Q=
github.com/quipo/dependencysolver v0.0.0-20170801134659-2b009cb4ddcc/go.mod h1:OQt6Zo5B3Zs+C49xul8kcHo+fZ1mCLPvd0LFxiZ2DHc=
github.com/r3labs/diff/v3 v3.0.1 h1:CBKqf3XmNRHXKmdU7mZP1w7TV0pDyVCis1AUHtA4Xtg=
Expand Down
1 change: 1 addition & 0 deletions openmeter/billing/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,5 @@ type InvoiceAdapter interface {
DeleteInvoices(ctx context.Context, input DeleteInvoicesAdapterInput) error
ListInvoices(ctx context.Context, input ListInvoicesInput) (ListInvoicesResponse, error)
AssociatedLineCounts(ctx context.Context, input AssociatedLineCountsAdapterInput) (AssociatedLineCountsAdapterResponse, error)
UpdateInvoice(ctx context.Context, input UpdateInvoiceAdapterInput) error
}
134 changes: 130 additions & 4 deletions openmeter/billing/adapter/invoice.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"strings"
"time"

"entgo.io/ent/dialect/sql"
"github.com/samber/lo"

"github.com/openmeterio/openmeter/api"
Expand Down Expand Up @@ -144,8 +145,18 @@ func (r *adapter) ListInvoices(ctx context.Context, input billing.ListInvoicesIn
query = query.Where(billinginvoice.IssuedAtLTE(*input.IssuedBefore))
}

if len(input.ExtendedStatuses) > 0 {
query = query.Where(billinginvoice.StatusIn(input.ExtendedStatuses...))
}

if len(input.Statuses) > 0 {
query = query.Where(billinginvoice.StatusIn(input.Statuses...))
query = query.Where(func(s *sql.Selector) {
s.Where(sql.Or(
lo.Map(input.Statuses, func(status string, _ int) *sql.Predicate {
return sql.Like(billinginvoice.FieldStatus, status+"%")
})...,
))
})
}

if len(input.Currencies) > 0 {
Expand Down Expand Up @@ -269,7 +280,7 @@ func (r *adapter) CreateInvoice(ctx context.Context, input billing.CreateInvoice
// Let's add required edges for mapping
newInvoice.Edges.BillingWorkflowConfig = clonedWorkflowConfig

return mapInvoiceFromDB(*newInvoice, billing.InvoiceExpandAll)
return mapInvoiceFromDB(*newInvoice, billingentity.InvoiceExpandAll)
}

type lineCountQueryOut struct {
Expand Down Expand Up @@ -316,7 +327,119 @@ func (r *adapter) AssociatedLineCounts(ctx context.Context, input billing.Associ
}, nil
}

func mapInvoiceFromDB(invoice db.BillingInvoice, expand billing.InvoiceExpand) (billingentity.Invoice, error) {
func (r *adapter) validateUpdateRequest(req billing.UpdateInvoiceAdapterInput, existing *db.BillingInvoice) error {
if !existing.UpdatedAt.Equal(req.UpdatedAt) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be Before/After?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope this ensures that the update object that was posted is based on the exact version in the database.

return billing.ConflictError{
Entity: billing.EntityInvoice,
ID: req.ID,
Err: fmt.Errorf("invoice has been updated since last read"),
}
}

if req.Currency != existing.Currency {
return billing.ValidationError{
Err: fmt.Errorf("currency cannot be changed"),
}
}

if req.Type != existing.Type {
return billing.ValidationError{
Err: fmt.Errorf("type cannot be changed"),
}
}

if req.Customer.CustomerID != existing.CustomerID {
return billing.ValidationError{
Err: fmt.Errorf("customer cannot be changed"),
}
}

return nil
}

// UpdateInvoice updates the specified invoice. It does not return the new invoice, as we would either
// ways need to re-fetch the invoice to get the updated edges.
func (r *adapter) UpdateInvoice(ctx context.Context, in billing.UpdateInvoiceAdapterInput) error {
existingInvoice, err := r.db.BillingInvoice.Query().
Where(billinginvoice.ID(in.ID)).
Where(billinginvoice.Namespace(in.Namespace)).
Only(ctx)
if err != nil {
return err
}

if err := r.validateUpdateRequest(in, existingInvoice); err != nil {
return err
}

updateQuery := r.db.BillingInvoice.UpdateOneID(in.ID).
Where(billinginvoice.Namespace(in.Namespace)).
SetMetadata(in.Metadata).
// Currency is immutable
SetStatus(in.Status).
// Type is immutable
SetOrClearNumber(in.Number).
SetOrClearDescription(in.Description).
SetOrClearDueAt(in.DueAt).
SetOrClearDraftUntil(in.DraftUntil).
SetOrClearIssuedAt(in.IssuedAt)

if in.Period != nil {
updateQuery = updateQuery.
SetPeriodStart(in.Period.Start).
SetPeriodEnd(in.Period.End)
} else {
updateQuery = updateQuery.
ClearPeriodStart().
ClearPeriodEnd()
}

// Supplier
updateQuery = updateQuery.
SetSupplierName(in.Supplier.Name).
SetOrClearSupplierAddressCountry(in.Supplier.Address.Country).
SetOrClearSupplierAddressPostalCode(in.Supplier.Address.PostalCode).
SetOrClearSupplierAddressCity(in.Supplier.Address.City).
SetOrClearSupplierAddressState(in.Supplier.Address.State).
SetOrClearSupplierAddressLine1(in.Supplier.Address.Line1).
SetOrClearSupplierAddressLine2(in.Supplier.Address.Line2).
SetOrClearSupplierAddressPhoneNumber(in.Supplier.Address.PhoneNumber)

// Customer
updateQuery = updateQuery.
// CustomerID is immutable
SetCustomerName(in.Customer.Name).
SetOrClearCustomerAddressCountry(in.Customer.BillingAddress.Country).
SetOrClearCustomerAddressPostalCode(in.Customer.BillingAddress.PostalCode).
SetOrClearCustomerAddressCity(in.Customer.BillingAddress.City).
SetOrClearCustomerAddressState(in.Customer.BillingAddress.State).
SetOrClearCustomerAddressLine1(in.Customer.BillingAddress.Line1).
SetOrClearCustomerAddressLine2(in.Customer.BillingAddress.Line2).
SetOrClearCustomerAddressPhoneNumber(in.Customer.BillingAddress.PhoneNumber).
SetOrClearCustomerTimezone(in.Customer.Timezone)

_, err = updateQuery.Save(ctx)
if err != nil {
return err
}

if in.ExpandedFields.Workflow {
// Update the workflow config
_, err := r.updateWorkflowConfig(ctx, in.Namespace, in.Workflow.Config.ID, in.Workflow.Config)
if err != nil {
return err
}
}

if in.ExpandedFields.Lines {
// TODO[later]: line updates (with changed flag)
r.logger.Warn("line updates are not yet implemented")
}

return nil
}

func mapInvoiceFromDB(invoice db.BillingInvoice, expand billingentity.InvoiceExpand) (billingentity.Invoice, error) {
res := billingentity.Invoice{
ID: invoice.ID,
Namespace: invoice.Namespace,
Expand All @@ -327,6 +450,7 @@ func mapInvoiceFromDB(invoice db.BillingInvoice, expand billing.InvoiceExpand) (
Number: invoice.Number,
Description: invoice.Description,
DueAt: invoice.DueAt,
DraftUntil: invoice.DraftUntil,
Supplier: billingentity.SupplierContact{
Name: invoice.SupplierName,
Address: models.Address{
Expand Down Expand Up @@ -360,6 +484,8 @@ func mapInvoiceFromDB(invoice db.BillingInvoice, expand billing.InvoiceExpand) (
CreatedAt: invoice.CreatedAt,
UpdatedAt: invoice.UpdatedAt,
DeletedAt: invoice.DeletedAt,

ExpandedFields: expand,
}

if expand.Workflow {
Expand All @@ -369,7 +495,7 @@ func mapInvoiceFromDB(invoice db.BillingInvoice, expand billing.InvoiceExpand) (
}

res.Workflow = &billingentity.InvoiceWorkflow{
WorkflowConfig: workflowConfig,
Config: workflowConfig,
SourceBillingProfileID: invoice.SourceBillingProfileID,

AppReferences: billingentity.ProfileAppReferences{
Expand Down
36 changes: 18 additions & 18 deletions openmeter/billing/adapter/profile.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ import (
"context"
"fmt"

"github.com/samber/lo"

"github.com/openmeterio/openmeter/api"
"github.com/openmeterio/openmeter/openmeter/billing"
billingentity "github.com/openmeterio/openmeter/openmeter/billing/entity"
Expand Down Expand Up @@ -68,7 +66,7 @@ func (a *adapter) createWorkflowConfig(ctx context.Context, ns string, input bil
SetNamespace(ns).
SetCollectionAlignment(input.Collection.Alignment).
SetLineCollectionPeriod(input.Collection.Interval.ISOString()).
SetInvoiceAutoAdvance(*input.Invoicing.AutoAdvance).
SetInvoiceAutoAdvance(input.Invoicing.AutoAdvance).
SetInvoiceDraftPeriod(input.Invoicing.DraftPeriod.ISOString()).
SetInvoiceDueAfter(input.Invoicing.DueAfter.ISOString()).
SetInvoiceCollectionMethod(input.Payment.CollectionMethod).
Expand Down Expand Up @@ -239,15 +237,7 @@ func (a adapter) UpdateProfile(ctx context.Context, input billing.UpdateProfileA
return nil, err
}

updatedWorkflowConfig, err := a.db.BillingWorkflowConfig.UpdateOneID(input.WorkflowConfigID).
Where(billingworkflowconfig.Namespace(targetState.Namespace)).
SetCollectionAlignment(targetState.WorkflowConfig.Collection.Alignment).
SetLineCollectionPeriod(targetState.WorkflowConfig.Collection.Interval.ISOString()).
SetInvoiceAutoAdvance(*targetState.WorkflowConfig.Invoicing.AutoAdvance).
SetInvoiceDraftPeriod(targetState.WorkflowConfig.Invoicing.DraftPeriod.ISOString()).
SetInvoiceDueAfter(targetState.WorkflowConfig.Invoicing.DueAfter.ISOString()).
SetInvoiceCollectionMethod(targetState.WorkflowConfig.Payment.CollectionMethod).
Save(ctx)
updatedWorkflowConfig, err := a.updateWorkflowConfig(ctx, targetState.Namespace, input.WorkflowConfigID, targetState.WorkflowConfig)
if err != nil {
return nil, err
}
Expand All @@ -256,6 +246,18 @@ func (a adapter) UpdateProfile(ctx context.Context, input billing.UpdateProfileA
return mapProfileFromDB(updatedProfile)
}

func (a adapter) updateWorkflowConfig(ctx context.Context, ns string, id string, input billingentity.WorkflowConfig) (*db.BillingWorkflowConfig, error) {
return a.db.BillingWorkflowConfig.UpdateOneID(id).
Where(billingworkflowconfig.Namespace(ns)).
SetCollectionAlignment(input.Collection.Alignment).
SetLineCollectionPeriod(input.Collection.Interval.ISOString()).
SetInvoiceAutoAdvance(input.Invoicing.AutoAdvance).
SetInvoiceDraftPeriod(input.Invoicing.DraftPeriod.ISOString()).
SetInvoiceDueAfter(input.Invoicing.DueAfter.ISOString()).
SetInvoiceCollectionMethod(input.Payment.CollectionMethod).
Save(ctx)
}

func mapProfileFromDB(dbProfile *db.BillingProfile) (*billingentity.BaseProfile, error) {
if dbProfile == nil {
return nil, nil
Expand Down Expand Up @@ -310,11 +312,9 @@ func mapWorkflowConfigToDB(wc billingentity.WorkflowConfig) *db.BillingWorkflowC
UpdatedAt: wc.UpdatedAt,
DeletedAt: wc.DeletedAt,

CollectionAlignment: wc.Collection.Alignment,
LineCollectionPeriod: wc.Collection.Interval.ISOString(),
InvoiceAutoAdvance: lo.FromPtrOr(
wc.Invoicing.AutoAdvance,
*billingentity.DefaultWorkflowConfig.Invoicing.AutoAdvance),
CollectionAlignment: wc.Collection.Alignment,
LineCollectionPeriod: wc.Collection.Interval.ISOString(),
InvoiceAutoAdvance: wc.Invoicing.AutoAdvance,
InvoiceDraftPeriod: wc.Invoicing.DraftPeriod.ISOString(),
InvoiceDueAfter: wc.Invoicing.DueAfter.ISOString(),
InvoiceCollectionMethod: wc.Payment.CollectionMethod,
Expand Down Expand Up @@ -350,7 +350,7 @@ func mapWorkflowConfigFromDB(dbWC *db.BillingWorkflowConfig) (billingentity.Work
},

Invoicing: billingentity.InvoicingConfig{
AutoAdvance: lo.ToPtr(dbWC.InvoiceAutoAdvance),
AutoAdvance: dbWC.InvoiceAutoAdvance,
DraftPeriod: draftPeriod,
DueAfter: dueAfter,
},
Expand Down
Loading