Skip to content

Commit

Permalink
Split set route to Start()
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Nov 6, 2024
1 parent c35b14a commit d8ab193
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 51 deletions.
13 changes: 7 additions & 6 deletions tun.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ type Handler interface {
type Tun interface {
io.ReadWriter
N.VectorisedWriter
Start() error
Close() error
}

Expand Down Expand Up @@ -89,10 +90,10 @@ func (o *Options) Inet4GatewayAddr() netip.Addr {
return o.Inet4Gateway
}
if len(o.Inet4Address) > 0 {
if HasNextAddress(o.Inet4Address[0], 1) {
return o.Inet4Address[0].Addr().Next()
} else if runtime.GOOS != "linux" {
if runtime.GOOS == "darwin" {
return o.Inet4Address[0].Addr()
} else if HasNextAddress(o.Inet4Address[0], 1) {
return o.Inet4Address[0].Addr().Next()
}
}
return netip.IPv4Unspecified()
Expand All @@ -103,10 +104,10 @@ func (o *Options) Inet6GatewayAddr() netip.Addr {
return o.Inet6Gateway
}
if len(o.Inet6Address) > 0 {
if HasNextAddress(o.Inet6Address[0], 1) {
return o.Inet6Address[0].Addr().Next()
} else if runtime.GOOS != "linux" {
if runtime.GOOS == "darwin" {
return o.Inet6Address[0].Addr()
} else if HasNextAddress(o.Inet6Address[0], 1) {
return o.Inet6Address[0].Addr().Next()
}
}
return netip.IPv6Unspecified()
Expand Down
89 changes: 69 additions & 20 deletions tun_darwin.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package tun

import (
"errors"
"fmt"
"net"
"net/netip"
Expand All @@ -24,9 +25,10 @@ const PacketOffset = 4
type NativeTun struct {
tunFile *os.File
tunWriter N.VectorisedWriter
mtu uint32
options Options
inet4Address [4]byte
inet6Address [16]byte
routerSet bool
}

func New(options Options) (Tun, error) {
Expand Down Expand Up @@ -54,7 +56,7 @@ func New(options Options) (Tun, error) {

nativeTun := &NativeTun{
tunFile: os.NewFile(uintptr(tunFd), "utun"),
mtu: options.MTU,
options: options,
}
if len(options.Inet4Address) > 0 {
nativeTun.inet4Address = options.Inet4Address[0].Addr().As4()
Expand All @@ -70,6 +72,15 @@ func New(options Options) (Tun, error) {
return nativeTun, nil
}

func (t *NativeTun) Start() error {
return t.setRoutes()
}

func (t *NativeTun) Close() error {
defer flushDNSCache()
return E.Errors(t.unsetRoutes(), t.tunFile.Close())
}

func (t *NativeTun) Read(p []byte) (n int, err error) {
return t.tunFile.Read(p)
}
Expand All @@ -93,11 +104,6 @@ func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error {
return t.tunWriter.WriteVectorised(append([]*buf.Buffer{buf.As(packetHeader)}, buffers...))
}

func (t *NativeTun) Close() error {
flushDNSCache()
return t.tunFile.Close()
}

const utunControlName = "com.apple.net.utun_control"

const (
Expand Down Expand Up @@ -239,28 +245,68 @@ func configure(tunFd int, ifIndex int, name string, options Options) error {
}
}
}
if options.AutoRoute {
var routeRanges []netip.Prefix
routeRanges, err = options.BuildAutoRouteRanges(false)
return nil
}

func (t *NativeTun) setRoutes() error {
if t.options.AutoRoute {
routeRanges, err := t.options.BuildAutoRouteRanges(false)
if err != nil {
return err
}
gateway4, gateway6 := options.Inet4GatewayAddr(), options.Inet6GatewayAddr()
for _, routeRange := range routeRanges {
if routeRange.Addr().Is4() {
err = addRoute(routeRange, gateway4)
gateway4, gateway6 := t.options.Inet4GatewayAddr(), t.options.Inet6GatewayAddr()
for _, destination := range routeRanges {
var gateway netip.Addr
if destination.Addr().Is4() {
gateway = gateway4
} else {
err = addRoute(routeRange, gateway6)
gateway = gateway6
}
err = execRoute(unix.RTM_ADD, destination, gateway)
if err != nil {
return E.Cause(err, "add route: ", routeRange)
if errors.Is(err, unix.EEXIST) {
err = execRoute(unix.RTM_DELETE, destination, gateway)
if err != nil {
return E.Cause(err, "remove existing route: ", destination)
}
err = execRoute(unix.RTM_ADD, destination, gateway)
if err != nil {
return E.Cause(err, "re-add route: ", destination)
}
}
return E.Cause(err, "add route: ", destination)
}
}
flushDNSCache()
t.routerSet = true
}
return nil
}

func (t *NativeTun) unsetRoutes() error {
if !t.routerSet {
return nil
}
routeRanges, err := t.options.BuildAutoRouteRanges(false)
if err != nil {
return err
}
gateway4, gateway6 := t.options.Inet4GatewayAddr(), t.options.Inet6GatewayAddr()
for _, destination := range routeRanges {
var gateway netip.Addr
if destination.Addr().Is4() {
gateway = gateway4
} else {
gateway = gateway6
}
err = execRoute(unix.RTM_DELETE, destination, gateway)
if err != nil {
err = E.Errors(err, E.Cause(err, "delete route: ", destination))
}
}
return err
}

func useSocket(domain, typ, proto int, block func(socketFd int) error) error {
socketFd, err := unix.Socket(domain, typ, proto)
if err != nil {
Expand All @@ -270,13 +316,16 @@ func useSocket(domain, typ, proto int, block func(socketFd int) error) error {
return block(socketFd)
}

func addRoute(destination netip.Prefix, gateway netip.Addr) error {
func execRoute(rtmType int, destination netip.Prefix, gateway netip.Addr) error {
routeMessage := route.RouteMessage{
Type: unix.RTM_ADD,
Flags: unix.RTF_UP | unix.RTF_STATIC | unix.RTF_GATEWAY,
Type: rtmType,
Version: unix.RTM_VERSION,
Flags: unix.RTF_STATIC | unix.RTF_GATEWAY,
Seq: 1,
}
if rtmType == unix.RTM_ADD {
routeMessage.Flags |= unix.RTF_UP
}
if gateway.Is4() {
routeMessage.Addrs = []route.Addr{
syscall.RTAX_DST: &route.Inet4Addr{IP: destination.Addr().As4()},
Expand All @@ -300,5 +349,5 @@ func addRoute(destination netip.Prefix, gateway netip.Addr) error {
}

func flushDNSCache() {
shell.Exec("dscacheutil", "-flushcache").Start()
go shell.Exec("dscacheutil", "-flushcache").Run()
}
4 changes: 2 additions & 2 deletions tun_darwin_gvisor.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type DarwinEndpoint struct {
}

func (e *DarwinEndpoint) MTU() uint32 {
return e.tun.mtu
return e.tun.options.MTU
}

func (e *DarwinEndpoint) SetMTU(mtu uint32) {
Expand Down Expand Up @@ -57,7 +57,7 @@ func (e *DarwinEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
}

func (e *DarwinEndpoint) dispatchLoop() {
packetBuffer := make([]byte, e.tun.mtu+PacketOffset)
packetBuffer := make([]byte, e.tun.options.MTU+PacketOffset)
for {
n, err := e.tun.tunFile.Read(packetBuffer)
if err != nil {
Expand Down
9 changes: 9 additions & 0 deletions tun_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,15 @@ func (t *NativeTun) configure(tunLink netlink.Link) error {
t.txChecksumOffload = true
}

return nil
}

func (t *NativeTun) Start() error {
tunLink, err := netlink.LinkByName(t.options.Name)
if err != nil {
return err
}

err = netlink.LinkSetUp(tunLink)
if err != nil {
return err
Expand Down
50 changes: 27 additions & 23 deletions tun_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,27 +106,6 @@ func (t *NativeTun) configure() error {
if len(t.options.Inet4Address) > 0 || len(t.options.Inet6Address) > 0 {
_ = luid.DisableDNSRegistration()
}
if t.options.AutoRoute {
gateway4, gateway6 := t.options.Inet4GatewayAddr(), t.options.Inet6GatewayAddr()
routeRanges, err := t.options.BuildAutoRouteRanges(false)
if err != nil {
return err
}
for _, routeRange := range routeRanges {
if routeRange.Addr().Is4() {
err = luid.AddRoute(routeRange, gateway4, 0)
} else {
err = luid.AddRoute(routeRange, gateway6, 0)
}
}
if err != nil {
return err
}
err = windnsapi.FlushResolverCache()
if err != nil {
return err
}
}
if len(t.options.Inet4Address) > 0 {
inetIf, err := luid.IPInterface(winipcfg.AddressFamily(windows.AF_INET))
if err != nil {
Expand Down Expand Up @@ -166,8 +145,34 @@ func (t *NativeTun) configure() error {
return E.Cause(err, "set ipv6 options")
}
}
return nil
}

if t.options.AutoRoute && t.options.StrictRoute {
func (t *NativeTun) Start() error {
if !t.options.AutoRoute {
return nil
}
luid := winipcfg.LUID(t.adapter.LUID())
gateway4, gateway6 := t.options.Inet4GatewayAddr(), t.options.Inet6GatewayAddr()
routeRanges, err := t.options.BuildAutoRouteRanges(false)
if err != nil {
return err
}
for _, routeRange := range routeRanges {
if routeRange.Addr().Is4() {
err = luid.AddRoute(routeRange, gateway4, 0)
} else {
err = luid.AddRoute(routeRange, gateway6, 0)
}
}
if err != nil {
return err
}
err = windnsapi.FlushResolverCache()
if err != nil {
return err
}
if t.options.StrictRoute {
var engine uintptr
session := &winsys.FWPM_SESSION0{Flags: winsys.FWPM_SESSION_FLAG_DYNAMIC}
err := winsys.FwpmEngineOpen0(nil, winsys.RPC_C_AUTHN_DEFAULT, nil, session, unsafe.Pointer(&engine))
Expand Down Expand Up @@ -340,7 +345,6 @@ func (t *NativeTun) configure() error {
}
}
}

return nil
}

Expand Down

0 comments on commit d8ab193

Please sign in to comment.