diff --git a/pkg/server/udp.go b/pkg/server/udp.go index ef7e4b7fd..339e4e442 100644 --- a/pkg/server/udp.go +++ b/pkg/server/udp.go @@ -31,6 +31,12 @@ import ( "net" ) +// cmcUDPConn can read and write cmsg. +type cmcUDPConn interface { + readFrom(b []byte) (n int, cm any, src net.Addr, err error) + writeTo(b []byte, cm any, dst net.Addr) (n int, err error) +} + func (s *Server) ServeUDP(c net.PacketConn) error { defer c.Close() @@ -52,8 +58,20 @@ func (s *Server) ServeUDP(c net.PacketConn) error { defer readBuf.Release() rb := readBuf.Bytes() + var cmc cmcUDPConn + var err error + uc, ok := c.(*net.UDPConn) + if ok { + cmc, err = newUDPConn(uc) + if err != nil { + return fmt.Errorf("failed to control socket cmsg, %w", err) + } + } else { + cmc = newDummyUDPConn(c) + } + for { - n, clientNetAddr, err := c.ReadFrom(rb) + n, cm, clientNetAddr, err := cmc.readFrom(rb) if err != nil { if s.Closed() { return ErrServerClosed @@ -87,7 +105,7 @@ func (s *Server) ServeUDP(c net.PacketConn) error { return } defer buf.Release() - if _, err := c.WriteTo(b, clientNetAddr); err != nil { + if _, err := cmc.writeTo(b, cm, clientNetAddr); err != nil { s.opts.Logger.Warn("failed to write response", zap.Stringer("client", clientNetAddr), zap.Error(err)) } } @@ -105,3 +123,23 @@ func getUDPSize(m *dns.Msg) int { } return int(s) } + +// newDummyUDPConn returns a dummyWrapper. +func newDummyUDPConn(c net.PacketConn) cmcUDPConn { + return dummyWrapper{c: c} +} + +// dummyWrapper is just a wrapper that implements cmcUDPConn but does not +// write or read any control msg. +type dummyWrapper struct { + c net.PacketConn +} + +func (w dummyWrapper) readFrom(b []byte) (n int, cm any, src net.Addr, err error) { + n, src, err = w.c.ReadFrom(b) + return +} + +func (w dummyWrapper) writeTo(b []byte, cm any, dst net.Addr) (n int, err error) { + return w.c.WriteTo(b, dst) +} diff --git a/pkg/server/udp_linux.go b/pkg/server/udp_linux.go new file mode 100644 index 000000000..c8e70b1fe --- /dev/null +++ b/pkg/server/udp_linux.go @@ -0,0 +1,113 @@ +//go:build linux + +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package server + +import ( + "fmt" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" + "golang.org/x/sys/unix" + "net" + "os" +) + +type protocol int + +const ( + invalid protocol = iota + v4 + v6 +) + +type ipv4PacketConn struct { + c *ipv4.PacketConn +} + +func (i ipv4PacketConn) readFrom(b []byte) (n int, cm any, src net.Addr, err error) { + return i.c.ReadFrom(b) +} + +func (i ipv4PacketConn) writeTo(b []byte, cm any, dst net.Addr) (n int, err error) { + return i.c.WriteTo(b, cm.(*ipv4.ControlMessage), dst) +} + +type ipv6PacketConn struct { + c *ipv6.PacketConn +} + +func (i ipv6PacketConn) readFrom(b []byte) (n int, cm any, src net.Addr, err error) { + return i.c.ReadFrom(b) +} + +func (i ipv6PacketConn) writeTo(b []byte, cm any, dst net.Addr) (n int, err error) { + return i.c.WriteTo(b, cm.(*ipv6.ControlMessage), dst) +} + +func newUDPConn(c *net.UDPConn) (cmcUDPConn, error) { + p, err := getSocketIPProtocol(c) + if err != nil { + return nil, fmt.Errorf("failed to get socket ip protocol, %w", err) + } + switch p { + case v4: + c := ipv4.NewPacketConn(c) + if err := c.SetControlMessage(ipv4.FlagSrc|ipv4.FlagDst|ipv4.FlagInterface, true); err != nil { + return nil, fmt.Errorf("failed to set ipv4 cmsg flags, %w", err) + } + return ipv4PacketConn{c: c}, nil + case v6: + c := ipv6.NewPacketConn(c) + if err := c.SetControlMessage(ipv6.FlagSrc|ipv6.FlagDst|ipv6.FlagInterface, true); err != nil { + return nil, fmt.Errorf("failed to set ipv6 cmsg flags, %w", err) + } + return ipv6PacketConn{c: c}, nil + default: + return nil, fmt.Errorf("unknow protocol %d", p) + } +} + +func getSocketIPProtocol(c *net.UDPConn) (protocol, error) { + sc, err := c.SyscallConn() + if err != nil { + return 0, err + } + proto := invalid + var syscallErr error + if controlErr := sc.Control(func(fd uintptr) { + v, err := unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_DOMAIN) + if err != nil { + syscallErr = os.NewSyscallError("failed to get SO_PROTOCOL", err) + return + } + switch v { + case unix.AF_INET: + proto = v4 + case unix.AF_INET6: + proto = v6 + default: + syscallErr = fmt.Errorf("socket protocol %d is not supported", v) + } + }); err != nil { + return 0, fmt.Errorf("control fd err, %w", controlErr) + } + return proto, syscallErr +} diff --git a/pkg/server/udp_others.go b/pkg/server/udp_others.go new file mode 100644 index 000000000..4842e742e --- /dev/null +++ b/pkg/server/udp_others.go @@ -0,0 +1,28 @@ +//go:build !linux + +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package server + +import "net" + +func newUDPConn(c *net.UDPConn) (cmcUDPConn, error) { + return newDummyUDPConn(c), nil +}