Skip to content

Commit

Permalink
add local dst rule
Browse files Browse the repository at this point in the history
  • Loading branch information
abhishek9686 committed Dec 7, 2024
1 parent f0e09b5 commit cc747d3
Show file tree
Hide file tree
Showing 3 changed files with 197 additions and 23 deletions.
5 changes: 1 addition & 4 deletions firewall/acl.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,11 @@ func ProcessAclRules(server string, fwUpdate *models.FwUpdate) {
}
aclRules := fwUpdate.AclRules
ruleTable := fwCrtl.FetchRuleTable(server, aclTable)
fmt.Printf("======> ACL RULES: %+v\n, Curr Rule table: %+v\n", fwUpdate.AclRules, ruleTable)
fmt.Printf("======> ACL RULES: %+v \n", fwUpdate.AclRules)
if len(ruleTable) == 0 && len(aclRules) > 0 {
fwCrtl.AddAclRules(server, aclRules)
ruleTable := fwCrtl.FetchRuleTable(server, aclTable)
fmt.Printf("======> AFTER ACL RULES: Curr Rule table: %+v\n", ruleTable)
return
}
fmt.Println("## CHECKING New RULES==>")
// add new acl rules
for _, aclRule := range aclRules {
if _, ok := ruleTable[aclRule.ID]; !ok {
Expand Down
191 changes: 172 additions & 19 deletions firewall/nftables_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,98 @@ var (
table: defaultIpTable,
chain: iptableINChain,
},
{
nfRule: &nftables.Rule{
Table: filterTable,
Chain: &nftables.Chain{Name: iptableFWDChain},
Exprs: []expr.Any{
// Match on input interface (-i netmaker)
&expr.Meta{
Key: expr.MetaKeyIIFNAME, // Input interface name
Register: 1, // Store in register 1
},
&expr.Cmp{
Op: expr.CmpOpEq, // Equals operation
Register: 1, // Compare register 1
Data: []byte(ncutils.GetInterfaceName() + "\x00"), // Interface name "netmaker" (null-terminated string)
},
// Match on conntrack state (-m conntrack --ctstate RELATED,ESTABLISHED)
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 1,
},
&expr.Bitwise{
SourceRegister: 1, // Use register 1 from Ct expression
DestRegister: 1, // Output to same register
Len: 4, // State length
Mask: []byte{0x06, 0x00, 0x00, 0x00}, // Mask for RELATED (2) and ESTABLISHED (4)
Xor: []byte{0x00, 0x00, 0x00, 0x00}, // No XOR
},
&expr.Cmp{
Op: expr.CmpOpNeq, // Check if the bitwise result is not zero
Register: 1,
Data: []byte{0x00, 0x00, 0x00, 0x00},
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
UserData: []byte(genRuleKey("-i", ncutils.GetInterfaceName(), "-m", "conntrack",
"--ctstate", "ESTABLISHED,RELATED", "-m", "comment",
"--comment", netmakerSignature, "-j", "ACCEPT")), // Add comment
},
rule: []string{"-i", ncutils.GetInterfaceName(), "-m", "conntrack",
"--ctstate", "ESTABLISHED,RELATED", "-m", "comment",
"--comment", netmakerSignature, "-j", "ACCEPT"},
table: defaultIpTable,
chain: iptableFWDChain,
},
{
nfRule: &nftables.Rule{
Table: filterTable,
Chain: &nftables.Chain{Name: iptableFWDChain},
Exprs: []expr.Any{
// Match on input interface (-i netmaker)
&expr.Meta{
Key: expr.MetaKeyOIFNAME, // Input interface name
Register: 1, // Store in register 1
},
&expr.Cmp{
Op: expr.CmpOpEq, // Equals operation
Register: 1, // Compare register 1
Data: []byte(ncutils.GetInterfaceName() + "\x00"), // Interface name "netmaker" (null-terminated string)
},
// Match on conntrack state (-m conntrack --ctstate RELATED,ESTABLISHED)
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 1,
},
&expr.Bitwise{
SourceRegister: 1, // Use register 1 from Ct expression
DestRegister: 1, // Output to same register
Len: 4, // State length
Mask: []byte{0x06, 0x00, 0x00, 0x00}, // Mask for RELATED (2) and ESTABLISHED (4)
Xor: []byte{0x00, 0x00, 0x00, 0x00}, // No XOR
},
&expr.Cmp{
Op: expr.CmpOpNeq, // Check if the bitwise result is not zero
Register: 1,
Data: []byte{0x00, 0x00, 0x00, 0x00},
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
UserData: []byte(genRuleKey("-o", ncutils.GetInterfaceName(), "-m", "conntrack",
"--ctstate", "ESTABLISHED,RELATED", "-m", "comment",
"--comment", netmakerSignature, "-j", "ACCEPT")), // Add comment
},
rule: []string{"-o", ncutils.GetInterfaceName(), "-m", "conntrack",
"--ctstate", "ESTABLISHED,RELATED", "-m", "comment",
"--comment", netmakerSignature, "-j", "ACCEPT"},
table: defaultIpTable,
chain: iptableFWDChain,
},
{

nfRule: &nftables.Rule{
Expand Down Expand Up @@ -989,28 +1081,32 @@ func (n *nftablesManager) InsertIngressRoutingRules(server string, ingressInfo m
func (n *nftablesManager) GetSrcIpsExpr(ips []net.IPNet, isIpv4 bool) []expr.Any {
var e []expr.Any
if isIpv4 {
e = append(e, &expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12, // Source IP offset in IPv4 header
Len: 4, // IPv4 address length
})

for _, ip := range ips {
// Match first source IP
e = append(e,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12, // Source IP offset in IPv4 header
Len: 4, // IPv4 address length
},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: ip.Mask, // Match for 100.64.0.0/24
Xor: net.IPv4zero.To4(),
},
/*
// Match source IPs
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Op: expr.CmpOpEq,
Data: ip.IP.To4(),
Data: []byte{
100, 64, 0, 1, // 100.64.0.1
100, 64, 255, 254, // 100.64.255.254
},
},
)
*/
e = append(e, &expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ip.IP.To4(),
})
}

} else {
Expand Down Expand Up @@ -1068,6 +1164,7 @@ func (n *nftablesManager) AddAclRules(server string, aclRules map[string]models.
ruleSpec = append(ruleSpec, "--dport",
strings.Join(aclRule.AllowedPorts, ","))
}
ruleSpec = append(ruleSpec, "-m", "addrtype", "--dst-type", "LOCAL")
ruleSpec = append(ruleSpec, "-j", "ACCEPT")
ruleSpec = appendNetmakerCommentToRule(ruleSpec)
n.deleteRule(defaultIpTable, aclInputRulesChain, genRuleKey(ruleSpec...))
Expand All @@ -1079,6 +1176,9 @@ func (n *nftablesManager) AddAclRules(server string, aclRules map[string]models.
if len(aclRule.AllowedPorts) > 0 {
e = append(e, n.getExprForPort(aclRule.AllowedPorts)...)
}
// Match destination type LOCAL
e = append(e, n.getLocalExpr()...)

e = append(e, // Accept the packet
&expr.Verdict{
Kind: expr.VerdictAccept, // ACCEPT verdict
Expand Down Expand Up @@ -1118,6 +1218,7 @@ func (n *nftablesManager) AddAclRules(server string, aclRules map[string]models.
ruleSpec = append(ruleSpec, "--dport",
strings.Join(aclRule.AllowedPorts, ","))
}
ruleSpec = append(ruleSpec, "-m", "addrtype", "--dst-type", "LOCAL")
ruleSpec = append(ruleSpec, "-j", "ACCEPT")
ruleSpec = appendNetmakerCommentToRule(ruleSpec)
n.deleteRule(defaultIpTable, aclInputRulesChain, genRuleKey(ruleSpec...))
Expand All @@ -1129,6 +1230,9 @@ func (n *nftablesManager) AddAclRules(server string, aclRules map[string]models.
if len(aclRule.AllowedPorts) > 0 {
e = append(e, n.getExprForPort(aclRule.AllowedPorts)...)
}
// Match destination type LOCAL
e = append(e, n.getLocalExpr()...)

e = append(e, // Accept the packet
&expr.Verdict{
Kind: expr.VerdictAccept, // ACCEPT verdict
Expand All @@ -1154,7 +1258,6 @@ func (n *nftablesManager) AddAclRules(server string, aclRules map[string]models.
}

if len(rules) > 0 {
fmt.Printf("====> IN ADDACLRULES: %+v\n", rules)
rCfg := rulesCfg{
rulesMap: map[string][]ruleInfo{
aclRule.ID: rules,
Expand Down Expand Up @@ -1194,6 +1297,7 @@ func (n *nftablesManager) UpsertAclRule(server string, aclRule models.AclRule) {
ruleSpec = append(ruleSpec, "--dport",
strings.Join(aclRule.AllowedPorts, ","))
}
ruleSpec = append(ruleSpec, "-m", "addrtype", "--dst-type", "LOCAL")
ruleSpec = append(ruleSpec, "-j", "ACCEPT")
ruleSpec = appendNetmakerCommentToRule(ruleSpec)
n.deleteRule(defaultIpTable, aclInputRulesChain, genRuleKey(ruleSpec...))
Expand All @@ -1205,6 +1309,9 @@ func (n *nftablesManager) UpsertAclRule(server string, aclRule models.AclRule) {
if len(aclRule.AllowedPorts) > 0 {
e = append(e, n.getExprForPort(aclRule.AllowedPorts)...)
}
// Match destination type LOCAL
e = append(e, n.getLocalExpr()...)

e = append(e, // Accept the packet
&expr.Verdict{
Kind: expr.VerdictAccept, // ACCEPT verdict
Expand Down Expand Up @@ -1244,6 +1351,7 @@ func (n *nftablesManager) UpsertAclRule(server string, aclRule models.AclRule) {
ruleSpec = append(ruleSpec, "--dport",
strings.Join(aclRule.AllowedPorts, ","))
}
ruleSpec = append(ruleSpec, "-m", "addrtype", "--dst-type", "LOCAL")
ruleSpec = append(ruleSpec, "-j", "ACCEPT")
ruleSpec = appendNetmakerCommentToRule(ruleSpec)
n.deleteRule(defaultIpTable, aclInputRulesChain, genRuleKey(ruleSpec...))
Expand All @@ -1255,6 +1363,9 @@ func (n *nftablesManager) UpsertAclRule(server string, aclRule models.AclRule) {
if len(aclRule.AllowedPorts) > 0 {
e = append(e, n.getExprForPort(aclRule.AllowedPorts)...)
}
// Match destination type LOCAL
e = append(e, n.getLocalExpr()...)

e = append(e, // Accept the packet
&expr.Verdict{
Kind: expr.VerdictAccept, // ACCEPT verdict
Expand All @@ -1280,7 +1391,6 @@ func (n *nftablesManager) UpsertAclRule(server string, aclRule models.AclRule) {
}

if len(rules) > 0 {
fmt.Printf("====> IN ADDACLRULES: %+v\n", rules)
rCfg := rulesCfg{
rulesMap: map[string][]ruleInfo{
aclRule.ID: rules,
Expand Down Expand Up @@ -1424,3 +1534,46 @@ func rulesEqual(rule1, rule2 *nftables.Rule) bool {

return false
}

// AddLocalRule is a wrapper to match packets with LOCAL destination type (IPv4 and IPv6).
func (n *nftablesManager) getLocalExpr() (e []expr.Any) {
return
localIPs, err := GetLocalIPs()
if err != nil {
return
}

for _, localIP := range localIPs {
var base expr.PayloadBase
var offsetDst, lenIP uint32
if localIP.To4() != nil {
// IPv4-specific parameters
base = expr.PayloadBaseNetworkHeader
offsetDst = 16 // Destination IP in IPv4 header
lenIP = 4 // IPv4 address length
} else {
// IPv6-specific parameters
base = expr.PayloadBaseNetworkHeader
offsetDst = 24 // Destination IP in IPv6 header
lenIP = 16 // IPv6 address length
}

e = append(e, []expr.Any{

// Match destination IP (local IP)
&expr.Payload{
DestRegister: 1,
Base: base,
Offset: offsetDst,
Len: lenIP,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: localIP,
},
}...)
}

return nil
}
24 changes: 24 additions & 0 deletions firewall/utils.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package firewall

import (
"net"
"net/netip"
)

Expand All @@ -17,3 +18,26 @@ func isAddrIpv4(addr string) bool {
}
return isIpv4
}

// GetLocalIPs retrieves all local IPs (IPv4 and IPv6) on the machine.
func GetLocalIPs() ([]net.IP, error) {
var localIPs []net.IP
interfaces, err := net.Interfaces()
if err != nil {
return nil, err
}

for _, iface := range interfaces {
addrs, err := iface.Addrs()
if err != nil {
return nil, err
}
for _, addr := range addrs {
ip, _, err := net.ParseCIDR(addr.String())
if err == nil {
localIPs = append(localIPs, ip)
}
}
}
return localIPs, nil
}

0 comments on commit cc747d3

Please sign in to comment.