Skip to content

Commit

Permalink
Expose hooks for SIP trunk and rule matching.
Browse files Browse the repository at this point in the history
  • Loading branch information
dennwc committed Feb 19, 2025
1 parent 0a12e2c commit cc84c13
Show file tree
Hide file tree
Showing 6 changed files with 260 additions and 27 deletions.
5 changes: 5 additions & 0 deletions .changeset/sip-match-report.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"github.com/livekit/protocol": minor
---

Expose hooks for SIP trunk and rule matching.
23 changes: 23 additions & 0 deletions sip/dispatchruleconflictreason_string.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

195 changes: 171 additions & 24 deletions sip/sip.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ import (
"github.com/livekit/protocol/utils/guid"
)

//go:generate stringer -type TrunkFilteredReason -trimprefix TrunkFiltered
//go:generate stringer -type TrunkConflictReason -trimprefix TrunkConflict
//go:generate stringer -type DispatchRuleConflictReason -trimprefix DispatchRuleConflict

func NewCallID() string {
return guid.New(utils.SIPCallPrefix)
}
Expand Down Expand Up @@ -130,14 +134,14 @@ func printID(s string) string {
// ValidateDispatchRules checks a set of dispatch rules for conflicts.
//
// Deprecated: use ValidateDispatchRulesIter
func ValidateDispatchRules(rules []*livekit.SIPDispatchRuleInfo) error {
_, err := ValidateDispatchRulesIter(iters.Slice(rules))
func ValidateDispatchRules(rules []*livekit.SIPDispatchRuleInfo, opts ...MatchDispatchRuleOpt) error {
_, err := ValidateDispatchRulesIter(iters.Slice(rules), opts...)
return err
}

// ValidateDispatchRulesIter checks a set of dispatch rules for conflicts.
func ValidateDispatchRulesIter(it iters.Iter[*livekit.SIPDispatchRuleInfo]) (best *livekit.SIPDispatchRuleInfo, _ error) {
it = NewDispatchRuleValidator().ValidateIter(it)
func ValidateDispatchRulesIter(it iters.Iter[*livekit.SIPDispatchRuleInfo], opts ...MatchDispatchRuleOpt) (best *livekit.SIPDispatchRuleInfo, _ error) {
it = NewDispatchRuleValidator(opts...).ValidateIter(it)
defer it.Close()
for {
r, err := it.Next()
Expand All @@ -153,8 +157,14 @@ func ValidateDispatchRulesIter(it iters.Iter[*livekit.SIPDispatchRuleInfo]) (bes
return best, nil
}

func NewDispatchRuleValidator() *DispatchRuleValidator {
func NewDispatchRuleValidator(opts ...MatchDispatchRuleOpt) *DispatchRuleValidator {
var opt matchDispatchRuleOpts
for _, fnc := range opts {
fnc(&opt)
}
opt.defaults()
return &DispatchRuleValidator{
opt: opt,
byRuleKey: make(map[dispatchRuleKey]*livekit.SIPDispatchRuleInfo),
}
}
Expand All @@ -166,6 +176,7 @@ type dispatchRuleKey struct {
}

type DispatchRuleValidator struct {
opt matchDispatchRuleOpts
byRuleKey map[dispatchRuleKey]*livekit.SIPDispatchRuleInfo
}

Expand Down Expand Up @@ -193,6 +204,10 @@ func (v *DispatchRuleValidator) Validate(r *livekit.SIPDispatchRuleInfo) error {
key := dispatchRuleKey{Pin: pin, Trunk: trunk, Number: normalizeNumber(number)}
r2 := v.byRuleKey[key]
if r2 != nil {
v.opt.Conflict(r, r2, DispatchRuleConflictGeneric)
if v.opt.AllowConflicts {
continue
}
return twirp.NewErrorf(twirp.InvalidArgument, "Conflicting SIP Dispatch Rules: same Trunk+Number+PIN combination for for %q and %q",
printID(r.SipDispatchRuleId), printID(r2.SipDispatchRuleId))
}
Expand Down Expand Up @@ -226,10 +241,10 @@ func (v *dispatchRuleValidatorIter) Close() {
// It returns an error if there are conflicting rules. Returns nil if no rules match.
//
// Deprecated: use MatchDispatchRuleIter
func SelectDispatchRule(rules []*livekit.SIPDispatchRuleInfo, req *rpc.EvaluateSIPDispatchRulesRequest) (*livekit.SIPDispatchRuleInfo, error) {
func SelectDispatchRule(rules []*livekit.SIPDispatchRuleInfo, req *rpc.EvaluateSIPDispatchRulesRequest, opts ...MatchDispatchRuleOpt) (*livekit.SIPDispatchRuleInfo, error) {
// Sorting will do the selection for us. We already filtered out irrelevant ones in MatchDispatchRule and above.
// Nil is fine here. We will report "no rules matched" later.
return ValidateDispatchRulesIter(iters.Slice(rules))
return ValidateDispatchRulesIter(iters.Slice(rules), opts...)
}

// GetPinAndRoom returns a room name/prefix and the pin for a dispatch rule. Just a convenience wrapper.
Expand Down Expand Up @@ -270,9 +285,13 @@ func normalizeNumber(num string) string {
return num
}

func validateTrunkInbound(byInbound map[string]*livekit.SIPInboundTrunkInfo, t *livekit.SIPInboundTrunkInfo) error {
func validateTrunkInbound(byInbound map[string]*livekit.SIPInboundTrunkInfo, t *livekit.SIPInboundTrunkInfo, opt *matchTrunkOpts) error {
if len(t.AllowedNumbers) == 0 {
if t2 := byInbound[""]; t2 != nil {
opt.Conflict(t, t2, TrunkConflictCalledNumber)
if opt.AllowConflicts {
return nil
}
return twirp.NewErrorf(twirp.InvalidArgument, "Conflicting inbound SIP Trunks: %q and %q, using the same number(s) %s without AllowedNumbers set",
printID(t.SipTrunkId), printID(t2.SipTrunkId), printNumbers(t.Numbers))
}
Expand All @@ -282,6 +301,10 @@ func validateTrunkInbound(byInbound map[string]*livekit.SIPInboundTrunkInfo, t *
inboundKey := normalizeNumber(num)
t2 := byInbound[inboundKey]
if t2 != nil {
opt.Conflict(t, t2, TrunkConflictCallingNumber)
if opt.AllowConflicts {
continue
}
return twirp.NewErrorf(twirp.InvalidArgument, "Conflicting inbound SIP Trunks: %q and %q, using the same number(s) %s and AllowedNumber %q",
printID(t.SipTrunkId), printID(t2.SipTrunkId), printNumbers(t.Numbers), num)
}
Expand All @@ -294,13 +317,18 @@ func validateTrunkInbound(byInbound map[string]*livekit.SIPInboundTrunkInfo, t *
// ValidateTrunks checks a set of trunks for conflicts.
//
// Deprecated: use ValidateTrunksIter
func ValidateTrunks(trunks []*livekit.SIPInboundTrunkInfo) error {
return ValidateTrunksIter(iters.Slice(trunks))
func ValidateTrunks(trunks []*livekit.SIPInboundTrunkInfo, opts ...MatchTrunkOpt) error {
return ValidateTrunksIter(iters.Slice(trunks), opts...)
}

// ValidateTrunksIter checks a set of trunks for conflicts.
func ValidateTrunksIter(it iters.Iter[*livekit.SIPInboundTrunkInfo]) error {
func ValidateTrunksIter(it iters.Iter[*livekit.SIPInboundTrunkInfo], opts ...MatchTrunkOpt) error {
defer it.Close()
var opt matchTrunkOpts
for _, fnc := range opts {
fnc(&opt)
}
opt.defaults()
byOutboundAndInbound := make(map[string]map[string]*livekit.SIPInboundTrunkInfo)
for {
t, err := it.Next()
Expand All @@ -315,7 +343,7 @@ func ValidateTrunksIter(it iters.Iter[*livekit.SIPInboundTrunkInfo]) error {
byInbound = make(map[string]*livekit.SIPInboundTrunkInfo)
byOutboundAndInbound[""] = byInbound
}
if err := validateTrunkInbound(byInbound, t); err != nil {
if err := validateTrunkInbound(byInbound, t, &opt); err != nil {
return err
}
} else {
Expand All @@ -325,7 +353,7 @@ func ValidateTrunksIter(it iters.Iter[*livekit.SIPInboundTrunkInfo]) error {
byInbound = make(map[string]*livekit.SIPInboundTrunkInfo)
byOutboundAndInbound[num] = byInbound
}
if err := validateTrunkInbound(byInbound, t); err != nil {
if err := validateTrunkInbound(byInbound, t, &opt); err != nil {
return err
}
}
Expand Down Expand Up @@ -410,18 +438,85 @@ func matchNumbers(num string, allowed []string) bool {
// Returns nil if no rules matched or an error if there are conflicting definitions.
//
// Deprecated: use MatchTrunkIter
func MatchTrunk(trunks []*livekit.SIPInboundTrunkInfo, srcIP netip.Addr, calling, called string) (*livekit.SIPInboundTrunkInfo, error) {
return MatchTrunkIter(iters.Slice(trunks), srcIP, calling, called)
func MatchTrunk(trunks []*livekit.SIPInboundTrunkInfo, srcIP netip.Addr, calling, called string, opts ...MatchTrunkOpt) (*livekit.SIPInboundTrunkInfo, error) {
return MatchTrunkIter(iters.Slice(trunks), srcIP, calling, called, opts...)
}

type matchTrunkOpts struct {
AllowConflicts bool
Filtered TrunkFilteredFunc
Conflict TrunkConflictFunc
}

func (opt *matchTrunkOpts) defaults() {
if opt.Filtered == nil {
opt.Filtered = func(_ *livekit.SIPInboundTrunkInfo, _ TrunkFilteredReason) {}
}
if opt.Conflict == nil {
opt.Conflict = func(_, _ *livekit.SIPInboundTrunkInfo, _ TrunkConflictReason) {}
}
}

type MatchTrunkOpt func(opt *matchTrunkOpts)

type TrunkFilteredReason int

const (
TrunkFilteredInvalid = TrunkFilteredReason(iota)
TrunkFilteredCallingNumberDisallowed
TrunkFilteredCalledNumberDisallowed
TrunkFilteredSourceAddressDisallowed
)

type TrunkFilteredFunc func(tr *livekit.SIPInboundTrunkInfo, reason TrunkFilteredReason)

// WithTrunkFiltered sets a callback that is called when selected Trunk(s) doesn't match the call.
func WithTrunkFiltered(fnc TrunkFilteredFunc) MatchTrunkOpt {
return func(opt *matchTrunkOpts) {
opt.Filtered = fnc
}
}

type TrunkConflictReason int

const (
TrunkConflictDefault = TrunkConflictReason(iota)
TrunkConflictCalledNumber
TrunkConflictCallingNumber
)

type TrunkConflictFunc func(t1, t2 *livekit.SIPInboundTrunkInfo, reason TrunkConflictReason)

// WithAllowTrunkConflicts allows conflicting Trunk definitions by picking the first match.
//
// Using this option will prevent TrunkConflictFunc from firing, since the first match will be returned immediately.
func WithAllowTrunkConflicts() MatchTrunkOpt {
return func(opt *matchTrunkOpts) {
opt.AllowConflicts = true
}
}

// WithTrunkConflict sets a callback that is called when two Trunks conflict.
func WithTrunkConflict(fnc TrunkConflictFunc) MatchTrunkOpt {
return func(opt *matchTrunkOpts) {
opt.Conflict = fnc
}
}

// MatchTrunkIter finds a SIP Trunk definition matching the request.
// Returns nil if no rules matched or an error if there are conflicting definitions.
func MatchTrunkIter(it iters.Iter[*livekit.SIPInboundTrunkInfo], srcIP netip.Addr, calling, called string) (*livekit.SIPInboundTrunkInfo, error) {
func MatchTrunkIter(it iters.Iter[*livekit.SIPInboundTrunkInfo], srcIP netip.Addr, calling, called string, opts ...MatchTrunkOpt) (*livekit.SIPInboundTrunkInfo, error) {
defer it.Close()
var opt matchTrunkOpts
for _, fnc := range opts {
fnc(&opt)
}
opt.defaults()
var (
selectedTrunk *livekit.SIPInboundTrunkInfo
defaultTrunk *livekit.SIPInboundTrunkInfo
defaultTrunkCnt int // to error in case there are multiple ones
selectedTrunk *livekit.SIPInboundTrunkInfo
defaultTrunk *livekit.SIPInboundTrunkInfo
defaultTrunkPrev *livekit.SIPInboundTrunkInfo
defaultTrunkCnt int // to error in case there are multiple ones
)
calledNorm := normalizeNumber(called)
for {
Expand All @@ -433,24 +528,38 @@ func MatchTrunkIter(it iters.Iter[*livekit.SIPInboundTrunkInfo], srcIP netip.Add
}
// Do not consider it if number doesn't match.
if !matchNumbers(calling, tr.AllowedNumbers) {
opt.Filtered(tr, TrunkFilteredCallingNumberDisallowed)
continue
}
if !matchAddrMasks(srcIP, tr.AllowedAddresses) {
opt.Filtered(tr, TrunkFilteredSourceAddressDisallowed)
continue
}
if len(tr.Numbers) == 0 {
// Default/wildcard trunk.
defaultTrunkPrev = defaultTrunk
defaultTrunk = tr
defaultTrunkCnt++
} else {
for _, num := range tr.Numbers {
if normalizeNumber(num) == calledNorm {
// Trunk specific to the number.
if selectedTrunk != nil {
opt.Conflict(selectedTrunk, tr, TrunkConflictCalledNumber)
if opt.AllowConflicts {
// This path is unreachable, since we pick the first trunk. Kept for completeness.
continue
}
return nil, twirp.NewErrorf(twirp.FailedPrecondition, "Multiple SIP Trunks matched for %q", called)
}
selectedTrunk = tr
if opt.AllowConflicts {
// Pick the first match as soon as it's found. We don't care about conflicts.
return selectedTrunk, nil
}
// Keep searching! We want to know if there are any conflicting Trunk definitions.
} else {
opt.Filtered(tr, TrunkFilteredCalledNumberDisallowed)
}
}
}
Expand All @@ -459,7 +568,10 @@ func MatchTrunkIter(it iters.Iter[*livekit.SIPInboundTrunkInfo], srcIP netip.Add
return selectedTrunk, nil
}
if defaultTrunkCnt > 1 {
return nil, twirp.NewErrorf(twirp.FailedPrecondition, "Multiple default SIP Trunks matched for %q", called)
opt.Conflict(defaultTrunk, defaultTrunkPrev, TrunkConflictDefault)
if !opt.AllowConflicts {
return nil, twirp.NewErrorf(twirp.FailedPrecondition, "Multiple default SIP Trunks matched for %q", called)
}
}
// Could still be nil here.
return defaultTrunk, nil
Expand All @@ -469,14 +581,49 @@ func MatchTrunkIter(it iters.Iter[*livekit.SIPInboundTrunkInfo], srcIP netip.Add
// Trunk parameter can be nil, in which case only wildcard dispatch rules will be effective (ones without Trunk IDs).
//
// Deprecated: use MatchDispatchRuleIter
func MatchDispatchRule(trunk *livekit.SIPInboundTrunkInfo, rules []*livekit.SIPDispatchRuleInfo, req *rpc.EvaluateSIPDispatchRulesRequest) (*livekit.SIPDispatchRuleInfo, error) {
return MatchDispatchRuleIter(trunk, iters.Slice(rules), req)
func MatchDispatchRule(trunk *livekit.SIPInboundTrunkInfo, rules []*livekit.SIPDispatchRuleInfo, req *rpc.EvaluateSIPDispatchRulesRequest, opts ...MatchDispatchRuleOpt) (*livekit.SIPDispatchRuleInfo, error) {
return MatchDispatchRuleIter(trunk, iters.Slice(rules), req, opts...)
}

type matchDispatchRuleOpts struct {
AllowConflicts bool
Conflict DispatchRuleConflictFunc
}

func (opt *matchDispatchRuleOpts) defaults() {
if opt.Conflict == nil {
opt.Conflict = func(_, _ *livekit.SIPDispatchRuleInfo, _ DispatchRuleConflictReason) {}
}
}

type MatchDispatchRuleOpt func(opt *matchDispatchRuleOpts)

type DispatchRuleConflictReason int

const (
DispatchRuleConflictGeneric = DispatchRuleConflictReason(iota)
)

type DispatchRuleConflictFunc func(r1, r2 *livekit.SIPDispatchRuleInfo, reason DispatchRuleConflictReason)

// WithAllowDispatchRuleConflicts allows conflicting DispatchRule definitions.
func WithAllowDispatchRuleConflicts() MatchDispatchRuleOpt {
return func(opt *matchDispatchRuleOpts) {
opt.AllowConflicts = true
}
}

// WithDispatchRuleConflict sets a callback that is called when two DispatchRules conflict.
func WithDispatchRuleConflict(fnc DispatchRuleConflictFunc) MatchDispatchRuleOpt {
return func(opt *matchDispatchRuleOpts) {
opt.Conflict = fnc
}
}

// MatchDispatchRuleIter finds the best dispatch rule matching the request parameters. Returns an error if no rule matched.
// Trunk parameter can be nil, in which case only wildcard dispatch rules will be effective (ones without Trunk IDs).
func MatchDispatchRuleIter(trunk *livekit.SIPInboundTrunkInfo, rules iters.Iter[*livekit.SIPDispatchRuleInfo], req *rpc.EvaluateSIPDispatchRulesRequest) (*livekit.SIPDispatchRuleInfo, error) {
rules = NewDispatchRuleValidator().ValidateIter(rules)
func MatchDispatchRuleIter(trunk *livekit.SIPInboundTrunkInfo, rules iters.Iter[*livekit.SIPDispatchRuleInfo], req *rpc.EvaluateSIPDispatchRulesRequest, opts ...MatchDispatchRuleOpt) (*livekit.SIPDispatchRuleInfo, error) {
rules = NewDispatchRuleValidator(opts...).ValidateIter(rules)
defer rules.Close()
// Trunk can still be nil here in case none matched or were defined.
// This is still fine, but only in case we'll match exactly one wildcard dispatch rule.
Expand Down
Loading

0 comments on commit cc84c13

Please sign in to comment.