diff --git a/VERSION.txt b/VERSION.txt index 9db5ea12f5227..5525f03fa61b6 100644 --- a/VERSION.txt +++ b/VERSION.txt @@ -1 +1 @@ -1.48.0 +1.48.1 diff --git a/clientupdate/clientupdate.go b/clientupdate/clientupdate.go index 6cbff0b85bf8c..5ca276b272f53 100644 --- a/clientupdate/clientupdate.go +++ b/clientupdate/clientupdate.go @@ -28,9 +28,7 @@ import ( "time" "github.com/google/uuid" - "tailscale.com/hostinfo" "tailscale.com/net/tshttpproxy" - "tailscale.com/tailcfg" "tailscale.com/types/logger" "tailscale.com/util/must" "tailscale.com/util/winutil" @@ -187,6 +185,8 @@ func (up *updater) confirm(ver string) bool { return true } +const synoinfoConfPath = "/etc/synoinfo.conf" + func (up *updater) updateSynology() error { if up.Version != "" { return errors.New("installing a specific version on Synology is not supported") @@ -194,7 +194,7 @@ func (up *updater) updateSynology() error { // Get the latest version and list of SPKs from pkgs.tailscale.com. osName := fmt.Sprintf("dsm%d", distro.DSMVersion()) - arch, err := synoArch(hostinfo.New()) + arch, err := synoArch(runtime.GOARCH, synoinfoConfPath) if err != nil { return err } @@ -245,51 +245,62 @@ func (up *updater) updateSynology() error { // synoArch returns the Synology CPU architecture matching one of the SPK // architectures served from pkgs.tailscale.com. -func synoArch(hinfo *tailcfg.Hostinfo) (string, error) { +func synoArch(goArch, synoinfoPath string) (string, error) { // Most Synology boxes just use a different arch name from GOARCH. arch := map[string]string{ "amd64": "x86_64", "386": "i686", "arm64": "armv8", - }[hinfo.GoArch] - // Here's the fun part, some older ARM boxes require you to use SPKs - // specifically for their CPU. - // - // See https://github.com/SynoCommunity/spksrc/wiki/Synology-and-SynoCommunity-Package-Architectures - // for a complete list. Here, we override GOARCH for those older boxes that - // support at least DSM6. - // - // This is an artisanal hand-crafted list based on the wiki page. Some - // values may be wrong, since we don't have all those devices to actually - // test with. - switch hinfo.DeviceModel { - case "DS213air", "DS213", "DS413j", - "DS112", "DS112+", "DS212", "DS212+", "RS212", "RS812", "DS212j", "DS112j", - "DS111", "DS211", "DS211+", "DS411slim", "DS411", "RS411", "DS211j", "DS411j": - arch = "88f6281" - case "NVR1218", "NVR216", "VS960HD", "VS360HD": - arch = "hi3535" - case "DS1517", "DS1817", "DS416", "DS2015xs", "DS715", "DS1515", "DS215+": - arch = "alpine" - case "DS216se", "DS115j", "DS114", "DS214se", "DS414slim", "RS214", "DS14", "EDS14", "DS213j": - arch = "armada370" - case "DS115", "DS215j": - arch = "armada375" - case "DS419slim", "DS218j", "RS217", "DS116", "DS216j", "DS216", "DS416slim", "RS816", "DS416j": - arch = "armada38x" - case "RS815", "DS214", "DS214+", "DS414", "RS814": - arch = "armadaxp" - case "DS414j": - arch = "comcerto2k" - case "DS216play": - arch = "monaco" - } + }[goArch] + if arch == "" { - return "", fmt.Errorf("cannot determine CPU architecture for Synology model %q (Go arch %q), please report a bug at https://github.com/tailscale/tailscale/issues/new/choose", hinfo.DeviceModel, hinfo.GoArch) + // Here's the fun part, some older ARM boxes require you to use SPKs + // specifically for their CPU. See + // https://github.com/SynoCommunity/spksrc/wiki/Synology-and-SynoCommunity-Package-Architectures + // for a complete list. + // + // Some CPUs will map to neither this list nor the goArch map above, and we + // don't have SPKs for them. + cpu, err := parseSynoinfo(synoinfoPath) + if err != nil { + return "", fmt.Errorf("failed to get CPU architecture: %w", err) + } + switch cpu { + case "88f6281", "88f6282", "hi3535", "alpine", "armada370", + "armada375", "armada38x", "armadaxp", "comcerto2k", "monaco": + arch = cpu + default: + return "", fmt.Errorf("unsupported Synology CPU architecture %q (Go arch %q), please report a bug at https://github.com/tailscale/tailscale/issues/new/choose", cpu, goArch) + } } return arch, nil } +func parseSynoinfo(path string) (string, error) { + f, err := os.Open(path) + if err != nil { + return "", err + } + defer f.Close() + + // Look for a line like: + // unique="synology_88f6282_413j" + // Extract the CPU in the middle (88f6282 in the above example). + s := bufio.NewScanner(f) + for s.Scan() { + l := s.Text() + if !strings.HasPrefix(l, "unique=") { + continue + } + parts := strings.SplitN(l, "_", 3) + if len(parts) != 3 { + return "", fmt.Errorf(`malformed %q: found %q, expected format like 'unique="synology_$cpu_$model'`, path, l) + } + return parts[1], nil + } + return "", fmt.Errorf(`missing "unique=" field in %q`, path) +} + func (up *updater) updateDebLike() error { ver, err := requestedTailscaleVersion(up.Version, up.track) if err != nil { diff --git a/clientupdate/clientupdate_test.go b/clientupdate/clientupdate_test.go index ec96ea79d2d65..83aa6a07e2e18 100644 --- a/clientupdate/clientupdate_test.go +++ b/clientupdate/clientupdate_test.go @@ -8,8 +8,6 @@ import ( "os" "path/filepath" "testing" - - "tailscale.com/tailcfg" ) func TestUpdateDebianAptSourcesListBytes(t *testing.T) { @@ -446,29 +444,151 @@ tailscale installed size: func TestSynoArch(t *testing.T) { tests := []struct { - goarch string - model string + goarch string + synoinfoUnique string + want string + wantErr bool + }{ + {goarch: "amd64", synoinfoUnique: "synology_x86_224", want: "x86_64"}, + {goarch: "arm64", synoinfoUnique: "synology_armv8_124", want: "armv8"}, + {goarch: "386", synoinfoUnique: "synology_i686_415play", want: "i686"}, + {goarch: "arm", synoinfoUnique: "synology_88f6281_213air", want: "88f6281"}, + {goarch: "arm", synoinfoUnique: "synology_88f6282_413j", want: "88f6282"}, + {goarch: "arm", synoinfoUnique: "synology_hi3535_NVR1218", want: "hi3535"}, + {goarch: "arm", synoinfoUnique: "synology_alpine_1517", want: "alpine"}, + {goarch: "arm", synoinfoUnique: "synology_armada370_216se", want: "armada370"}, + {goarch: "arm", synoinfoUnique: "synology_armada375_115", want: "armada375"}, + {goarch: "arm", synoinfoUnique: "synology_armada38x_419slim", want: "armada38x"}, + {goarch: "arm", synoinfoUnique: "synology_armadaxp_RS815", want: "armadaxp"}, + {goarch: "arm", synoinfoUnique: "synology_comcerto2k_414j", want: "comcerto2k"}, + {goarch: "arm", synoinfoUnique: "synology_monaco_216play", want: "monaco"}, + {goarch: "ppc64", synoinfoUnique: "synology_qoriq_413", wantErr: true}, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("%s-%s", tt.goarch, tt.synoinfoUnique), func(t *testing.T) { + synoinfoConfPath := filepath.Join(t.TempDir(), "synoinfo.conf") + if err := os.WriteFile( + synoinfoConfPath, + []byte(fmt.Sprintf("unique=%q\n", tt.synoinfoUnique)), + 0600, + ); err != nil { + t.Fatal(err) + } + got, err := synoArch(tt.goarch, synoinfoConfPath) + if err != nil { + if !tt.wantErr { + t.Fatalf("got unexpected error %v", err) + } + return + } + if tt.wantErr { + t.Fatalf("got %q, expected an error", got) + } + if got != tt.want { + t.Errorf("got %q, want %q", got, tt.want) + } + }) + } +} + +func TestParseSynoinfo(t *testing.T) { + tests := []struct { + desc string + content string want string wantErr bool }{ - {goarch: "amd64", model: "DS224+", want: "x86_64"}, - {goarch: "arm64", model: "DS124", want: "armv8"}, - {goarch: "386", model: "DS415play", want: "i686"}, - {goarch: "arm", model: "DS213air", want: "88f6281"}, - {goarch: "arm", model: "NVR1218", want: "hi3535"}, - {goarch: "arm", model: "DS1517", want: "alpine"}, - {goarch: "arm", model: "DS216se", want: "armada370"}, - {goarch: "arm", model: "DS115", want: "armada375"}, - {goarch: "arm", model: "DS419slim", want: "armada38x"}, - {goarch: "arm", model: "RS815", want: "armadaxp"}, - {goarch: "arm", model: "DS414j", want: "comcerto2k"}, - {goarch: "arm", model: "DS216play", want: "monaco"}, - {goarch: "riscv64", model: "DS999", wantErr: true}, - } + { + desc: "double-quoted", + content: ` +company_title="Synology" +unique="synology_88f6281_213air" +`, + want: "88f6281", + }, + { + desc: "single-quoted", + content: ` +company_title="Synology" +unique='synology_88f6281_213air' +`, + want: "88f6281", + }, + { + desc: "unquoted", + content: ` +company_title="Synology" +unique=synology_88f6281_213air +`, + want: "88f6281", + }, + { + desc: "missing unique", + content: ` +company_title="Synology" +`, + wantErr: true, + }, + { + desc: "empty unique", + content: ` +company_title="Synology" +unique= +`, + wantErr: true, + }, + { + desc: "empty unique double-quoted", + content: ` +company_title="Synology" +unique="" +`, + wantErr: true, + }, + { + desc: "empty unique single-quoted", + content: ` +company_title="Synology" +unique='' +`, + wantErr: true, + }, + { + desc: "malformed unique", + content: ` +company_title="Synology" +unique="synology_88f6281" +`, + wantErr: true, + }, + { + desc: "empty file", + content: ``, + wantErr: true, + }, + { + desc: "empty lines and comments", + content: ` + +# In a file named synoinfo? Shocking! +company_title="Synology" + +# unique= is_a_field_that_follows +unique="synology_88f6281_213air" + +`, + want: "88f6281", + }, + } for _, tt := range tests { - t.Run(fmt.Sprintf("%s-%s", tt.goarch, tt.model), func(t *testing.T) { - got, err := synoArch(&tailcfg.Hostinfo{GoArch: tt.goarch, DeviceModel: tt.model}) + t.Run(tt.desc, func(t *testing.T) { + synoinfoConfPath := filepath.Join(t.TempDir(), "synoinfo.conf") + if err := os.WriteFile(synoinfoConfPath, []byte(tt.content), 0600); err != nil { + t.Fatal(err) + } + got, err := parseSynoinfo(synoinfoConfPath) if err != nil { if !tt.wantErr { t.Fatalf("got unexpected error %v", err) diff --git a/net/portmapper/upnp.go b/net/portmapper/upnp.go index 6a525c54ff297..34cae584067e1 100644 --- a/net/portmapper/upnp.go +++ b/net/portmapper/upnp.go @@ -106,11 +106,13 @@ type upnpClient interface { // It is not used for anything other than labelling. const tsPortMappingDesc = "tailscale-portmap" -// addAnyPortMapping abstracts over different UPnP client connections, calling the available -// AddAnyPortMapping call if available for WAN IP connection v2, otherwise defaulting to the old -// behavior of calling AddPortMapping with port = 0 to specify a wildcard port. -// It returns the new external port (which may not be identical to the external port specified), -// or an error. +// addAnyPortMapping abstracts over different UPnP client connections, calling +// the available AddAnyPortMapping call if available for WAN IP connection v2, +// otherwise picking either the previous port (if one is present) or a random +// port and trying to obtain a mapping using AddPortMapping. +// +// It returns the new external port (which may not be identical to the external +// port specified), or an error. // // TODO(bradfitz): also returned the actual lease duration obtained. and check it regularly. func addAnyPortMapping( @@ -121,6 +123,31 @@ func addAnyPortMapping( internalClient string, leaseDuration time.Duration, ) (newPort uint16, err error) { + // Some devices don't let clients add a port mapping for privileged + // ports (ports below 1024). Additionally, per section 2.3.18 of the + // UPnP spec, regarding the ExternalPort field: + // + // If this value is specified as a wildcard (i.e. 0), connection + // request on all external ports (that are not otherwise mapped) + // will be forwarded to InternalClient. In the wildcard case, the + // value(s) of InternalPort on InternalClient are ignored by the IGD + // for those connections that are forwarded to InternalClient. + // Obviously only one such entry can exist in the NAT at any time + // and conflicts are handled with a “first write wins” behavior. + // + // We obviously do not want to open all ports on the user's device to + // the internet, so we want to do this prior to calling either + // AddAnyPortMapping or AddPortMapping. + // + // Pick an external port that's greater than 1024 by getting a random + // number in [0, 65535 - 1024] and then adding 1024 to it, shifting the + // range to [1024, 65535]. + if externalPort < 1024 { + externalPort = uint16(rand.Intn(65535-1024) + 1024) + } + + // First off, try using AddAnyPortMapping; if there's a conflict, the + // router will pick another port and return it. if upnp, ok := upnp.(*internetgateway2.WANIPConnection2); ok { return upnp.AddAnyPortMapping( ctx, @@ -135,15 +162,8 @@ func addAnyPortMapping( ) } - // Some devices don't let clients add a port mapping for privileged - // ports (ports below 1024). - // - // Pick an external port that's greater than 1024 by getting a random - // number in [0, 65535 - 1024] and then adding 1024 to it, shifting the - // range to [1024, 65535]. - if externalPort < 1024 { - externalPort = uint16(rand.Intn(65535-1024) + 1024) - } + // Fall back to using AddPortMapping, which requests a mapping to/from + // a specific external port. err = upnp.AddPortMapping( ctx, "", diff --git a/tailcfg/tailcfg.go b/tailcfg/tailcfg.go index 7633cc17ac202..9a53e8ee2d660 100644 --- a/tailcfg/tailcfg.go +++ b/tailcfg/tailcfg.go @@ -734,9 +734,12 @@ type NetInfo struct { // the control plane. DERPLatency map[string]float64 `json:",omitempty"` - // FirewallMode is the current firewall utility in use by router (iptables, nftables). - // FirewallMode ipt means iptables, nft means nftables. When it's empty user is not using - // our netfilter runners to manage firewall rules. + // FirewallMode encodes both which firewall mode was selected and why. + // It is Linux-specific (at least as of 2023-08-19) and is meant to help + // debug iptables-vs-nftables issues. The string is of the form + // "{nft,ift}-REASON", like "nft-forced" or "ipt-default". Empty means + // either not Linux or a configuration in which the host firewall rules + // are not managed by tailscaled. FirewallMode string `json:",omitempty"` // Update BasicallyEqual when adding fields. diff --git a/util/linuxfw/nftables_runner.go b/util/linuxfw/nftables_runner.go index 4d46ea104d82b..a4d65857a5bd8 100644 --- a/util/linuxfw/nftables_runner.go +++ b/util/linuxfw/nftables_runner.go @@ -13,6 +13,7 @@ import ( "net" "net/netip" "reflect" + "strings" "github.com/google/nftables" "github.com/google/nftables/expr" @@ -26,12 +27,16 @@ const ( chainNamePostrouting = "ts-postrouting" ) +// chainTypeRegular is an nftables chain that does not apply to a hook. +const chainTypeRegular = "" + type chainInfo struct { table *nftables.Table name string chainType nftables.ChainType chainHook *nftables.ChainHook chainPriority *nftables.ChainPriority + chainPolicy *nftables.ChainPolicy } type nftable struct { @@ -40,6 +45,21 @@ type nftable struct { Nat *nftables.Table } +// nftablesRunner implements a netfilterRunner using the netlink based nftables +// library. As nftables allows for arbitrary tables and chains, there is a need +// to follow conventions in order to integrate well with a surrounding +// ecosystem. The rules installed by nftablesRunner have the following +// properties: +// - Install rules that intend to take precedence over rules installed by +// other software. Tailscale provides packet filtering for tailnet traffic +// inside the daemon based on the tailnet ACL rules. +// - As nftables "accept" is not final, rules from high priority tables (low +// numbers) will fall through to lower priority tables (high numbers). In +// order to effectively be 'final', we install "jump" rules into conventional +// tables and chains that will reach an accept verdict inside those tables. +// - The table and chain conventions followed here are those used by +// `iptables-nft` and `ufw`, so that those tools co-exist and do not +// negatively affect Tailscale function. type nftablesRunner struct { conn *nftables.Conn nft4 *nftable @@ -116,6 +136,11 @@ func getChainsFromTable(c *nftables.Conn, table *nftables.Table) ([]*nftables.Ch return ret, nil } +// isTSChain retruns true if the chain name starts with ts +func isTSChain(name string) bool { + return strings.HasPrefix(name, "ts-") +} + // createChainIfNotExist creates a chain with the given name in the given table // if it does not exist. func createChainIfNotExist(c *nftables.Conn, cinfo chainInfo) error { @@ -123,8 +148,11 @@ func createChainIfNotExist(c *nftables.Conn, cinfo chainInfo) error { if err != nil && !errors.Is(err, errorChainNotFound{cinfo.table.Name, cinfo.name}) { return fmt.Errorf("get chain: %w", err) } else if err == nil { - // Chain already exists - if chain.Type != cinfo.chainType || chain.Hooknum != cinfo.chainHook || chain.Priority != cinfo.chainPriority { + // The chain already exists. If it is a TS chain, check the + // type/hook/priority, but for "conventional chains" assume they're what + // we expect (in case iptables-nft/ufw make minor behavior changes in + // the future). + if isTSChain(chain.Name) && (chain.Type != cinfo.chainType || chain.Hooknum != cinfo.chainHook || chain.Priority != cinfo.chainPriority) { return fmt.Errorf("chain %s already exists with different type/hook/priority", cinfo.name) } return nil @@ -136,6 +164,7 @@ func createChainIfNotExist(c *nftables.Conn, cinfo chainInfo) error { Type: cinfo.chainType, Hooknum: cinfo.chainHook, Priority: cinfo.chainPriority, + Policy: cinfo.chainPolicy, }) if err := c.Flush(); err != nil { @@ -228,6 +257,10 @@ ruleLoop: } for i, e := range r.Exprs { + // Skip counter expressions, as they will not match. + if _, ok := e.(*expr.Counter); ok { + continue + } if !reflect.DeepEqual(e, rule.Exprs[i]) { continue ruleLoop } @@ -388,27 +421,49 @@ func (n *nftablesRunner) getNATTables() []*nftable { // AddChains creates custom Tailscale chains in netfilter via nftables // if the ts-chain doesn't already exist. func (n *nftablesRunner) AddChains() error { + polAccept := nftables.ChainPolicyAccept for _, table := range n.getTables() { - filter, err := createTableIfNotExist(n.conn, table.Proto, "ts-filter") + // Create the filter table if it doesn't exist, this table name is the same + // as the name used by iptables-nft and ufw. We install rules into the + // same conventional table so that `accept` verdicts from our jump + // chains are conclusive. + filter, err := createTableIfNotExist(n.conn, table.Proto, "filter") if err != nil { return fmt.Errorf("create table: %w", err) } table.Filter = filter - if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameForward, nftables.ChainTypeFilter, nftables.ChainHookForward, nftables.ChainPriorityRef(-1)}); err != nil { + // Adding the "conventional chains" that are used by iptables-nft and ufw. + if err = createChainIfNotExist(n.conn, chainInfo{filter, "FORWARD", nftables.ChainTypeFilter, nftables.ChainHookForward, nftables.ChainPriorityFilter, &polAccept}); err != nil { + return fmt.Errorf("create forward chain: %w", err) + } + if err = createChainIfNotExist(n.conn, chainInfo{filter, "INPUT", nftables.ChainTypeFilter, nftables.ChainHookInput, nftables.ChainPriorityFilter, &polAccept}); err != nil { + return fmt.Errorf("create input chain: %w", err) + } + // Adding the tailscale chains that contain our rules. + if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameForward, chainTypeRegular, nil, nil, nil}); err != nil { return fmt.Errorf("create forward chain: %w", err) } - if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameInput, nftables.ChainTypeFilter, nftables.ChainHookInput, nftables.ChainPriorityRef(-1)}); err != nil { + if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameInput, chainTypeRegular, nil, nil, nil}); err != nil { return fmt.Errorf("create input chain: %w", err) } } for _, table := range n.getNATTables() { - nat, err := createTableIfNotExist(n.conn, table.Proto, "ts-nat") + // Create the nat table if it doesn't exist, this table name is the same + // as the name used by iptables-nft and ufw. We install rules into the + // same conventional table so that `accept` verdicts from our jump + // chains are conclusive. + nat, err := createTableIfNotExist(n.conn, table.Proto, "nat") if err != nil { return fmt.Errorf("create table: %w", err) } table.Nat = nat - if err = createChainIfNotExist(n.conn, chainInfo{nat, chainNamePostrouting, nftables.ChainTypeNAT, nftables.ChainHookPostrouting, nftables.ChainPriorityNATDest}); err != nil { + // Adding the "conventional chains" that are used by iptables-nft and ufw. + if err = createChainIfNotExist(n.conn, chainInfo{nat, "POSTROUTING", nftables.ChainTypeNAT, nftables.ChainHookPostrouting, nftables.ChainPriorityNATSource, &polAccept}); err != nil { + return fmt.Errorf("create postrouting chain: %w", err) + } + // Adding the tailscale chain that contains our rules. + if err = createChainIfNotExist(n.conn, chainInfo{nat, chainNamePostrouting, chainTypeRegular, nil, nil, nil}); err != nil { return fmt.Errorf("create postrouting chain: %w", err) } } @@ -445,19 +500,16 @@ func (n *nftablesRunner) DelChains() error { if err := deleteChainIfExists(n.conn, table.Filter, chainNameInput); err != nil { return fmt.Errorf("delete chain: %w", err) } - n.conn.DelTable(table.Filter) } if err := deleteChainIfExists(n.conn, n.nft4.Nat, chainNamePostrouting); err != nil { return fmt.Errorf("delete chain: %w", err) } - n.conn.DelTable(n.nft4.Nat) if n.v6NATAvailable { if err := deleteChainIfExists(n.conn, n.nft6.Nat, chainNamePostrouting); err != nil { return fmt.Errorf("delete chain: %w", err) } - n.conn.DelTable(n.nft6.Nat) } if err := n.conn.Flush(); err != nil { @@ -467,15 +519,128 @@ func (n *nftablesRunner) DelChains() error { return nil } -// AddHooks is defined to satisfy the interface. NfTables does not require -// AddHooks, since we don't have any default tables or chains in nftables. +// createHookRule creates a rule to jump from a hooked chain to a regular chain. +func createHookRule(table *nftables.Table, fromChain *nftables.Chain, toChainName string) *nftables.Rule { + exprs := []expr.Any{ + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictJump, + Chain: toChainName, + }, + } + + rule := &nftables.Rule{ + Table: table, + Chain: fromChain, + Exprs: exprs, + } + + return rule +} + +// addHookRule adds a rule to jump from a hooked chain to a regular chain at top of the hooked chain. +func addHookRule(conn *nftables.Conn, table *nftables.Table, fromChain *nftables.Chain, toChainName string) error { + rule := createHookRule(table, fromChain, toChainName) + _ = conn.InsertRule(rule) + + if err := conn.Flush(); err != nil { + return fmt.Errorf("flush add rule: %w", err) + } + + return nil +} + +// AddHooks is adding rules to conventional chains like "FORWARD", "INPUT" and "POSTROUTING" +// in tables and jump from those chains to tailscale chains. func (n *nftablesRunner) AddHooks() error { + conn := n.conn + + for _, table := range n.getTables() { + inputChain, err := getChainFromTable(conn, table.Filter, "INPUT") + if err != nil { + return fmt.Errorf("get INPUT chain: %w", err) + } + err = addHookRule(conn, table.Filter, inputChain, chainNameInput) + if err != nil { + return fmt.Errorf("Addhook: %w", err) + } + forwardChain, err := getChainFromTable(conn, table.Filter, "FORWARD") + if err != nil { + return fmt.Errorf("get FORWARD chain: %w", err) + } + err = addHookRule(conn, table.Filter, forwardChain, chainNameForward) + if err != nil { + return fmt.Errorf("Addhook: %w", err) + } + } + + for _, table := range n.getNATTables() { + postroutingChain, err := getChainFromTable(conn, table.Nat, "POSTROUTING") + if err != nil { + return fmt.Errorf("get INPUT chain: %w", err) + } + err = addHookRule(conn, table.Nat, postroutingChain, chainNamePostrouting) + if err != nil { + return fmt.Errorf("Addhook: %w", err) + } + } return nil } -// DelHooks is defined to satisfy the interface. NfTables does not require -// DelHooks, since we don't have any default tables or chains in nftables. +// delHookRule deletes a rule that jumps from a hooked chain to a regular chain. +func delHookRule(conn *nftables.Conn, table *nftables.Table, fromChain *nftables.Chain, toChainName string) error { + rule := createHookRule(table, fromChain, toChainName) + existingRule, err := findRule(conn, rule) + if err != nil { + return fmt.Errorf("Failed to find hook rule: %w", err) + } + + if existingRule == nil { + return nil + } + + _ = conn.DelRule(existingRule) + + if err := conn.Flush(); err != nil { + return fmt.Errorf("flush del hook rule: %w", err) + } + return nil +} + +// DelHooks is deleting the rules added to conventional chains to jump to tailscale chains. func (n *nftablesRunner) DelHooks(logf logger.Logf) error { + conn := n.conn + + for _, table := range n.getTables() { + inputChain, err := getChainFromTable(conn, table.Filter, "INPUT") + if err != nil { + return fmt.Errorf("get INPUT chain: %w", err) + } + err = delHookRule(conn, table.Filter, inputChain, chainNameInput) + if err != nil { + return fmt.Errorf("delhook: %w", err) + } + forwardChain, err := getChainFromTable(conn, table.Filter, "FORWARD") + if err != nil { + return fmt.Errorf("get FORWARD chain: %w", err) + } + err = delHookRule(conn, table.Filter, forwardChain, chainNameForward) + if err != nil { + return fmt.Errorf("delhook: %w", err) + } + } + + for _, table := range n.getNATTables() { + postroutingChain, err := getChainFromTable(conn, table.Nat, "POSTROUTING") + if err != nil { + return fmt.Errorf("get INPUT chain: %w", err) + } + err = delHookRule(conn, table.Nat, postroutingChain, chainNamePostrouting) + if err != nil { + return fmt.Errorf("delhook: %w", err) + } + } + return nil } @@ -953,25 +1118,62 @@ func (n *nftablesRunner) DelSNATRule() error { return nil } +// cleanupChain removes a jump rule from hookChainName to tsChainName, and then +// the entire chain tsChainName. Errors are logged, but attempts to remove both +// the jump rule and chain continue even if one errors. +func cleanupChain(logf logger.Logf, conn *nftables.Conn, table *nftables.Table, hookChainName, tsChainName string) { + // remove the jump first, before removing the jump destination. + defaultChain, err := getChainFromTable(conn, table, hookChainName) + if err != nil && !errors.Is(err, errorChainNotFound{table.Name, hookChainName}) { + logf("cleanup: did not find default chain: %s", err) + } + if !errors.Is(err, errorChainNotFound{table.Name, hookChainName}) { + // delete hook in convention chain + _ = delHookRule(conn, table, defaultChain, tsChainName) + } + + tsChain, err := getChainFromTable(conn, table, tsChainName) + if err != nil && !errors.Is(err, errorChainNotFound{table.Name, tsChainName}) { + logf("cleanup: did not find ts-chain: %s", err) + } + + if tsChain != nil { + // flush and delete ts-chain + conn.FlushChain(tsChain) + conn.DelChain(tsChain) + err = conn.Flush() + logf("cleanup: delete and flush chain %s: %s", tsChainName, err) + } +} + // NfTablesCleanUp removes all Tailscale added nftables rules. // Any errors that occur are logged to the provided logf. func NfTablesCleanUp(logf logger.Logf) { conn, err := nftables.New() if err != nil { - logf("ERROR: nftables connection: %w", err) + logf("cleanup: nftables connection: %s", err) } tables, err := conn.ListTables() // both v4 and v6 if err != nil { - logf("ERROR: list tables: %w", err) + logf("cleanup: list tables: %s", err) } for _, table := range tables { + // These table names were used briefly in 1.48.0. if table.Name == "ts-filter" || table.Name == "ts-nat" { conn.DelTable(table) if err := conn.Flush(); err != nil { - logf("ERROR: flush table %s: %w", table.Name, err) + logf("cleanup: flush delete table %s: %s", table.Name, err) } } + + if table.Name == "filter" { + cleanupChain(logf, conn, table, "INPUT", chainNameInput) + cleanupChain(logf, conn, table, "FORWARD", chainNameForward) + } + if table.Name == "nat" { + cleanupChain(logf, conn, table, "POSTROUTING", chainNamePostrouting) + } } } diff --git a/util/linuxfw/nftables_runner_test.go b/util/linuxfw/nftables_runner_test.go index ab4543b2dca8d..ad068957ee9a3 100644 --- a/util/linuxfw/nftables_runner_test.go +++ b/util/linuxfw/nftables_runner_test.go @@ -101,6 +101,48 @@ func newTestConn(t *testing.T, want [][]byte) *nftables.Conn { return conn } +func TestInsertHookRule(t *testing.T) { + proto := nftables.TableFamilyIPv4 + want := [][]byte{ + // batch begin + []byte("\x00\x00\x00\x0a"), + // nft add table ip ts-filter-test + []byte("\x02\x00\x00\x00\x13\x00\x01\x00\x74\x73\x2d\x66\x69\x6c\x74\x65\x72\x2d\x74\x65\x73\x74\x00\x00\x08\x00\x02\x00\x00\x00\x00\x00"), + // nft add chain ip ts-filter-test ts-input-test { type filter hook input priority 0 \; } + []byte("\x02\x00\x00\x00\x13\x00\x01\x00\x74\x73\x2d\x66\x69\x6c\x74\x65\x72\x2d\x74\x65\x73\x74\x00\x00\x12\x00\x03\x00\x74\x73\x2d\x69\x6e\x70\x75\x74\x2d\x74\x65\x73\x74\x00\x00\x00\x14\x00\x04\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x0b\x00\x07\x00\x66\x69\x6c\x74\x65\x72\x00\x00"), + // nft add chain ip ts-filter-test ts-jumpto + []byte("\x02\x00\x00\x00\x13\x00\x01\x00\x74\x73\x2d\x66\x69\x6c\x74\x65\x72\x2d\x74\x65\x73\x74\x00\x00\x0e\x00\x03\x00\x74\x73\x2d\x6a\x75\x6d\x70\x74\x6f\x00\x00\x00"), + // nft add rule ip ts-filter-test ts-input-test counter jump ts-jumptp + []byte("\x02\x00\x00\x00\x13\x00\x01\x00\x74\x73\x2d\x66\x69\x6c\x74\x65\x72\x2d\x74\x65\x73\x74\x00\x00\x12\x00\x02\x00\x74\x73\x2d\x69\x6e\x70\x75\x74\x2d\x74\x65\x73\x74\x00\x00\x00\x70\x00\x04\x80\x2c\x00\x01\x80\x0c\x00\x01\x00\x63\x6f\x75\x6e\x74\x65\x72\x00\x1c\x00\x02\x80\x0c\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0c\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x40\x00\x01\x80\x0e\x00\x01\x00\x69\x6d\x6d\x65\x64\x69\x61\x74\x65\x00\x00\x00\x2c\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x00\x20\x00\x02\x80\x1c\x00\x02\x80\x08\x00\x01\x00\xff\xff\xff\xfd\x0e\x00\x02\x00\x74\x73\x2d\x6a\x75\x6d\x70\x74\x6f\x00\x00\x00"), + // batch end + []byte("\x00\x00\x00\x0a"), + } + testConn := newTestConn(t, want) + table := testConn.AddTable(&nftables.Table{ + Family: proto, + Name: "ts-filter-test", + }) + + fromchain := testConn.AddChain(&nftables.Chain{ + Name: "ts-input-test", + Table: table, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookInput, + Priority: nftables.ChainPriorityFilter, + }) + + tochain := testConn.AddChain(&nftables.Chain{ + Name: "ts-jumpto", + Table: table, + }) + + err := addHookRule(testConn, table, fromchain, tochain.Name) + if err != nil { + t.Fatal(err) + } + +} + func TestInsertLoopbackRule(t *testing.T) { proto := nftables.TableFamilyIPv4 want := [][]byte{ @@ -461,8 +503,8 @@ func TestAddAndDelNetfilterChains(t *testing.T) { t.Fatalf("list chains failed: %v", err) } - if len(chainsV4) != 3 { - t.Fatalf("len(chainsV4) = %d, want 3", len(chainsV4)) + if len(chainsV4) != 6 { + t.Fatalf("len(chainsV4) = %d, want 6", len(chainsV4)) } chainsV6, err := conn.ListChainsOfTableFamily(nftables.TableFamilyIPv6) @@ -470,8 +512,8 @@ func TestAddAndDelNetfilterChains(t *testing.T) { t.Fatalf("list chains failed: %v", err) } - if len(chainsV6) != 3 { - t.Fatalf("len(chainsV6) = %d, want 3", len(chainsV6)) + if len(chainsV6) != 6 { + t.Fatalf("len(chainsV6) = %d, want 6", len(chainsV6)) } runner.DelChains() @@ -788,3 +830,87 @@ func TestNFTAddAndDelLoopbackRule(t *testing.T) { t.Fatalf("len(inputV4Rules) = %d, want 2", len(inputV4Rules)) } } + +func TestNFTAddAndDelHookRule(t *testing.T) { + if os.Geteuid() != 0 { + t.Skip(t.Name(), " requires privileges to create a namespace in order to run") + return + } + + conn := newSysConn(t) + runner := newFakeNftablesRunner(t, conn) + runner.AddChains() + defer runner.DelChains() + runner.AddHooks() + + forwardChain, err := getChainFromTable(conn, runner.nft4.Filter, "FORWARD") + if err != nil { + t.Fatalf("failed to get forwardChain: %v", err) + } + + forwardChainRules, err := conn.GetRules(forwardChain.Table, forwardChain) + if err != nil { + t.Fatalf("failed to get rules: %v", err) + } + + if len(forwardChainRules) != 1 { + t.Fatalf("expected 1 rule in FORWARD chain, got %v", len(forwardChainRules)) + } + + inputChain, err := getChainFromTable(conn, runner.nft4.Filter, "INPUT") + if err != nil { + t.Fatalf("failed to get inputChain: %v", err) + } + + inputChainRules, err := conn.GetRules(inputChain.Table, inputChain) + if err != nil { + t.Fatalf("failed to get rules: %v", err) + } + + if len(inputChainRules) != 1 { + t.Fatalf("expected 1 rule in INPUT chain, got %v", len(inputChainRules)) + } + + postroutingChain, err := getChainFromTable(conn, runner.nft4.Nat, "POSTROUTING") + if err != nil { + t.Fatalf("failed to get postroutingChain: %v", err) + } + + postroutingChainRules, err := conn.GetRules(postroutingChain.Table, postroutingChain) + if err != nil { + t.Fatalf("failed to get rules: %v", err) + } + + if len(postroutingChainRules) != 1 { + t.Fatalf("expected 1 rule in POSTROUTING chain, got %v", len(postroutingChainRules)) + } + + runner.DelHooks(t.Logf) + + forwardChainRules, err = conn.GetRules(forwardChain.Table, forwardChain) + if err != nil { + t.Fatalf("failed to get rules: %v", err) + } + + if len(forwardChainRules) != 0 { + t.Fatalf("expected 0 rule in FORWARD chain, got %v", len(forwardChainRules)) + } + + inputChainRules, err = conn.GetRules(inputChain.Table, inputChain) + if err != nil { + t.Fatalf("failed to get rules: %v", err) + } + + if len(inputChainRules) != 0 { + t.Fatalf("expected 0 rule in INPUT chain, got %v", len(inputChainRules)) + } + + postroutingChainRules, err = conn.GetRules(postroutingChain.Table, postroutingChain) + if err != nil { + t.Fatalf("failed to get rules: %v", err) + } + + if len(postroutingChainRules) != 0 { + t.Fatalf("expected 0 rule in POSTROUTING chain, got %v", len(postroutingChainRules)) + } +} diff --git a/wgengine/router/router_linux.go b/wgengine/router/router_linux.go index 710e9cfe01893..8a7273bd225c4 100644 --- a/wgengine/router/router_linux.go +++ b/wgengine/router/router_linux.go @@ -85,41 +85,32 @@ func chooseFireWallMode(logf logger.Logf, det tableDetector) linuxfw.FirewallMod iptAva, nftAva := true, true iptRuleCount, err := det.iptDetect() if err != nil { - logf("router: detect iptables rule: %v", err) + logf("detect iptables rule: %v", err) iptAva = false } nftRuleCount, err := det.nftDetect() if err != nil { - logf("router: detect nftables rule: %v", err) + logf("detect nftables rule: %v", err) nftAva = false } - logf("router: nftables rule count: %d, iptables rule count: %d", nftRuleCount, iptRuleCount) + logf("nftables rule count: %d, iptables rule count: %d", nftRuleCount, iptRuleCount) switch { - case envknob.String("TS_DEBUG_FIREWALL_MODE") == "nftables": - // TODO(KevinLiang10): Updates to a flag - logf("router: envknob TS_DEBUG_FIREWALL_MODE=nftables set") - hostinfo.SetFirewallMode("nft-forced") - return linuxfw.FirewallModeNfTables - case envknob.String("TS_DEBUG_FIREWALL_MODE") == "iptables": - logf("router: envknob TS_DEBUG_FIREWALL_MODE=iptables set") - hostinfo.SetFirewallMode("ipt-forced") - return linuxfw.FirewallModeIPTables case nftRuleCount > 0 && iptRuleCount == 0: - logf("router: nftables is currently in use") + logf("nftables is currently in use") hostinfo.SetFirewallMode("nft-inuse") return linuxfw.FirewallModeNfTables case iptRuleCount > 0 && nftRuleCount == 0: - logf("router: iptables is currently in use") + logf("iptables is currently in use") hostinfo.SetFirewallMode("ipt-inuse") return linuxfw.FirewallModeIPTables case nftAva: // if both iptables and nftables are available but // neither/both are currently used, use nftables. - logf("router: nftables is available") + logf("nftables is available") hostinfo.SetFirewallMode("nft") return linuxfw.FirewallModeNfTables case iptAva: - logf("router: iptables is available") + logf("iptables is available") hostinfo.SetFirewallMode("ipt") return linuxfw.FirewallModeIPTables default: @@ -136,18 +127,44 @@ func chooseFireWallMode(logf logger.Logf, det tableDetector) linuxfw.FirewallMod // As nftables is still experimental, iptables will be used unless TS_DEBUG_USE_NETLINK_NFTABLES is set. func newNetfilterRunner(logf logger.Logf) (netfilterRunner, error) { tableDetector := &linuxFWDetector{} - mode := chooseFireWallMode(logf, tableDetector) + var mode linuxfw.FirewallMode + + // We now use iptables as default and have "auto" and "nftables" as + // options for people to test further. + switch { + case distro.Get() == distro.Gokrazy: + // Reduce startup logging on gokrazy. There's no way to do iptables on + // gokrazy anyway. + logf("GoKrazy should use nftables.") + hostinfo.SetFirewallMode("nft-gokrazy") + mode = linuxfw.FirewallModeNfTables + case envknob.String("TS_DEBUG_FIREWALL_MODE") == "nftables": + logf("envknob TS_DEBUG_FIREWALL_MODE=nftables set") + hostinfo.SetFirewallMode("nft-forced") + mode = linuxfw.FirewallModeNfTables + case envknob.String("TS_DEBUG_FIREWALL_MODE") == "auto": + mode = chooseFireWallMode(logf, tableDetector) + case envknob.String("TS_DEBUG_FIREWALL_MODE") == "iptables": + logf("envknob TS_DEBUG_FIREWALL_MODE=iptables set") + hostinfo.SetFirewallMode("ipt-forced") + mode = linuxfw.FirewallModeIPTables + default: + logf("default choosing iptables") + hostinfo.SetFirewallMode("ipt-default") + mode = linuxfw.FirewallModeIPTables + } + var nfr netfilterRunner var err error switch mode { case linuxfw.FirewallModeIPTables: - logf("router: using iptables") + logf("using iptables") nfr, err = linuxfw.NewIPTablesRunner(logf) if err != nil { return nil, err } case linuxfw.FirewallModeNfTables: - logf("router: using nftables") + logf("using nftables") nfr, err = linuxfw.NewNfTablesRunner(logf) if err != nil { return nil, err