diff --git a/README.md b/README.md index 9aa1d0a..7c1e768 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ Configuration example: ```yaml listen: - - 239.82.71.65:8271 + - 239.82.71.65:8271 # or "239.82.71.65:8271@eth0" or "239.82.71.65:8271@192.168.0.0/16" - 127.0.0.1:8282 groups: diff --git a/agent/agent.go b/agent/agent.go index d403ef3..cde1454 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -5,11 +5,14 @@ import ( "fmt" "log" "net" + "net/netip" + "strings" "sync" "time" "github.com/Snawoot/rgap/config" "github.com/Snawoot/rgap/protocol" + "github.com/Snawoot/rgap/util" "github.com/hashicorp/go-multierror" ) @@ -92,7 +95,12 @@ func (a *Agent) singleRun(ctx context.Context, t time.Time) error { } func (a *Agent) sendSingle(ctx context.Context, msg []byte, dst string) error { - conn, err := a.cfg.Dialer.DialContext(ctx, "udp", dst) + dstAddr, iface, err := util.SplitAndResolveAddrSpec(dst) + if err != nil { + return fmt.Errorf("destination %s: interface resolving failed: %w", dst, err) + } + + conn, err := a.dialInterfaceContext(ctx, "udp", dstAddr, iface) if err != nil { return fmt.Errorf("Agent.sendSingle dial failed: %w", err) } @@ -111,3 +119,28 @@ func (a *Agent) sendSingle(ctx context.Context, msg []byte, dst string) error { } return nil } + +func (a *Agent) dialInterfaceContext(ctx context.Context, network, addr string, iif *net.Interface) (net.Conn, error) { + if iif == nil { + return a.cfg.Dialer.DialContext(ctx, network, addr) + } + + var hints []string + addrs, err := iif.Addrs() + if err != nil { + return nil, err + } + for _, addr := range addrs { + ipnet, ok := addr.(*net.IPNet) + if !ok { + return nil, fmt.Errorf("unexpected type returned as address interface: %T", addr) + } + netipAddr, ok := netip.AddrFromSlice(ipnet.IP) + if !ok { + return nil, fmt.Errorf("interface %v has invalid address %s", iif.Name, ipnet.IP) + } + hints = append(hints, netipAddr.Unmap().String()) + } + boundDialer := util.NewBoundDialer(a.cfg.Dialer, strings.Join(hints, ",")) + return boundDialer.DialContext(ctx, network, addr) +} diff --git a/listener/udpsource.go b/listener/udpsource.go index fe51a92..a2e680f 100644 --- a/listener/udpsource.go +++ b/listener/udpsource.go @@ -7,6 +7,7 @@ import ( "net" "github.com/Snawoot/rgap/protocol" + "github.com/Snawoot/rgap/util" ) type UDPSource struct { @@ -33,7 +34,12 @@ func (s *UDPSource) Start() error { s.ctxCancel = cancel s.loopDone = make(chan struct{}) - udpAddr, err := net.ResolveUDPAddr("udp", s.address) + listenAddr, iface, err := util.SplitAndResolveAddrSpec(s.address) + if err != nil { + return fmt.Errorf("UDP source %s: interface resolving failed: %w", s.address, err) + } + + udpAddr, err := net.ResolveUDPAddr("udp", listenAddr) if err != nil { return fmt.Errorf("bad UDP listen address: %w", err) } @@ -41,7 +47,7 @@ func (s *UDPSource) Start() error { var conn *net.UDPConn if udpAddr.IP.IsMulticast() { - conn, err = net.ListenMulticastUDP("udp4", nil, udpAddr) + conn, err = net.ListenMulticastUDP("udp", iface, udpAddr) if err != nil { return fmt.Errorf("UDP listen failed: %w", err) } diff --git a/util/hintdialer.go b/util/hintdialer.go new file mode 100644 index 0000000..7fc7cea --- /dev/null +++ b/util/hintdialer.go @@ -0,0 +1,186 @@ +package util + +import ( + "context" + "errors" + "fmt" + "net" + "os" + "strings" + + "github.com/hashicorp/go-multierror" +) + +var ( + ErrNoSuitableAddress = errors.New("no suitable address") + ErrBadIPAddressLength = errors.New("bad IP address length") + ErrUnknownNetwork = errors.New("unknown network") +) + +type BoundDialerContextKey struct{} + +type BoundDialerContextValue struct { + Hints *string + LocalAddr string +} + +type BoundDialerDefaultSink interface { + DialContext(ctx context.Context, network, address string) (net.Conn, error) +} + +type BoundDialer struct { + defaultDialer BoundDialerDefaultSink + defaultHints string +} + +func NewBoundDialer(defaultDialer BoundDialerDefaultSink, defaultHints string) *BoundDialer { + if defaultDialer == nil { + defaultDialer = &net.Dialer{} + } + return &BoundDialer{ + defaultDialer: defaultDialer, + defaultHints: defaultHints, + } +} + +func (d *BoundDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + hints := d.defaultHints + lAddr := "" + if hintsOverride := ctx.Value(BoundDialerContextKey{}); hintsOverride != nil { + if hintsOverrideValue, ok := hintsOverride.(BoundDialerContextValue); ok { + if hintsOverrideValue.Hints != nil { + hints = *hintsOverrideValue.Hints + } + lAddr = hintsOverrideValue.LocalAddr + } + } + + parsedHints, err := parseHints(hints, lAddr) + if err != nil { + return nil, fmt.Errorf("dial failed: %w", err) + } + + if len(parsedHints) == 0 { + return d.defaultDialer.DialContext(ctx, network, address) + } + + var netBase string + switch network { + case "tcp", "tcp4", "tcp6": + netBase = "tcp" + case "udp", "udp4", "udp6": + netBase = "udp" + case "ip", "ip4", "ip6": + netBase = "ip" + default: + return d.defaultDialer.DialContext(ctx, network, address) + } + + var resErr error + for _, lIP := range parsedHints { + lAddr, restrictedNetwork, err := ipToLAddr(netBase, lIP) + if err != nil { + resErr = multierror.Append(resErr, fmt.Errorf("ipToLAddr(%q) failed: %w", lIP.String(), err)) + continue + } + if network != netBase && network != restrictedNetwork { + continue + } + + conn, err := (&net.Dialer{ + LocalAddr: lAddr, + }).DialContext(ctx, restrictedNetwork, address) + if err != nil { + resErr = multierror.Append(resErr, fmt.Errorf("dial failed: %w", err)) + } else { + return conn, nil + } + } + + if resErr == nil { + resErr = ErrNoSuitableAddress + } + return nil, resErr +} + +func (d *BoundDialer) Dial(network, address string) (net.Conn, error) { + return d.DialContext(context.Background(), network, address) +} + +func ipToLAddr(network string, ip net.IP) (net.Addr, string, error) { + v6 := true + if ip4 := ip.To4(); len(ip4) == net.IPv4len { + ip = ip4 + v6 = false + } else if len(ip) != net.IPv6len { + return nil, "", ErrBadIPAddressLength + } + + var lAddr net.Addr + var lNetwork string + switch network { + case "tcp", "tcp4", "tcp6": + lAddr = &net.TCPAddr{ + IP: ip, + } + if v6 { + lNetwork = "tcp6" + } else { + lNetwork = "tcp4" + } + case "udp", "udp4", "udp6": + lAddr = &net.UDPAddr{ + IP: ip, + } + if v6 { + lNetwork = "udp6" + } else { + lNetwork = "udp4" + } + case "ip", "ip4", "ip6": + lAddr = &net.IPAddr{ + IP: ip, + } + if v6 { + lNetwork = "ip6" + } else { + lNetwork = "ip4" + } + default: + return nil, "", ErrUnknownNetwork + } + + return lAddr, lNetwork, nil +} + +func parseHints(hints, lAddr string) ([]net.IP, error) { + hints = os.Expand(hints, func(key string) string { + switch key { + case "lAddr": + return lAddr + default: + return fmt.Sprintf("", key) + } + }) + res, err := parseIPList(hints) + if err != nil { + return nil, fmt.Errorf("unable to parse source IP hints %q: %w", hints, err) + } + return res, nil +} + +func parseIPList(list string) ([]net.IP, error) { + res := make([]net.IP, 0) + for _, elem := range strings.Split(list, ",") { + elem = strings.TrimSpace(elem) + if len(elem) == 0 { + continue + } + if parsed := net.ParseIP(elem); parsed == nil { + return nil, fmt.Errorf("unable to parse IP address %q", elem) + } else { + res = append(res, parsed) + } + } + return res, nil +} diff --git a/util/util.go b/util/util.go index 7100030..405942b 100644 --- a/util/util.go +++ b/util/util.go @@ -2,8 +2,12 @@ package util import ( "bytes" + "errors" "fmt" + "log" + "net" "net/netip" + "strings" "gopkg.in/yaml.v3" ) @@ -58,3 +62,58 @@ func CheckedUnmarshal(doc *yaml.Node, dst interface{}) error { } return nil } + +func SplitAndResolveAddrSpec(spec string) (string, *net.Interface, error) { + addrSpec, ifaceSpec, found := strings.Cut(spec, "@") + if !found { + return addrSpec, nil, nil + } + iface, err := ResolveInterface(ifaceSpec) + if err != nil { + return addrSpec, nil, fmt.Errorf("unable to resolve interface spec %q: %w", ifaceSpec, err) + } + return addrSpec, iface, nil +} + +func ResolveInterface(spec string) (*net.Interface, error) { + ifaces, err := net.Interfaces() + if err != nil { + return nil, fmt.Errorf("unable to enumerate interfaces: %w", err) + } + if pfx, err := netip.ParsePrefix(spec); err == nil { + // look for address + for i := range ifaces { + addrs, err := ifaces[i].Addrs() + if err != nil { + // may be a problem with some interface, + // but we still probably can find the right one + log.Printf("WARNING: interface %s is failing to report its addresses: %v", ifaces[i].Name, err) + continue + } + for _, addr := range addrs { + ipnet, ok := addr.(*net.IPNet) + if !ok { + return nil, fmt.Errorf("unexpected type returned as address interface: %T", addr) + } + netipAddr, ok := netip.AddrFromSlice(ipnet.IP) + if !ok { + return nil, fmt.Errorf("interface %v has invalid address %s", ifaces[i].Name, ipnet.IP) + } + netipAddr = netipAddr.Unmap() + if pfx.Contains(netipAddr) { + res := ifaces[i] + return &res, nil + } + } + } + } else { + // look for iface name + for i := range ifaces { + if ifaces[i].Name == spec { + res := ifaces[i] + return &res, nil + } + } + } + return nil, errors.New("specified interface not found") +}