diff --git a/cmd/commands/cmd_payments.go b/cmd/commands/cmd_payments.go index c8787a6cc15..b9a9eef95a4 100644 --- a/cmd/commands/cmd_payments.go +++ b/cmd/commands/cmd_payments.go @@ -26,6 +26,8 @@ import ( "github.com/lightningnetwork/lnd/routing/route" "github.com/urfave/cli" "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) const ( @@ -488,10 +490,47 @@ func routerRPCSendPayment(ctx context.Context, payConn grpc.ClientConnInterface, return routerrpc.NewRouterClient(payConn).SendPaymentV2(ctx, req) } +// maybeValidateMPPParams validates that the MPP parameters are compatible with the +// payment amount. It returns an error if the parameters don't allow the full +// payment amount to be sent. +// maybeValidateMPPParams could be enhanced with additional checks +func maybeValidateMPPParams(amt lnwire.MilliSatoshi, maxParts uint32, + maxShardSize lnwire.MilliSatoshi) error { + // Add validation for negative or zero amounts + if amt <= 0 { + return fmt.Errorf("payment amount must be positive, got %v", amt) + } + + // Early return if MPP is not being used + if maxParts == 0 || maxShardSize == 0 { + return nil + } + + // Validate maxShardSize + if maxShardSize <= 0 { + return fmt.Errorf("max_shard_size must be positive, got %v", maxShardSize) + } + + // Validate maxParts + if maxParts <= 0 { + return fmt.Errorf("max_parts must be positive, got %v", maxParts) + } + + // Calculate maximum possible amount + maxPossibleAmount := lnwire.MilliSatoshi(maxParts) * maxShardSize + + if amt > maxPossibleAmount { + return fmt.Errorf("payment amount %v exceeds maximum possible "+ + "amount %v with max_parts=%v and max_shard_size_msat=%v", + amt, maxPossibleAmount, maxParts, maxShardSize) + } + + return nil +} + func SendPaymentRequest(ctx *cli.Context, req *routerrpc.SendPaymentRequest, lnConn, paymentConn grpc.ClientConnInterface, callSendPayment SendPaymentFn) error { - ctxc := getContext() lnClient := lnrpc.NewLightningClient(lnConn) @@ -586,7 +625,13 @@ func SendPaymentRequest(ctx *cli.Context, req *routerrpc.SendPaymentRequest, if invoiceAmt != 0 { amt = invoiceAmt } - + err = maybeValidateMPPParams( + lnwire.MilliSatoshi(amt*1000), req.MaxParts, lnwire.MilliSatoshi(req.MaxShardSizeMsat), + ) + if err != nil { + return status.Errorf(codes.InvalidArgument, + "invalid MPP parameters: %v", err) + } // Calculate fee limit based on the determined amount. feeLimit, err = retrieveFeeLimit(ctx, amt) if err != nil { @@ -603,6 +648,13 @@ func SendPaymentRequest(ctx *cli.Context, req *routerrpc.SendPaymentRequest, } } else { var err error + err = maybeValidateMPPParams( + lnwire.MilliSatoshi(req.Amt*1000), req.MaxParts, lnwire.MilliSatoshi(req.MaxShardSizeMsat), + ) + if err != nil { + return status.Errorf(codes.InvalidArgument, + "invalid MPP parameters: %v", err) + } feeLimit, err = retrieveFeeLimit(ctx, req.Amt) if err != nil { return err