From d8ab1932563b73ff54777e5718e70c9e1f5b08bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 6 Nov 2024 17:09:04 +0800 Subject: [PATCH] Split set route to Start() --- tun.go | 13 ++++--- tun_darwin.go | 89 ++++++++++++++++++++++++++++++++++---------- tun_darwin_gvisor.go | 4 +- tun_linux.go | 9 +++++ tun_windows.go | 50 +++++++++++++------------ 5 files changed, 114 insertions(+), 51 deletions(-) diff --git a/tun.go b/tun.go index 43d7eba..4643c12 100644 --- a/tun.go +++ b/tun.go @@ -24,6 +24,7 @@ type Handler interface { type Tun interface { io.ReadWriter N.VectorisedWriter + Start() error Close() error } @@ -89,10 +90,10 @@ func (o *Options) Inet4GatewayAddr() netip.Addr { return o.Inet4Gateway } if len(o.Inet4Address) > 0 { - if HasNextAddress(o.Inet4Address[0], 1) { - return o.Inet4Address[0].Addr().Next() - } else if runtime.GOOS != "linux" { + if runtime.GOOS == "darwin" { return o.Inet4Address[0].Addr() + } else if HasNextAddress(o.Inet4Address[0], 1) { + return o.Inet4Address[0].Addr().Next() } } return netip.IPv4Unspecified() @@ -103,10 +104,10 @@ func (o *Options) Inet6GatewayAddr() netip.Addr { return o.Inet6Gateway } if len(o.Inet6Address) > 0 { - if HasNextAddress(o.Inet6Address[0], 1) { - return o.Inet6Address[0].Addr().Next() - } else if runtime.GOOS != "linux" { + if runtime.GOOS == "darwin" { return o.Inet6Address[0].Addr() + } else if HasNextAddress(o.Inet6Address[0], 1) { + return o.Inet6Address[0].Addr().Next() } } return netip.IPv6Unspecified() diff --git a/tun_darwin.go b/tun_darwin.go index e7267d4..16d463d 100644 --- a/tun_darwin.go +++ b/tun_darwin.go @@ -1,6 +1,7 @@ package tun import ( + "errors" "fmt" "net" "net/netip" @@ -24,9 +25,10 @@ const PacketOffset = 4 type NativeTun struct { tunFile *os.File tunWriter N.VectorisedWriter - mtu uint32 + options Options inet4Address [4]byte inet6Address [16]byte + routerSet bool } func New(options Options) (Tun, error) { @@ -54,7 +56,7 @@ func New(options Options) (Tun, error) { nativeTun := &NativeTun{ tunFile: os.NewFile(uintptr(tunFd), "utun"), - mtu: options.MTU, + options: options, } if len(options.Inet4Address) > 0 { nativeTun.inet4Address = options.Inet4Address[0].Addr().As4() @@ -70,6 +72,15 @@ func New(options Options) (Tun, error) { return nativeTun, nil } +func (t *NativeTun) Start() error { + return t.setRoutes() +} + +func (t *NativeTun) Close() error { + defer flushDNSCache() + return E.Errors(t.unsetRoutes(), t.tunFile.Close()) +} + func (t *NativeTun) Read(p []byte) (n int, err error) { return t.tunFile.Read(p) } @@ -93,11 +104,6 @@ func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error { return t.tunWriter.WriteVectorised(append([]*buf.Buffer{buf.As(packetHeader)}, buffers...)) } -func (t *NativeTun) Close() error { - flushDNSCache() - return t.tunFile.Close() -} - const utunControlName = "com.apple.net.utun_control" const ( @@ -239,28 +245,68 @@ func configure(tunFd int, ifIndex int, name string, options Options) error { } } } - if options.AutoRoute { - var routeRanges []netip.Prefix - routeRanges, err = options.BuildAutoRouteRanges(false) + return nil +} + +func (t *NativeTun) setRoutes() error { + if t.options.AutoRoute { + routeRanges, err := t.options.BuildAutoRouteRanges(false) if err != nil { return err } - gateway4, gateway6 := options.Inet4GatewayAddr(), options.Inet6GatewayAddr() - for _, routeRange := range routeRanges { - if routeRange.Addr().Is4() { - err = addRoute(routeRange, gateway4) + gateway4, gateway6 := t.options.Inet4GatewayAddr(), t.options.Inet6GatewayAddr() + for _, destination := range routeRanges { + var gateway netip.Addr + if destination.Addr().Is4() { + gateway = gateway4 } else { - err = addRoute(routeRange, gateway6) + gateway = gateway6 } + err = execRoute(unix.RTM_ADD, destination, gateway) if err != nil { - return E.Cause(err, "add route: ", routeRange) + if errors.Is(err, unix.EEXIST) { + err = execRoute(unix.RTM_DELETE, destination, gateway) + if err != nil { + return E.Cause(err, "remove existing route: ", destination) + } + err = execRoute(unix.RTM_ADD, destination, gateway) + if err != nil { + return E.Cause(err, "re-add route: ", destination) + } + } + return E.Cause(err, "add route: ", destination) } } flushDNSCache() + t.routerSet = true } return nil } +func (t *NativeTun) unsetRoutes() error { + if !t.routerSet { + return nil + } + routeRanges, err := t.options.BuildAutoRouteRanges(false) + if err != nil { + return err + } + gateway4, gateway6 := t.options.Inet4GatewayAddr(), t.options.Inet6GatewayAddr() + for _, destination := range routeRanges { + var gateway netip.Addr + if destination.Addr().Is4() { + gateway = gateway4 + } else { + gateway = gateway6 + } + err = execRoute(unix.RTM_DELETE, destination, gateway) + if err != nil { + err = E.Errors(err, E.Cause(err, "delete route: ", destination)) + } + } + return err +} + func useSocket(domain, typ, proto int, block func(socketFd int) error) error { socketFd, err := unix.Socket(domain, typ, proto) if err != nil { @@ -270,13 +316,16 @@ func useSocket(domain, typ, proto int, block func(socketFd int) error) error { return block(socketFd) } -func addRoute(destination netip.Prefix, gateway netip.Addr) error { +func execRoute(rtmType int, destination netip.Prefix, gateway netip.Addr) error { routeMessage := route.RouteMessage{ - Type: unix.RTM_ADD, - Flags: unix.RTF_UP | unix.RTF_STATIC | unix.RTF_GATEWAY, + Type: rtmType, Version: unix.RTM_VERSION, + Flags: unix.RTF_STATIC | unix.RTF_GATEWAY, Seq: 1, } + if rtmType == unix.RTM_ADD { + routeMessage.Flags |= unix.RTF_UP + } if gateway.Is4() { routeMessage.Addrs = []route.Addr{ syscall.RTAX_DST: &route.Inet4Addr{IP: destination.Addr().As4()}, @@ -300,5 +349,5 @@ func addRoute(destination netip.Prefix, gateway netip.Addr) error { } func flushDNSCache() { - shell.Exec("dscacheutil", "-flushcache").Start() + go shell.Exec("dscacheutil", "-flushcache").Run() } diff --git a/tun_darwin_gvisor.go b/tun_darwin_gvisor.go index e693470..df46bf1 100644 --- a/tun_darwin_gvisor.go +++ b/tun_darwin_gvisor.go @@ -24,7 +24,7 @@ type DarwinEndpoint struct { } func (e *DarwinEndpoint) MTU() uint32 { - return e.tun.mtu + return e.tun.options.MTU } func (e *DarwinEndpoint) SetMTU(mtu uint32) { @@ -57,7 +57,7 @@ func (e *DarwinEndpoint) Attach(dispatcher stack.NetworkDispatcher) { } func (e *DarwinEndpoint) dispatchLoop() { - packetBuffer := make([]byte, e.tun.mtu+PacketOffset) + packetBuffer := make([]byte, e.tun.options.MTU+PacketOffset) for { n, err := e.tun.tunFile.Read(packetBuffer) if err != nil { diff --git a/tun_linux.go b/tun_linux.go index 0c64b15..6eb15d4 100644 --- a/tun_linux.go +++ b/tun_linux.go @@ -288,6 +288,15 @@ func (t *NativeTun) configure(tunLink netlink.Link) error { t.txChecksumOffload = true } + return nil +} + +func (t *NativeTun) Start() error { + tunLink, err := netlink.LinkByName(t.options.Name) + if err != nil { + return err + } + err = netlink.LinkSetUp(tunLink) if err != nil { return err diff --git a/tun_windows.go b/tun_windows.go index 45d5b00..392b78c 100644 --- a/tun_windows.go +++ b/tun_windows.go @@ -106,27 +106,6 @@ func (t *NativeTun) configure() error { if len(t.options.Inet4Address) > 0 || len(t.options.Inet6Address) > 0 { _ = luid.DisableDNSRegistration() } - if t.options.AutoRoute { - gateway4, gateway6 := t.options.Inet4GatewayAddr(), t.options.Inet6GatewayAddr() - routeRanges, err := t.options.BuildAutoRouteRanges(false) - if err != nil { - return err - } - for _, routeRange := range routeRanges { - if routeRange.Addr().Is4() { - err = luid.AddRoute(routeRange, gateway4, 0) - } else { - err = luid.AddRoute(routeRange, gateway6, 0) - } - } - if err != nil { - return err - } - err = windnsapi.FlushResolverCache() - if err != nil { - return err - } - } if len(t.options.Inet4Address) > 0 { inetIf, err := luid.IPInterface(winipcfg.AddressFamily(windows.AF_INET)) if err != nil { @@ -166,8 +145,34 @@ func (t *NativeTun) configure() error { return E.Cause(err, "set ipv6 options") } } + return nil +} - if t.options.AutoRoute && t.options.StrictRoute { +func (t *NativeTun) Start() error { + if !t.options.AutoRoute { + return nil + } + luid := winipcfg.LUID(t.adapter.LUID()) + gateway4, gateway6 := t.options.Inet4GatewayAddr(), t.options.Inet6GatewayAddr() + routeRanges, err := t.options.BuildAutoRouteRanges(false) + if err != nil { + return err + } + for _, routeRange := range routeRanges { + if routeRange.Addr().Is4() { + err = luid.AddRoute(routeRange, gateway4, 0) + } else { + err = luid.AddRoute(routeRange, gateway6, 0) + } + } + if err != nil { + return err + } + err = windnsapi.FlushResolverCache() + if err != nil { + return err + } + if t.options.StrictRoute { var engine uintptr session := &winsys.FWPM_SESSION0{Flags: winsys.FWPM_SESSION_FLAG_DYNAMIC} err := winsys.FwpmEngineOpen0(nil, winsys.RPC_C_AUTHN_DEFAULT, nil, session, unsafe.Pointer(&engine)) @@ -340,7 +345,6 @@ func (t *NativeTun) configure() error { } } } - return nil }