From cc84c1346e716c351a7498c436edcbc94c9cfb69 Mon Sep 17 00:00:00 2001 From: Denys Smirnov Date: Wed, 19 Feb 2025 21:00:43 +0200 Subject: [PATCH] Expose hooks for SIP trunk and rule matching. --- .changeset/sip-match-report.md | 5 + sip/dispatchruleconflictreason_string.go | 23 +++ sip/sip.go | 195 ++++++++++++++++++++--- sip/sip_test.go | 13 +- sip/trunkconflictreason_string.go | 25 +++ sip/trunkfilteredreason_string.go | 26 +++ 6 files changed, 260 insertions(+), 27 deletions(-) create mode 100644 .changeset/sip-match-report.md create mode 100644 sip/dispatchruleconflictreason_string.go create mode 100644 sip/trunkconflictreason_string.go create mode 100644 sip/trunkfilteredreason_string.go diff --git a/.changeset/sip-match-report.md b/.changeset/sip-match-report.md new file mode 100644 index 00000000..6d3bfefc --- /dev/null +++ b/.changeset/sip-match-report.md @@ -0,0 +1,5 @@ +--- +"github.com/livekit/protocol": minor +--- + +Expose hooks for SIP trunk and rule matching. diff --git a/sip/dispatchruleconflictreason_string.go b/sip/dispatchruleconflictreason_string.go new file mode 100644 index 00000000..2ac265f2 --- /dev/null +++ b/sip/dispatchruleconflictreason_string.go @@ -0,0 +1,23 @@ +// Code generated by "stringer -type DispatchRuleConflictReason -trimprefix DispatchRuleConflict"; DO NOT EDIT. + +package sip + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[DispatchRuleConflictGeneric-0] +} + +const _DispatchRuleConflictReason_name = "Generic" + +var _DispatchRuleConflictReason_index = [...]uint8{0, 7} + +func (i DispatchRuleConflictReason) String() string { + if i < 0 || i >= DispatchRuleConflictReason(len(_DispatchRuleConflictReason_index)-1) { + return "DispatchRuleConflictReason(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _DispatchRuleConflictReason_name[_DispatchRuleConflictReason_index[i]:_DispatchRuleConflictReason_index[i+1]] +} diff --git a/sip/sip.go b/sip/sip.go index b730e21d..3b4a4197 100644 --- a/sip/sip.go +++ b/sip/sip.go @@ -36,6 +36,10 @@ import ( "github.com/livekit/protocol/utils/guid" ) +//go:generate stringer -type TrunkFilteredReason -trimprefix TrunkFiltered +//go:generate stringer -type TrunkConflictReason -trimprefix TrunkConflict +//go:generate stringer -type DispatchRuleConflictReason -trimprefix DispatchRuleConflict + func NewCallID() string { return guid.New(utils.SIPCallPrefix) } @@ -130,14 +134,14 @@ func printID(s string) string { // ValidateDispatchRules checks a set of dispatch rules for conflicts. // // Deprecated: use ValidateDispatchRulesIter -func ValidateDispatchRules(rules []*livekit.SIPDispatchRuleInfo) error { - _, err := ValidateDispatchRulesIter(iters.Slice(rules)) +func ValidateDispatchRules(rules []*livekit.SIPDispatchRuleInfo, opts ...MatchDispatchRuleOpt) error { + _, err := ValidateDispatchRulesIter(iters.Slice(rules), opts...) return err } // ValidateDispatchRulesIter checks a set of dispatch rules for conflicts. -func ValidateDispatchRulesIter(it iters.Iter[*livekit.SIPDispatchRuleInfo]) (best *livekit.SIPDispatchRuleInfo, _ error) { - it = NewDispatchRuleValidator().ValidateIter(it) +func ValidateDispatchRulesIter(it iters.Iter[*livekit.SIPDispatchRuleInfo], opts ...MatchDispatchRuleOpt) (best *livekit.SIPDispatchRuleInfo, _ error) { + it = NewDispatchRuleValidator(opts...).ValidateIter(it) defer it.Close() for { r, err := it.Next() @@ -153,8 +157,14 @@ func ValidateDispatchRulesIter(it iters.Iter[*livekit.SIPDispatchRuleInfo]) (bes return best, nil } -func NewDispatchRuleValidator() *DispatchRuleValidator { +func NewDispatchRuleValidator(opts ...MatchDispatchRuleOpt) *DispatchRuleValidator { + var opt matchDispatchRuleOpts + for _, fnc := range opts { + fnc(&opt) + } + opt.defaults() return &DispatchRuleValidator{ + opt: opt, byRuleKey: make(map[dispatchRuleKey]*livekit.SIPDispatchRuleInfo), } } @@ -166,6 +176,7 @@ type dispatchRuleKey struct { } type DispatchRuleValidator struct { + opt matchDispatchRuleOpts byRuleKey map[dispatchRuleKey]*livekit.SIPDispatchRuleInfo } @@ -193,6 +204,10 @@ func (v *DispatchRuleValidator) Validate(r *livekit.SIPDispatchRuleInfo) error { key := dispatchRuleKey{Pin: pin, Trunk: trunk, Number: normalizeNumber(number)} r2 := v.byRuleKey[key] if r2 != nil { + v.opt.Conflict(r, r2, DispatchRuleConflictGeneric) + if v.opt.AllowConflicts { + continue + } return twirp.NewErrorf(twirp.InvalidArgument, "Conflicting SIP Dispatch Rules: same Trunk+Number+PIN combination for for %q and %q", printID(r.SipDispatchRuleId), printID(r2.SipDispatchRuleId)) } @@ -226,10 +241,10 @@ func (v *dispatchRuleValidatorIter) Close() { // It returns an error if there are conflicting rules. Returns nil if no rules match. // // Deprecated: use MatchDispatchRuleIter -func SelectDispatchRule(rules []*livekit.SIPDispatchRuleInfo, req *rpc.EvaluateSIPDispatchRulesRequest) (*livekit.SIPDispatchRuleInfo, error) { +func SelectDispatchRule(rules []*livekit.SIPDispatchRuleInfo, req *rpc.EvaluateSIPDispatchRulesRequest, opts ...MatchDispatchRuleOpt) (*livekit.SIPDispatchRuleInfo, error) { // Sorting will do the selection for us. We already filtered out irrelevant ones in MatchDispatchRule and above. // Nil is fine here. We will report "no rules matched" later. - return ValidateDispatchRulesIter(iters.Slice(rules)) + return ValidateDispatchRulesIter(iters.Slice(rules), opts...) } // GetPinAndRoom returns a room name/prefix and the pin for a dispatch rule. Just a convenience wrapper. @@ -270,9 +285,13 @@ func normalizeNumber(num string) string { return num } -func validateTrunkInbound(byInbound map[string]*livekit.SIPInboundTrunkInfo, t *livekit.SIPInboundTrunkInfo) error { +func validateTrunkInbound(byInbound map[string]*livekit.SIPInboundTrunkInfo, t *livekit.SIPInboundTrunkInfo, opt *matchTrunkOpts) error { if len(t.AllowedNumbers) == 0 { if t2 := byInbound[""]; t2 != nil { + opt.Conflict(t, t2, TrunkConflictCalledNumber) + if opt.AllowConflicts { + return nil + } return twirp.NewErrorf(twirp.InvalidArgument, "Conflicting inbound SIP Trunks: %q and %q, using the same number(s) %s without AllowedNumbers set", printID(t.SipTrunkId), printID(t2.SipTrunkId), printNumbers(t.Numbers)) } @@ -282,6 +301,10 @@ func validateTrunkInbound(byInbound map[string]*livekit.SIPInboundTrunkInfo, t * inboundKey := normalizeNumber(num) t2 := byInbound[inboundKey] if t2 != nil { + opt.Conflict(t, t2, TrunkConflictCallingNumber) + if opt.AllowConflicts { + continue + } return twirp.NewErrorf(twirp.InvalidArgument, "Conflicting inbound SIP Trunks: %q and %q, using the same number(s) %s and AllowedNumber %q", printID(t.SipTrunkId), printID(t2.SipTrunkId), printNumbers(t.Numbers), num) } @@ -294,13 +317,18 @@ func validateTrunkInbound(byInbound map[string]*livekit.SIPInboundTrunkInfo, t * // ValidateTrunks checks a set of trunks for conflicts. // // Deprecated: use ValidateTrunksIter -func ValidateTrunks(trunks []*livekit.SIPInboundTrunkInfo) error { - return ValidateTrunksIter(iters.Slice(trunks)) +func ValidateTrunks(trunks []*livekit.SIPInboundTrunkInfo, opts ...MatchTrunkOpt) error { + return ValidateTrunksIter(iters.Slice(trunks), opts...) } // ValidateTrunksIter checks a set of trunks for conflicts. -func ValidateTrunksIter(it iters.Iter[*livekit.SIPInboundTrunkInfo]) error { +func ValidateTrunksIter(it iters.Iter[*livekit.SIPInboundTrunkInfo], opts ...MatchTrunkOpt) error { defer it.Close() + var opt matchTrunkOpts + for _, fnc := range opts { + fnc(&opt) + } + opt.defaults() byOutboundAndInbound := make(map[string]map[string]*livekit.SIPInboundTrunkInfo) for { t, err := it.Next() @@ -315,7 +343,7 @@ func ValidateTrunksIter(it iters.Iter[*livekit.SIPInboundTrunkInfo]) error { byInbound = make(map[string]*livekit.SIPInboundTrunkInfo) byOutboundAndInbound[""] = byInbound } - if err := validateTrunkInbound(byInbound, t); err != nil { + if err := validateTrunkInbound(byInbound, t, &opt); err != nil { return err } } else { @@ -325,7 +353,7 @@ func ValidateTrunksIter(it iters.Iter[*livekit.SIPInboundTrunkInfo]) error { byInbound = make(map[string]*livekit.SIPInboundTrunkInfo) byOutboundAndInbound[num] = byInbound } - if err := validateTrunkInbound(byInbound, t); err != nil { + if err := validateTrunkInbound(byInbound, t, &opt); err != nil { return err } } @@ -410,18 +438,85 @@ func matchNumbers(num string, allowed []string) bool { // Returns nil if no rules matched or an error if there are conflicting definitions. // // Deprecated: use MatchTrunkIter -func MatchTrunk(trunks []*livekit.SIPInboundTrunkInfo, srcIP netip.Addr, calling, called string) (*livekit.SIPInboundTrunkInfo, error) { - return MatchTrunkIter(iters.Slice(trunks), srcIP, calling, called) +func MatchTrunk(trunks []*livekit.SIPInboundTrunkInfo, srcIP netip.Addr, calling, called string, opts ...MatchTrunkOpt) (*livekit.SIPInboundTrunkInfo, error) { + return MatchTrunkIter(iters.Slice(trunks), srcIP, calling, called, opts...) +} + +type matchTrunkOpts struct { + AllowConflicts bool + Filtered TrunkFilteredFunc + Conflict TrunkConflictFunc +} + +func (opt *matchTrunkOpts) defaults() { + if opt.Filtered == nil { + opt.Filtered = func(_ *livekit.SIPInboundTrunkInfo, _ TrunkFilteredReason) {} + } + if opt.Conflict == nil { + opt.Conflict = func(_, _ *livekit.SIPInboundTrunkInfo, _ TrunkConflictReason) {} + } +} + +type MatchTrunkOpt func(opt *matchTrunkOpts) + +type TrunkFilteredReason int + +const ( + TrunkFilteredInvalid = TrunkFilteredReason(iota) + TrunkFilteredCallingNumberDisallowed + TrunkFilteredCalledNumberDisallowed + TrunkFilteredSourceAddressDisallowed +) + +type TrunkFilteredFunc func(tr *livekit.SIPInboundTrunkInfo, reason TrunkFilteredReason) + +// WithTrunkFiltered sets a callback that is called when selected Trunk(s) doesn't match the call. +func WithTrunkFiltered(fnc TrunkFilteredFunc) MatchTrunkOpt { + return func(opt *matchTrunkOpts) { + opt.Filtered = fnc + } +} + +type TrunkConflictReason int + +const ( + TrunkConflictDefault = TrunkConflictReason(iota) + TrunkConflictCalledNumber + TrunkConflictCallingNumber +) + +type TrunkConflictFunc func(t1, t2 *livekit.SIPInboundTrunkInfo, reason TrunkConflictReason) + +// WithAllowTrunkConflicts allows conflicting Trunk definitions by picking the first match. +// +// Using this option will prevent TrunkConflictFunc from firing, since the first match will be returned immediately. +func WithAllowTrunkConflicts() MatchTrunkOpt { + return func(opt *matchTrunkOpts) { + opt.AllowConflicts = true + } +} + +// WithTrunkConflict sets a callback that is called when two Trunks conflict. +func WithTrunkConflict(fnc TrunkConflictFunc) MatchTrunkOpt { + return func(opt *matchTrunkOpts) { + opt.Conflict = fnc + } } // MatchTrunkIter finds a SIP Trunk definition matching the request. // Returns nil if no rules matched or an error if there are conflicting definitions. -func MatchTrunkIter(it iters.Iter[*livekit.SIPInboundTrunkInfo], srcIP netip.Addr, calling, called string) (*livekit.SIPInboundTrunkInfo, error) { +func MatchTrunkIter(it iters.Iter[*livekit.SIPInboundTrunkInfo], srcIP netip.Addr, calling, called string, opts ...MatchTrunkOpt) (*livekit.SIPInboundTrunkInfo, error) { defer it.Close() + var opt matchTrunkOpts + for _, fnc := range opts { + fnc(&opt) + } + opt.defaults() var ( - selectedTrunk *livekit.SIPInboundTrunkInfo - defaultTrunk *livekit.SIPInboundTrunkInfo - defaultTrunkCnt int // to error in case there are multiple ones + selectedTrunk *livekit.SIPInboundTrunkInfo + defaultTrunk *livekit.SIPInboundTrunkInfo + defaultTrunkPrev *livekit.SIPInboundTrunkInfo + defaultTrunkCnt int // to error in case there are multiple ones ) calledNorm := normalizeNumber(called) for { @@ -433,13 +528,16 @@ func MatchTrunkIter(it iters.Iter[*livekit.SIPInboundTrunkInfo], srcIP netip.Add } // Do not consider it if number doesn't match. if !matchNumbers(calling, tr.AllowedNumbers) { + opt.Filtered(tr, TrunkFilteredCallingNumberDisallowed) continue } if !matchAddrMasks(srcIP, tr.AllowedAddresses) { + opt.Filtered(tr, TrunkFilteredSourceAddressDisallowed) continue } if len(tr.Numbers) == 0 { // Default/wildcard trunk. + defaultTrunkPrev = defaultTrunk defaultTrunk = tr defaultTrunkCnt++ } else { @@ -447,10 +545,21 @@ func MatchTrunkIter(it iters.Iter[*livekit.SIPInboundTrunkInfo], srcIP netip.Add if normalizeNumber(num) == calledNorm { // Trunk specific to the number. if selectedTrunk != nil { + opt.Conflict(selectedTrunk, tr, TrunkConflictCalledNumber) + if opt.AllowConflicts { + // This path is unreachable, since we pick the first trunk. Kept for completeness. + continue + } return nil, twirp.NewErrorf(twirp.FailedPrecondition, "Multiple SIP Trunks matched for %q", called) } selectedTrunk = tr + if opt.AllowConflicts { + // Pick the first match as soon as it's found. We don't care about conflicts. + return selectedTrunk, nil + } // Keep searching! We want to know if there are any conflicting Trunk definitions. + } else { + opt.Filtered(tr, TrunkFilteredCalledNumberDisallowed) } } } @@ -459,7 +568,10 @@ func MatchTrunkIter(it iters.Iter[*livekit.SIPInboundTrunkInfo], srcIP netip.Add return selectedTrunk, nil } if defaultTrunkCnt > 1 { - return nil, twirp.NewErrorf(twirp.FailedPrecondition, "Multiple default SIP Trunks matched for %q", called) + opt.Conflict(defaultTrunk, defaultTrunkPrev, TrunkConflictDefault) + if !opt.AllowConflicts { + return nil, twirp.NewErrorf(twirp.FailedPrecondition, "Multiple default SIP Trunks matched for %q", called) + } } // Could still be nil here. return defaultTrunk, nil @@ -469,14 +581,49 @@ func MatchTrunkIter(it iters.Iter[*livekit.SIPInboundTrunkInfo], srcIP netip.Add // Trunk parameter can be nil, in which case only wildcard dispatch rules will be effective (ones without Trunk IDs). // // Deprecated: use MatchDispatchRuleIter -func MatchDispatchRule(trunk *livekit.SIPInboundTrunkInfo, rules []*livekit.SIPDispatchRuleInfo, req *rpc.EvaluateSIPDispatchRulesRequest) (*livekit.SIPDispatchRuleInfo, error) { - return MatchDispatchRuleIter(trunk, iters.Slice(rules), req) +func MatchDispatchRule(trunk *livekit.SIPInboundTrunkInfo, rules []*livekit.SIPDispatchRuleInfo, req *rpc.EvaluateSIPDispatchRulesRequest, opts ...MatchDispatchRuleOpt) (*livekit.SIPDispatchRuleInfo, error) { + return MatchDispatchRuleIter(trunk, iters.Slice(rules), req, opts...) +} + +type matchDispatchRuleOpts struct { + AllowConflicts bool + Conflict DispatchRuleConflictFunc +} + +func (opt *matchDispatchRuleOpts) defaults() { + if opt.Conflict == nil { + opt.Conflict = func(_, _ *livekit.SIPDispatchRuleInfo, _ DispatchRuleConflictReason) {} + } +} + +type MatchDispatchRuleOpt func(opt *matchDispatchRuleOpts) + +type DispatchRuleConflictReason int + +const ( + DispatchRuleConflictGeneric = DispatchRuleConflictReason(iota) +) + +type DispatchRuleConflictFunc func(r1, r2 *livekit.SIPDispatchRuleInfo, reason DispatchRuleConflictReason) + +// WithAllowDispatchRuleConflicts allows conflicting DispatchRule definitions. +func WithAllowDispatchRuleConflicts() MatchDispatchRuleOpt { + return func(opt *matchDispatchRuleOpts) { + opt.AllowConflicts = true + } +} + +// WithDispatchRuleConflict sets a callback that is called when two DispatchRules conflict. +func WithDispatchRuleConflict(fnc DispatchRuleConflictFunc) MatchDispatchRuleOpt { + return func(opt *matchDispatchRuleOpts) { + opt.Conflict = fnc + } } // MatchDispatchRuleIter finds the best dispatch rule matching the request parameters. Returns an error if no rule matched. // Trunk parameter can be nil, in which case only wildcard dispatch rules will be effective (ones without Trunk IDs). -func MatchDispatchRuleIter(trunk *livekit.SIPInboundTrunkInfo, rules iters.Iter[*livekit.SIPDispatchRuleInfo], req *rpc.EvaluateSIPDispatchRulesRequest) (*livekit.SIPDispatchRuleInfo, error) { - rules = NewDispatchRuleValidator().ValidateIter(rules) +func MatchDispatchRuleIter(trunk *livekit.SIPInboundTrunkInfo, rules iters.Iter[*livekit.SIPDispatchRuleInfo], req *rpc.EvaluateSIPDispatchRulesRequest, opts ...MatchDispatchRuleOpt) (*livekit.SIPDispatchRuleInfo, error) { + rules = NewDispatchRuleValidator(opts...).ValidateIter(rules) defer rules.Close() // Trunk can still be nil here in case none matched or were defined. // This is still fine, but only in case we'll match exactly one wildcard dispatch rule. diff --git a/sip/sip_test.go b/sip/sip_test.go index 74ef08df..eccacfaf 100644 --- a/sip/sip_test.go +++ b/sip/sip_test.go @@ -16,6 +16,7 @@ package sip import ( "fmt" + "github.com/dennwc/iters" "net/netip" "strconv" "testing" @@ -226,7 +227,9 @@ func TestSIPMatchTrunk(t *testing.T) { srcIP, err = netip.ParseAddr(src) require.NoError(t, err) } - got, err := MatchTrunk(trunks, srcIP, from, to) + got, err := MatchTrunkIter(iters.Slice(trunks), srcIP, from, to, WithTrunkConflict(func(t1, t2 *livekit.SIPInboundTrunkInfo, reason TrunkConflictReason) { + t.Logf("conflict: %v\n%v\nvs\n%v", reason, t1, t2) + })) if c.expErr { require.Error(t, err) require.Nil(t, got) @@ -547,7 +550,9 @@ func TestSIPMatchDispatchRule(t *testing.T) { name = "no pin" } t.Run(name, func(t *testing.T) { - got, err := MatchDispatchRule(c.trunk.AsInbound(), c.rules, newSIPReqDispatch(pin, c.noPin)) + got, err := MatchDispatchRuleIter(c.trunk.AsInbound(), iters.Slice(c.rules), newSIPReqDispatch(pin, c.noPin), WithDispatchRuleConflict(func(r1, r2 *livekit.SIPDispatchRuleInfo, reason DispatchRuleConflictReason) { + t.Logf("conflict: %v\n%v\nvs\n%v", reason, r1, r2) + })) if c.expErr { require.Error(t, err) require.Nil(t, got) @@ -575,7 +580,9 @@ func TestSIPValidateDispatchRules(t *testing.T) { r.SipDispatchRuleId = strconv.Itoa(i) } } - err := ValidateDispatchRules(c.rules) + _, err := ValidateDispatchRulesIter(iters.Slice(c.rules), WithDispatchRuleConflict(func(r1, r2 *livekit.SIPDispatchRuleInfo, reason DispatchRuleConflictReason) { + t.Logf("conflict: %v\n%v\nvs\n%v", reason, r1, r2) + })) if c.invalid { require.Error(t, err) } else { diff --git a/sip/trunkconflictreason_string.go b/sip/trunkconflictreason_string.go new file mode 100644 index 00000000..1802d1a3 --- /dev/null +++ b/sip/trunkconflictreason_string.go @@ -0,0 +1,25 @@ +// Code generated by "stringer -type TrunkConflictReason -trimprefix TrunkConflict"; DO NOT EDIT. + +package sip + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[TrunkConflictDefault-0] + _ = x[TrunkConflictCalledNumber-1] + _ = x[TrunkConflictCallingNumber-2] +} + +const _TrunkConflictReason_name = "DefaultCalledNumberCallingNumber" + +var _TrunkConflictReason_index = [...]uint8{0, 7, 19, 32} + +func (i TrunkConflictReason) String() string { + if i < 0 || i >= TrunkConflictReason(len(_TrunkConflictReason_index)-1) { + return "TrunkConflictReason(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _TrunkConflictReason_name[_TrunkConflictReason_index[i]:_TrunkConflictReason_index[i+1]] +} diff --git a/sip/trunkfilteredreason_string.go b/sip/trunkfilteredreason_string.go new file mode 100644 index 00000000..45a39a5d --- /dev/null +++ b/sip/trunkfilteredreason_string.go @@ -0,0 +1,26 @@ +// Code generated by "stringer -type TrunkFilteredReason -trimprefix TrunkFiltered"; DO NOT EDIT. + +package sip + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[TrunkFilteredInvalid-0] + _ = x[TrunkFilteredCallingNumberDisallowed-1] + _ = x[TrunkFilteredCalledNumberDisallowed-2] + _ = x[TrunkFilteredSourceAddressDisallowed-3] +} + +const _TrunkFilteredReason_name = "InvalidCallingNumberDisallowedCalledNumberDisallowedSourceAddressDisallowed" + +var _TrunkFilteredReason_index = [...]uint8{0, 7, 30, 52, 75} + +func (i TrunkFilteredReason) String() string { + if i < 0 || i >= TrunkFilteredReason(len(_TrunkFilteredReason_index)-1) { + return "TrunkFilteredReason(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _TrunkFilteredReason_name[_TrunkFilteredReason_index[i]:_TrunkFilteredReason_index[i+1]] +}