Skip to content

Commit

Permalink
Improve nftables rules
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Jun 11, 2024
1 parent e5f9651 commit 85fe25a
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 151 deletions.
41 changes: 27 additions & 14 deletions redirect_iptables.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,28 @@ import (
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"

"golang.org/x/sys/unix"
)

func (r *autoRedirect) iptablesPathForFamily(family int) string {
if family == unix.AF_INET {
return r.iptablesPath
} else {
return r.ip6tablesPath
func (r *autoRedirect) setupIPTables() error {
if r.enableIPv4 {
err := r.setupIPTablesForFamily(r.iptablesPath)
if err != nil {
return err
}
}
if r.enableIPv6 {
err := r.setupIPTablesForFamily(r.ip6tablesPath)
if err != nil {
return err
}
}
return nil
}

func (r *autoRedirect) setupIPTables(family int) error {
func (r *autoRedirect) setupIPTablesForFamily(iptablesPath string) error {
tableNameOutput := r.tableName + "-output"
tableNameForward := r.tableName + "-forward"
tableNamePreRouteing := r.tableName + "-prerouting"
iptablesPath := r.iptablesPathForFamily(family)
redirectPort := r.redirectPort()
// OUTPUT
err := r.runShell(iptablesPath, "-t nat -N", tableNameOutput)
Expand Down Expand Up @@ -74,7 +79,7 @@ func (r *autoRedirect) setupIPTables(family int) error {
routeAddress []netip.Prefix
routeExcludeAddress []netip.Prefix
)
if family == unix.AF_INET {
if iptablesPath == r.iptablesPath {
routeAddress = r.tunOptions.Inet4RouteAddress
routeExcludeAddress = r.tunOptions.Inet4RouteExcludeAddress
} else {
Expand Down Expand Up @@ -112,10 +117,10 @@ func (r *autoRedirect) setupIPTables(family int) error {
}
if !r.tunOptions.EXP_DisableDNSHijack {
dnsServer := common.Find(r.tunOptions.DNSServers, func(it netip.Addr) bool {
return it.Is4() == (family == unix.AF_INET)
return it.Is4() == (iptablesPath == r.iptablesPath)
})
if !dnsServer.IsValid() {
if family == unix.AF_INET {
if iptablesPath == r.iptablesPath {
dnsServer = r.tunOptions.Inet4Address[0].Addr().Next()
} else {
dnsServer = r.tunOptions.Inet6Address[0].Addr().Next()
Expand Down Expand Up @@ -199,11 +204,19 @@ func (r *autoRedirect) setupIPTables(family int) error {
return nil
}

func (r *autoRedirect) cleanupIPTables(family int) {
func (r *autoRedirect) cleanupIPTables() {
if r.enableIPv4 {
r.cleanupIPTablesForFamily(r.iptablesPath)
}
if r.enableIPv6 {
r.cleanupIPTablesForFamily(r.ip6tablesPath)
}
}

func (r *autoRedirect) cleanupIPTablesForFamily(iptablesPath string) {
tableNameOutput := r.tableName + "-output"
tableNameForward := r.tableName + "-forward"
tableNamePreRouteing := r.tableName + "-prerouting"
iptablesPath := r.iptablesPathForFamily(family)
_ = r.runShell(iptablesPath, "-t nat -D OUTPUT -j", tableNameOutput)
_ = r.runShell(iptablesPath, "-t nat -F", tableNameOutput)
_ = r.runShell(iptablesPath, "-t nat -X", tableNameOutput)
Expand Down
53 changes: 11 additions & 42 deletions redirect_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ import (
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"

"golang.org/x/sys/unix"
)

type autoRedirect struct {
Expand Down Expand Up @@ -118,11 +116,19 @@ func (r *autoRedirect) Start() error {
}
r.redirectServer = server
}
return r.setupTables()
if r.useNFTables {
return r.setupNFTables()
} else {
return r.setupIPTables()
}
}

func (r *autoRedirect) Close() error {
r.cleanupTables()
if r.useNFTables {
r.cleanupNFTables()
} else {
r.cleanupIPTables()
}
return common.Close(
common.PtrOrNil(r.redirectServer),
)
Expand All @@ -134,7 +140,7 @@ func (r *autoRedirect) initializeNFTables() error {
return err
}
defer nft.CloseLasting()
_, err = nft.ListTablesOfFamily(unix.AF_INET)
_, err = nft.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {
return err
}
Expand All @@ -148,40 +154,3 @@ func (r *autoRedirect) redirectPort() uint16 {
}
return M.AddrPortFromNet(r.redirectServer.listener.Addr()).Port()
}

func (r *autoRedirect) setupTables() error {
var setupTables func(int) error
if r.useNFTables {
setupTables = r.setupNFTables
} else {
setupTables = r.setupIPTables
}
if r.enableIPv4 {
err := setupTables(unix.AF_INET)
if err != nil {
return err
}
}
if r.enableIPv6 {
err := setupTables(unix.AF_INET6)
if err != nil {
return err
}
}
return nil
}

func (r *autoRedirect) cleanupTables() {
var cleanupTables func(int)
if r.useNFTables {
cleanupTables = r.cleanupNFTables
} else {
cleanupTables = r.cleanupIPTables
}
if r.enableIPv4 {
cleanupTables(unix.AF_INET)
}
if r.enableIPv6 {
cleanupTables(unix.AF_INET6)
}
}
139 changes: 73 additions & 66 deletions redirect_nftables.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,57 +9,24 @@ import (
"github.com/sagernet/nftables/binaryutil"
"github.com/sagernet/nftables/expr"
"github.com/sagernet/sing/common"
F "github.com/sagernet/sing/common/format"

"golang.org/x/sys/unix"
)

const (
nftablesChainOutput = "output"
nftablesChainForward = "forward"
nftablesChainPreRouting = "prerouting"
)

func nftablesFamily(family int) nftables.TableFamily {
switch family {
case unix.AF_INET:
return nftables.TableFamilyIPv4
case unix.AF_INET6:
return nftables.TableFamilyIPv6
default:
panic(F.ToString("unknown family ", family))
}
}

func (r *autoRedirect) setupNFTables(family int) error {
func (r *autoRedirect) setupNFTables() error {
nft, err := nftables.New()
if err != nil {
return err
}
defer nft.CloseLasting()

redirectPort := r.redirectPort()

table := nft.AddTable(&nftables.Table{
Name: r.tableName,
Family: nftablesFamily(family),
})

chainOutput := nft.AddChain(&nftables.Chain{
Name: nftablesChainOutput,
Table: table,
Hooknum: nftables.ChainHookOutput,
Priority: nftables.ChainPriorityMangle,
Type: nftables.ChainTypeNAT,
})
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chainOutput,
Exprs: nftablesRuleIfName(expr.MetaKeyOIFNAME, r.tunOptions.Name, nftablesRuleRedirectToPorts(redirectPort)...),
Family: nftables.TableFamilyINet,
})

chainForward := nft.AddChain(&nftables.Chain{
Name: nftablesChainForward,
Name: "forward",
Table: table,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityMangle,
Expand All @@ -79,8 +46,22 @@ func (r *autoRedirect) setupNFTables(family int) error {
}),
})

redirectPort := r.redirectPort()
chainOutput := nft.AddChain(&nftables.Chain{
Name: "output",
Table: table,
Hooknum: nftables.ChainHookOutput,
Priority: nftables.ChainPriorityMangle,
Type: nftables.ChainTypeNAT,
})
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chainOutput,
Exprs: nftablesRuleIfName(expr.MetaKeyOIFNAME, r.tunOptions.Name, nftablesRuleRedirectToPorts(redirectPort)...),
})

chainPreRouting := nft.AddChain(&nftables.Chain{
Name: nftablesChainPreRouting,
Name: "prerouting",
Table: table,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityMangle,
Expand All @@ -97,12 +78,13 @@ func (r *autoRedirect) setupNFTables(family int) error {
routeAddress []netip.Prefix
routeExcludeAddress []netip.Prefix
)
if table.Family == nftables.TableFamilyIPv4 {
routeAddress = r.tunOptions.Inet4RouteAddress
routeExcludeAddress = r.tunOptions.Inet4RouteExcludeAddress
} else {
routeAddress = r.tunOptions.Inet6RouteAddress
routeExcludeAddress = r.tunOptions.Inet6RouteExcludeAddress
if r.enableIPv4 {
routeAddress = append(routeAddress, r.tunOptions.Inet4RouteAddress...)
routeExcludeAddress = append(routeExcludeAddress, r.tunOptions.Inet4RouteExcludeAddress...)
}
if r.enableIPv6 {
routeAddress = append(routeAddress, r.tunOptions.Inet6RouteAddress...)
routeExcludeAddress = append(routeExcludeAddress, r.tunOptions.Inet6RouteExcludeAddress...)
}
for _, address := range routeExcludeAddress {
nft.AddRule(&nftables.Rule{
Expand Down Expand Up @@ -140,37 +122,66 @@ func (r *autoRedirect) setupNFTables(family int) error {
}

if !r.tunOptions.EXP_DisableDNSHijack {
dnsServer := common.Find(r.tunOptions.DNSServers, func(it netip.Addr) bool {
return it.Is4() == (family == unix.AF_INET)
dnsServer4 := common.Find(r.tunOptions.DNSServers, func(it netip.Addr) bool {
return it.Is4()
})
if !dnsServer.IsValid() {
if family == unix.AF_INET {
dnsServer = r.tunOptions.Inet4Address[0].Addr().Next()
} else {
dnsServer = r.tunOptions.Inet6Address[0].Addr().Next()
}
dnsServer6 := common.Find(r.tunOptions.DNSServers, func(it netip.Addr) bool {
return it.Is6()
})
if r.enableIPv4 && !dnsServer4.IsValid() {
dnsServer4 = r.tunOptions.Inet4Address[0].Addr().Next()
}
if r.enableIPv6 && !dnsServer6.IsValid() {
dnsServer6 = r.tunOptions.Inet6Address[0].Addr().Next()
}
if len(r.tunOptions.IncludeInterface) > 0 || len(r.tunOptions.IncludeUID) > 0 {
for _, name := range r.tunOptions.IncludeInterface {
if r.enableIPv4 {
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chainPreRouting,
Exprs: nftablesRuleIfName(expr.MetaKeyIIFNAME, name, append(routeExprs, nftablesRuleHijackDNS(nftables.TableFamilyIPv4, dnsServer4)...)...),
})
}
if r.enableIPv6 {
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chainPreRouting,
Exprs: nftablesRuleIfName(expr.MetaKeyIIFNAME, name, append(routeExprs, nftablesRuleHijackDNS(nftables.TableFamilyIPv6, dnsServer6)...)...),
})
}
}
for _, uidRange := range r.tunOptions.IncludeUID {
if r.enableIPv4 {
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chainPreRouting,
Exprs: nftablesRuleMetaUInt32Range(expr.MetaKeySKUID, uidRange, append(routeExprs, nftablesRuleHijackDNS(nftables.TableFamilyIPv4, dnsServer4)...)...),
})
}
if r.enableIPv6 {
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chainPreRouting,
Exprs: nftablesRuleMetaUInt32Range(expr.MetaKeySKUID, uidRange, append(routeExprs, nftablesRuleHijackDNS(nftables.TableFamilyIPv6, dnsServer6)...)...),
})
}
}
} else {
if r.enableIPv4 {
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chainPreRouting,
Exprs: nftablesRuleIfName(expr.MetaKeyIIFNAME, name, append(routeExprs, nftablesRuleHijackDNS(table.Family, dnsServer)...)...),
Exprs: append(routeExprs, nftablesRuleHijackDNS(nftables.TableFamilyIPv4, dnsServer4)...),
})
}
for _, uidRange := range r.tunOptions.IncludeUID {
if r.enableIPv6 {
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chainPreRouting,
Exprs: nftablesRuleMetaUInt32Range(expr.MetaKeySKUID, uidRange, append(routeExprs, nftablesRuleHijackDNS(table.Family, dnsServer)...)...),
Exprs: append(routeExprs, nftablesRuleHijackDNS(nftables.TableFamilyIPv6, dnsServer6)...),
})
}
} else {
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chainPreRouting,
Exprs: append(routeExprs, nftablesRuleHijackDNS(table.Family, dnsServer)...),
})
}
}

Expand Down Expand Up @@ -219,18 +230,14 @@ func (r *autoRedirect) setupNFTables(family int) error {
return nft.Flush()
}

func (r *autoRedirect) cleanupNFTables(family int) {
func (r *autoRedirect) cleanupNFTables() {
conn, err := nftables.New()
if err != nil {
return
}
conn.FlushTable(&nftables.Table{
Name: r.tableName,
Family: nftablesFamily(family),
})
conn.DelTable(&nftables.Table{
Name: r.tableName,
Family: nftablesFamily(family),
Family: nftables.TableFamilyINet,
})
_ = conn.Flush()
_ = conn.CloseLasting()
Expand Down
Loading

0 comments on commit 85fe25a

Please sign in to comment.