diff --git a/pkg/server/udp.go b/pkg/server/udp.go
index ef7e4b7fd..339e4e442 100644
--- a/pkg/server/udp.go
+++ b/pkg/server/udp.go
@@ -31,6 +31,12 @@ import (
"net"
)
+// cmcUDPConn can read and write cmsg.
+type cmcUDPConn interface {
+ readFrom(b []byte) (n int, cm any, src net.Addr, err error)
+ writeTo(b []byte, cm any, dst net.Addr) (n int, err error)
+}
+
func (s *Server) ServeUDP(c net.PacketConn) error {
defer c.Close()
@@ -52,8 +58,20 @@ func (s *Server) ServeUDP(c net.PacketConn) error {
defer readBuf.Release()
rb := readBuf.Bytes()
+ var cmc cmcUDPConn
+ var err error
+ uc, ok := c.(*net.UDPConn)
+ if ok {
+ cmc, err = newUDPConn(uc)
+ if err != nil {
+ return fmt.Errorf("failed to control socket cmsg, %w", err)
+ }
+ } else {
+ cmc = newDummyUDPConn(c)
+ }
+
for {
- n, clientNetAddr, err := c.ReadFrom(rb)
+ n, cm, clientNetAddr, err := cmc.readFrom(rb)
if err != nil {
if s.Closed() {
return ErrServerClosed
@@ -87,7 +105,7 @@ func (s *Server) ServeUDP(c net.PacketConn) error {
return
}
defer buf.Release()
- if _, err := c.WriteTo(b, clientNetAddr); err != nil {
+ if _, err := cmc.writeTo(b, cm, clientNetAddr); err != nil {
s.opts.Logger.Warn("failed to write response", zap.Stringer("client", clientNetAddr), zap.Error(err))
}
}
@@ -105,3 +123,23 @@ func getUDPSize(m *dns.Msg) int {
}
return int(s)
}
+
+// newDummyUDPConn returns a dummyWrapper.
+func newDummyUDPConn(c net.PacketConn) cmcUDPConn {
+ return dummyWrapper{c: c}
+}
+
+// dummyWrapper is just a wrapper that implements cmcUDPConn but does not
+// write or read any control msg.
+type dummyWrapper struct {
+ c net.PacketConn
+}
+
+func (w dummyWrapper) readFrom(b []byte) (n int, cm any, src net.Addr, err error) {
+ n, src, err = w.c.ReadFrom(b)
+ return
+}
+
+func (w dummyWrapper) writeTo(b []byte, cm any, dst net.Addr) (n int, err error) {
+ return w.c.WriteTo(b, dst)
+}
diff --git a/pkg/server/udp_linux.go b/pkg/server/udp_linux.go
new file mode 100644
index 000000000..c8e70b1fe
--- /dev/null
+++ b/pkg/server/udp_linux.go
@@ -0,0 +1,113 @@
+//go:build linux
+
+/*
+ * Copyright (C) 2020-2022, IrineSistiana
+ *
+ * This file is part of mosdns.
+ *
+ * mosdns is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * mosdns is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ */
+
+package server
+
+import (
+ "fmt"
+ "golang.org/x/net/ipv4"
+ "golang.org/x/net/ipv6"
+ "golang.org/x/sys/unix"
+ "net"
+ "os"
+)
+
+type protocol int
+
+const (
+ invalid protocol = iota
+ v4
+ v6
+)
+
+type ipv4PacketConn struct {
+ c *ipv4.PacketConn
+}
+
+func (i ipv4PacketConn) readFrom(b []byte) (n int, cm any, src net.Addr, err error) {
+ return i.c.ReadFrom(b)
+}
+
+func (i ipv4PacketConn) writeTo(b []byte, cm any, dst net.Addr) (n int, err error) {
+ return i.c.WriteTo(b, cm.(*ipv4.ControlMessage), dst)
+}
+
+type ipv6PacketConn struct {
+ c *ipv6.PacketConn
+}
+
+func (i ipv6PacketConn) readFrom(b []byte) (n int, cm any, src net.Addr, err error) {
+ return i.c.ReadFrom(b)
+}
+
+func (i ipv6PacketConn) writeTo(b []byte, cm any, dst net.Addr) (n int, err error) {
+ return i.c.WriteTo(b, cm.(*ipv6.ControlMessage), dst)
+}
+
+func newUDPConn(c *net.UDPConn) (cmcUDPConn, error) {
+ p, err := getSocketIPProtocol(c)
+ if err != nil {
+ return nil, fmt.Errorf("failed to get socket ip protocol, %w", err)
+ }
+ switch p {
+ case v4:
+ c := ipv4.NewPacketConn(c)
+ if err := c.SetControlMessage(ipv4.FlagSrc|ipv4.FlagDst|ipv4.FlagInterface, true); err != nil {
+ return nil, fmt.Errorf("failed to set ipv4 cmsg flags, %w", err)
+ }
+ return ipv4PacketConn{c: c}, nil
+ case v6:
+ c := ipv6.NewPacketConn(c)
+ if err := c.SetControlMessage(ipv6.FlagSrc|ipv6.FlagDst|ipv6.FlagInterface, true); err != nil {
+ return nil, fmt.Errorf("failed to set ipv6 cmsg flags, %w", err)
+ }
+ return ipv6PacketConn{c: c}, nil
+ default:
+ return nil, fmt.Errorf("unknow protocol %d", p)
+ }
+}
+
+func getSocketIPProtocol(c *net.UDPConn) (protocol, error) {
+ sc, err := c.SyscallConn()
+ if err != nil {
+ return 0, err
+ }
+ proto := invalid
+ var syscallErr error
+ if controlErr := sc.Control(func(fd uintptr) {
+ v, err := unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_DOMAIN)
+ if err != nil {
+ syscallErr = os.NewSyscallError("failed to get SO_PROTOCOL", err)
+ return
+ }
+ switch v {
+ case unix.AF_INET:
+ proto = v4
+ case unix.AF_INET6:
+ proto = v6
+ default:
+ syscallErr = fmt.Errorf("socket protocol %d is not supported", v)
+ }
+ }); err != nil {
+ return 0, fmt.Errorf("control fd err, %w", controlErr)
+ }
+ return proto, syscallErr
+}
diff --git a/pkg/server/udp_others.go b/pkg/server/udp_others.go
new file mode 100644
index 000000000..4842e742e
--- /dev/null
+++ b/pkg/server/udp_others.go
@@ -0,0 +1,28 @@
+//go:build !linux
+
+/*
+ * Copyright (C) 2020-2022, IrineSistiana
+ *
+ * This file is part of mosdns.
+ *
+ * mosdns is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * mosdns is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ */
+
+package server
+
+import "net"
+
+func newUDPConn(c *net.UDPConn) (cmcUDPConn, error) {
+ return newDummyUDPConn(c), nil
+}