Skip to content

Commit

Permalink
Check routes when checking for network changes
Browse files Browse the repository at this point in the history
  • Loading branch information
samuong committed Jul 21, 2024
1 parent fbfe6a9 commit 31fb72c
Show file tree
Hide file tree
Showing 2 changed files with 187 additions and 24 deletions.
65 changes: 58 additions & 7 deletions netmonitor.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2019, 2021 The Alpaca Authors
// Copyright 2019, 2021, 2024 The Alpaca Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -25,11 +25,13 @@ type netMonitor interface {

type netMonitorImpl struct {
addrs map[string]struct{}
routes map[string]net.IP
getAddrs func() ([]net.Addr, error)
dial func(network, addr string) (net.Conn, error)
}

func newNetMonitor() *netMonitorImpl {
return &netMonitorImpl{getAddrs: net.InterfaceAddrs}
return &netMonitorImpl{getAddrs: net.InterfaceAddrs, dial: net.Dial}
}

func (nm *netMonitorImpl) addrsChanged() bool {
Expand All @@ -39,13 +41,32 @@ func (nm *netMonitorImpl) addrsChanged() bool {
return false
}
set := addrSliceToSet(addrs)
if setsAreEqual(set, nm.addrs) {

// Probe for routes to a set of remote addresses. These addresses are
// the same as those used by myIpAddressEx.
// TODO: Cache the results so they don't need to be recalculated in
// myIpAddress and myIpAddressEx.
remotes := []string{
"8.8.8.8", "2001:4860:4860::8888", // public addresses
"10.0.0.0", "172.16.0.0", "192.168.0.0", "FC00::", // private addresses
}
routes := map[string]net.IP{}
routesChanged := false
for _, remote := range remotes {
local := nm.probeRoute(remote, false)
routes[remote] = local
was, ok := nm.routes[remote]
if nm.routes == nil || !ok || !was.Equal(local) {
routesChanged = true
}
}

if setsAreEqual(set, nm.addrs) && !routesChanged {
return false
} else {
log.Printf("Network changes detected: %v", addrs)
nm.addrs = set
return true
}
nm.addrs = set
nm.routes = routes
return true
}

func addrSliceToSet(slice []net.Addr) map[string]struct{} {
Expand All @@ -67,3 +88,33 @@ func setsAreEqual(a, b map[string]struct{}) bool {
}
return true
}

// probeRoute creates a UDP "connection" to the remote address, and returns the
// local interface address. This does involve a system call, but does not
// generate any network traffic since UDP is a connectionless protocol.
func (nm *netMonitorImpl) probeRoute(host string, ipv4only bool) net.IP {
var network string
if ipv4only {
network = "udp4"
} else {
network = "udp"
}
conn, err := nm.dial(network, net.JoinHostPort(host, "80"))
if err != nil {
return nil
}
defer conn.Close()
local, ok := conn.LocalAddr().(*net.UDPAddr)
if !ok {
// Since we called dial with network set to "udp4" or "udp", we
// expect this to be a *net.UDPAddr. If this fails, it's a bug
// in Alpaca, and hopefully users will report it. But it's not
// worth panicking over so we won't end the request here.
log.Printf("unexpected: probeRoute host=%q ipv4only=%t: %v", host, ipv4only, err)
return nil
}
if ip := local.IP; ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
return nil
}
return local.IP
}
146 changes: 129 additions & 17 deletions netmonitor_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2019 The Alpaca Authors
// Copyright 2019, 2024 The Alpaca Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -16,10 +16,13 @@ package main

import (
"errors"
"fmt"
"net"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

type mockAddr string
Expand All @@ -40,28 +43,137 @@ func toAddrs(ss ...string) []net.Addr {
return addrs
}

type mockConn struct {
localAddr net.Addr
}

var _ net.Conn = mockConn{}

func (c mockConn) Close() error {
return nil
}

func (c mockConn) LocalAddr() net.Addr {
return c.localAddr
}

func (c mockConn) Read(b []byte) (n int, err error) {
panic("unreachable")
}

func (c mockConn) RemoteAddr() net.Addr {
panic("unreachable")
}

func (c mockConn) SetDeadline(t time.Time) error {
panic("unreachable")
}

func (c mockConn) SetReadDeadline(t time.Time) error {
panic("unreachable")
}

func (c mockConn) SetWriteDeadline(t time.Time) error {
panic("unreachable")
}

func (c mockConn) Write(b []byte) (n int, err error) {
panic("unreachable")
}

type mockNet struct {
connectedTo string
}

func (n *mockNet) interfaceAddrs() ([]net.Addr, error) {
var addrs []net.Addr
switch n.connectedTo {
case "vpn", "wifi":
addrs = append(addrs, toAddrs("192.168.1.2/24", "fe80::fedc:ba98:7654:3210/64")...)
fallthrough
case "offline":
addrs = append(addrs, toAddrs("127.0.0.1/8", "::1/128")...)
default:
panic("interfaceAddrs connectedTo=" + n.connectedTo)
}
return addrs, nil
}

func (n *mockNet) dial(network, address string) (net.Conn, error) {
if network != "udp" && network != "udp4" {
panic("dial network=" + network)
}
host, _, err := net.SplitHostPort(address)
if err != nil {
panic("dial: " + err.Error())
}
ip := net.ParseIP(host)
if ip == nil {
panic("dial host=" + host)
}
if ipv4 := ip.To4(); ipv4 == nil {
return nil, &net.OpError{
Op: "dial",
Net: network,
Source: nil,
Addr: mockAddr(address),
Err: errors.New("connect: no route to host"),
}
}
switch n.connectedTo {
case "vpn":
return mockConn{localAddr: &net.UDPAddr{IP: net.IPv4(10, 0, 0, 3)}}, nil
case "wifi":
return mockConn{localAddr: &net.UDPAddr{IP: net.IPv4(192, 168, 1, 2)}}, nil
case "offline":
return nil, &net.OpError{
Op: "dial",
Net: network,
Source: nil,
Addr: mockAddr(address),
Err: errors.New("connect: network is unreachable"),
}
default:
panic("dial connectedTo=" + n.connectedTo)
}
}

func TestNetworkMonitor(t *testing.T) {
var next []net.Addr
nm := &netMonitorImpl{getAddrs: func() ([]net.Addr, error) { return next, nil }}
// Start with just loopback interfaces
next = toAddrs("127.0.0.1/8", "::1/128")
var network mockNet
nm := &netMonitorImpl{getAddrs: network.interfaceAddrs, dial: network.dial}
network.connectedTo = "offline"
assert.True(t, nm.addrsChanged())
// Connect to network, and get local IPv4 and IPv6 addresses
next = toAddrs("127.0.0.1/8", "192.168.1.6/24", "::1/128", "fe80::dfd9:fe1d:56d1:1f3a/64")
network.connectedTo = "wifi"
assert.True(t, nm.addrsChanged())
// Stay connected, nothing changed
next = toAddrs("127.0.0.1/8", "192.168.1.6/24", "::1/128", "fe80::dfd9:fe1d:56d1:1f3a/64")
assert.False(t, nm.addrsChanged())
// DHCP lease expires, get new addresses
next = toAddrs("127.0.0.1/8", "192.168.1.7/24", "::1/128", "fe80::dfd9:fe1d:56d1:1f3b/64")
network.connectedTo = "vpn"
fmt.Println(`network.connectedTo = "vpn"`)
assert.True(t, nm.addrsChanged())
// Disconnect, and go back to having just loopback addresses
next = toAddrs("127.0.0.1/8", "::1/128")
fmt.Println(`network.connectedTo = "offline"`)
network.connectedTo = "offline"
assert.True(t, nm.addrsChanged())
}

func TestFailToGetAddrs(t *testing.T) {
alwaysFail := func() ([]net.Addr, error) { return nil, errors.New("failed") }
nm := &netMonitorImpl{getAddrs: alwaysFail}
assert.False(t, nm.addrsChanged())
func TestDumpAddrs(t *testing.T) {
ifaces, err := net.Interfaces()
require.NoError(t, err)
t.Log("---------- INTERFACES AND ADDRESSES: ----------")
for _, iface := range ifaces {
addrs, err := iface.Addrs()
require.NoError(t, err)
t.Logf("%s -> %q", iface.Name, addrs)
}
remotes := []string{
"8.8.8.8", "2001:4860:4860::8888", // public addresses
"10.0.0.0", "172.16.0.0", "192.168.0.0", "FC00::", // private addresses
}
t.Log("---------- ROUTES: ----------")
for _, addr := range remotes {
conn, err := net.Dial("udp", net.JoinHostPort(addr, "80"))
if err != nil {
t.Logf("%q => error: %v\n", addr, err)
continue
}
t.Logf("%q => %q\n", addr, conn.LocalAddr().String())
}
}

0 comments on commit 31fb72c

Please sign in to comment.