Skip to content

Commit

Permalink
fix getting subnet cidr by protocol (#4844)
Browse files Browse the repository at this point in the history
Signed-off-by: zhangzujian <[email protected]>
  • Loading branch information
zhangzujian authored Dec 18, 2024
1 parent 4875f23 commit 64acd01
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 16 deletions.
21 changes: 8 additions & 13 deletions pkg/daemon/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func (c *Controller) getSubnetsNeedNAT(protocol string) ([]string, error) {
for _, subnet := range subnets {
if c.isSubnetNeedNat(subnet, protocol) {
cidrBlock, err := getCidrByProtocol(subnet.Spec.CIDRBlock, protocol)
if err == nil {
if err == nil && cidrBlock != "" {
subnetsNeedNat = append(subnetsNeedNat, cidrBlock)
}
}
Expand Down Expand Up @@ -141,7 +141,7 @@ func (c *Controller) getSubnetsDistributedGateway(protocol string) ([]string, er
subnet.Spec.GatewayType == kubeovnv1.GWDistributedType &&
(subnet.Spec.Protocol == kubeovnv1.ProtocolDual || subnet.Spec.Protocol == protocol) {
cidrBlock, err := getCidrByProtocol(subnet.Spec.CIDRBlock, protocol)
if err == nil {
if err == nil && cidrBlock != "" {
result = append(result, cidrBlock)
}
}
Expand Down Expand Up @@ -172,7 +172,7 @@ func (c *Controller) getDefaultVpcSubnetsCIDR(protocol string) ([]string, map[st
for _, subnet := range subnets {
if subnet.Spec.Vpc == c.config.ClusterRouter && (subnet.Spec.Vlan == "" || subnet.Spec.LogicalGateway) && subnet.Spec.CIDRBlock != "" {
cidrBlock, err := getCidrByProtocol(subnet.Spec.CIDRBlock, protocol)
if err == nil {
if err == nil && cidrBlock != "" {
ret = append(ret, cidrBlock)
subnetMap[subnet.Name] = cidrBlock
}
Expand Down Expand Up @@ -204,22 +204,17 @@ func (c *Controller) getOtherNodes(protocol string) ([]string, error) {
}

func getCidrByProtocol(cidr, protocol string) (string, error) {
var cidrStr string
if err := util.CheckCidrs(cidr); err != nil {
return "", err
}

if util.CheckProtocol(cidr) == kubeovnv1.ProtocolDual {
cidrBlocks := strings.Split(cidr, ",")
if protocol == kubeovnv1.ProtocolIPv4 {
cidrStr = cidrBlocks[0]
} else if protocol == kubeovnv1.ProtocolIPv6 {
cidrStr = cidrBlocks[1]
for _, cidr := range strings.Split(cidr, ",") {
if util.CheckProtocol(cidr) == protocol {
return cidr, nil
}
} else {
cidrStr = cidr
}
return cidrStr, nil

return "", nil
}

func (c *Controller) getEgressNatIPByNode(nodeName string) (map[string]string, error) {
Expand Down
11 changes: 8 additions & 3 deletions pkg/daemon/gateway_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,15 +205,17 @@ func (c *Controller) reconcileNatOutGoingPolicyIPset(protocol string) {
return
}

subnetCidrs := make([]string, 0)
subnetCidrs := make([]string, 0, len(subnets))
natPolicyRuleIDs := strset.New()
for _, subnet := range subnets {
cidrBlock, err := getCidrByProtocol(subnet.Spec.CIDRBlock, protocol)
if err != nil {
klog.Errorf("failed to get subnet %s CIDR block by protocol: %v", subnet.Name, err)
continue
}
subnetCidrs = append(subnetCidrs, cidrBlock)
if cidrBlock != "" {
subnetCidrs = append(subnetCidrs, cidrBlock)
}
for _, rule := range subnet.Status.NatOutgoingPolicyRules {
if rule.RuleID == "" {
klog.Errorf("unexpected empty ID for NAT outgoing rule %q of subnet %s", rule.NatOutgoingPolicyRule, subnet.Name)
Expand Down Expand Up @@ -1003,6 +1005,9 @@ func (c *Controller) generateNatOutgoingPolicyChainRules(protocol string) ([]uti
klog.Errorf("failed to get subnet %s cidr block with protocol: %v", subnet.Name, err)
continue
}
if cidrBlock == "" {
continue
}

ovnNatPolicySubnetChainName := OvnNatOutGoingPolicySubnet + util.GetTruncatedUID(string(subnet.GetUID()))
natPolicySubnetIptables = append(natPolicySubnetIptables, util.IPTableRule{Table: NAT, Chain: OvnNatOutGoingPolicy, Rule: strings.Fields(fmt.Sprintf(`-s %s -m comment --comment natPolicySubnet-%s -j %s`, cidrBlock, subnet.Name, ovnNatPolicySubnetChainName))})
Expand Down Expand Up @@ -1616,7 +1621,7 @@ func (c *Controller) getSubnetsNeedPR(protocol string) (map[policyRouteMeta]stri
}
if meta.gateway != "" {
cidrBlock, err := getCidrByProtocol(subnet.Spec.CIDRBlock, protocol)
if err == nil {
if err == nil && cidrBlock != "" {
subnetsNeedPR[meta] = cidrBlock
}
}
Expand Down
61 changes: 61 additions & 0 deletions pkg/daemon/gateway_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package daemon

import (
"testing"

"github.com/stretchr/testify/require"

kubeovnv1 "github.com/kubeovn/kube-ovn/pkg/apis/kubeovn/v1"
)

func TestGetCidrByProtocol(t *testing.T) {
cases := []struct {
name string
cidr string
protocol string
wantErr bool
expetced string
}{{
name: "ipv4 only",
cidr: "1.1.1.0/24",
protocol: kubeovnv1.ProtocolIPv4,
expetced: "1.1.1.0/24",
}, {
name: "ipv6 only",
cidr: "2001:db8::/120",
protocol: kubeovnv1.ProtocolIPv6,
expetced: "2001:db8::/120",
}, {
name: "get ipv4 from ipv6",
cidr: "2001:db8::/120",
protocol: kubeovnv1.ProtocolIPv4,
}, {
name: "get ipv4 from dual stack",
cidr: "1.1.1.0/24,2001:db8::/120",
protocol: kubeovnv1.ProtocolIPv4,
expetced: "1.1.1.0/24",
}, {
name: "get ipv6 from ipv4",
cidr: "1.1.1.0/24",
protocol: kubeovnv1.ProtocolIPv6,
}, {
name: "get ipv6 from dual stack",
cidr: "1.1.1.0/24,2001:db8::/120",
protocol: kubeovnv1.ProtocolIPv6,
expetced: "2001:db8::/120",
}, {
name: "invalid cidr",
cidr: "foo bar",
protocol: kubeovnv1.ProtocolIPv4,
wantErr: true,
}}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
got, err := getCidrByProtocol(c.cidr, c.protocol)
if (err != nil) != c.wantErr {
t.Errorf("getCidrByProtocol(%q, %q) error = %v, wantErr = %v", c.cidr, c.protocol, err, c.wantErr)
}
require.Equal(t, c.expetced, got)
})
}
}

0 comments on commit 64acd01

Please sign in to comment.