diff --git a/pkg/daemon/gateway.go b/pkg/daemon/gateway.go index a64146a8bf9..142bf61e679 100644 --- a/pkg/daemon/gateway.go +++ b/pkg/daemon/gateway.go @@ -101,7 +101,7 @@ func (c *Controller) getSubnetsNeedNAT(protocol string) ([]string, error) { for _, subnet := range subnets { if c.isSubnetNeedNat(subnet, protocol) { cidrBlock, err := getCidrByProtocol(subnet.Spec.CIDRBlock, protocol) - if err == nil { + if err == nil && cidrBlock != "" { subnetsNeedNat = append(subnetsNeedNat, cidrBlock) } } @@ -141,7 +141,7 @@ func (c *Controller) getSubnetsDistributedGateway(protocol string) ([]string, er subnet.Spec.GatewayType == kubeovnv1.GWDistributedType && (subnet.Spec.Protocol == kubeovnv1.ProtocolDual || subnet.Spec.Protocol == protocol) { cidrBlock, err := getCidrByProtocol(subnet.Spec.CIDRBlock, protocol) - if err == nil { + if err == nil && cidrBlock != "" { result = append(result, cidrBlock) } } @@ -172,7 +172,7 @@ func (c *Controller) getDefaultVpcSubnetsCIDR(protocol string) ([]string, map[st for _, subnet := range subnets { if subnet.Spec.Vpc == c.config.ClusterRouter && (subnet.Spec.Vlan == "" || subnet.Spec.LogicalGateway) && subnet.Spec.CIDRBlock != "" { cidrBlock, err := getCidrByProtocol(subnet.Spec.CIDRBlock, protocol) - if err == nil { + if err == nil && cidrBlock != "" { ret = append(ret, cidrBlock) subnetMap[subnet.Name] = cidrBlock } @@ -204,22 +204,17 @@ func (c *Controller) getOtherNodes(protocol string) ([]string, error) { } func getCidrByProtocol(cidr, protocol string) (string, error) { - var cidrStr string if err := util.CheckCidrs(cidr); err != nil { return "", err } - if util.CheckProtocol(cidr) == kubeovnv1.ProtocolDual { - cidrBlocks := strings.Split(cidr, ",") - if protocol == kubeovnv1.ProtocolIPv4 { - cidrStr = cidrBlocks[0] - } else if protocol == kubeovnv1.ProtocolIPv6 { - cidrStr = cidrBlocks[1] + for _, cidr := range strings.Split(cidr, ",") { + if util.CheckProtocol(cidr) == protocol { + return cidr, nil } - } else { - cidrStr = cidr } - return cidrStr, nil + + return "", nil } func (c *Controller) getEgressNatIPByNode(nodeName string) (map[string]string, error) { diff --git a/pkg/daemon/gateway_linux.go b/pkg/daemon/gateway_linux.go index ddefc4b8fd3..a36286cb86d 100644 --- a/pkg/daemon/gateway_linux.go +++ b/pkg/daemon/gateway_linux.go @@ -205,7 +205,7 @@ func (c *Controller) reconcileNatOutGoingPolicyIPset(protocol string) { return } - subnetCidrs := make([]string, 0) + subnetCidrs := make([]string, 0, len(subnets)) natPolicyRuleIDs := strset.New() for _, subnet := range subnets { cidrBlock, err := getCidrByProtocol(subnet.Spec.CIDRBlock, protocol) @@ -213,7 +213,9 @@ func (c *Controller) reconcileNatOutGoingPolicyIPset(protocol string) { klog.Errorf("failed to get subnet %s CIDR block by protocol: %v", subnet.Name, err) continue } - subnetCidrs = append(subnetCidrs, cidrBlock) + if cidrBlock != "" { + subnetCidrs = append(subnetCidrs, cidrBlock) + } for _, rule := range subnet.Status.NatOutgoingPolicyRules { if rule.RuleID == "" { klog.Errorf("unexpected empty ID for NAT outgoing rule %q of subnet %s", rule.NatOutgoingPolicyRule, subnet.Name) @@ -1003,6 +1005,9 @@ func (c *Controller) generateNatOutgoingPolicyChainRules(protocol string) ([]uti klog.Errorf("failed to get subnet %s cidr block with protocol: %v", subnet.Name, err) continue } + if cidrBlock == "" { + continue + } ovnNatPolicySubnetChainName := OvnNatOutGoingPolicySubnet + util.GetTruncatedUID(string(subnet.GetUID())) natPolicySubnetIptables = append(natPolicySubnetIptables, util.IPTableRule{Table: NAT, Chain: OvnNatOutGoingPolicy, Rule: strings.Fields(fmt.Sprintf(`-s %s -m comment --comment natPolicySubnet-%s -j %s`, cidrBlock, subnet.Name, ovnNatPolicySubnetChainName))}) @@ -1616,7 +1621,7 @@ func (c *Controller) getSubnetsNeedPR(protocol string) (map[policyRouteMeta]stri } if meta.gateway != "" { cidrBlock, err := getCidrByProtocol(subnet.Spec.CIDRBlock, protocol) - if err == nil { + if err == nil && cidrBlock != "" { subnetsNeedPR[meta] = cidrBlock } } diff --git a/pkg/daemon/gateway_test.go b/pkg/daemon/gateway_test.go new file mode 100644 index 00000000000..ce44d43650c --- /dev/null +++ b/pkg/daemon/gateway_test.go @@ -0,0 +1,61 @@ +package daemon + +import ( + "testing" + + "github.com/stretchr/testify/require" + + kubeovnv1 "github.com/kubeovn/kube-ovn/pkg/apis/kubeovn/v1" +) + +func TestGetCidrByProtocol(t *testing.T) { + cases := []struct { + name string + cidr string + protocol string + wantErr bool + expetced string + }{{ + name: "ipv4 only", + cidr: "1.1.1.0/24", + protocol: kubeovnv1.ProtocolIPv4, + expetced: "1.1.1.0/24", + }, { + name: "ipv6 only", + cidr: "2001:db8::/120", + protocol: kubeovnv1.ProtocolIPv6, + expetced: "2001:db8::/120", + }, { + name: "get ipv4 from ipv6", + cidr: "2001:db8::/120", + protocol: kubeovnv1.ProtocolIPv4, + }, { + name: "get ipv4 from dual stack", + cidr: "1.1.1.0/24,2001:db8::/120", + protocol: kubeovnv1.ProtocolIPv4, + expetced: "1.1.1.0/24", + }, { + name: "get ipv6 from ipv4", + cidr: "1.1.1.0/24", + protocol: kubeovnv1.ProtocolIPv6, + }, { + name: "get ipv6 from dual stack", + cidr: "1.1.1.0/24,2001:db8::/120", + protocol: kubeovnv1.ProtocolIPv6, + expetced: "2001:db8::/120", + }, { + name: "invalid cidr", + cidr: "foo bar", + protocol: kubeovnv1.ProtocolIPv4, + wantErr: true, + }} + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + got, err := getCidrByProtocol(c.cidr, c.protocol) + if (err != nil) != c.wantErr { + t.Errorf("getCidrByProtocol(%q, %q) error = %v, wantErr = %v", c.cidr, c.protocol, err, c.wantErr) + } + require.Equal(t, c.expetced, got) + }) + } +}