From fb203bfa08fdd2c6eb1f62038efe4a4fcb4bc6fe Mon Sep 17 00:00:00 2001 From: Elazar Gershuni Date: Mon, 11 Mar 2024 11:06:25 +0200 Subject: [PATCH] add ExceptCidr; direct handling of IP Signed-off-by: Elazar Gershuni --- pkg/ipblock/ipblock.go | 77 +++++++++++++++++++----------------------- 1 file changed, 35 insertions(+), 42 deletions(-) diff --git a/pkg/ipblock/ipblock.go b/pkg/ipblock/ipblock.go index d495ea4..b7ea688 100644 --- a/pkg/ipblock/ipblock.go +++ b/pkg/ipblock/ipblock.go @@ -123,12 +123,9 @@ func (b *IPBlock) Split() []*IPBlock { // intToIP4 returns a string of an ip address from an input integer ip value func intToIP4(ipInt int64) string { - // need to do two bit shifting and “0xff” masking - b0 := strconv.FormatInt((ipInt>>ipShift0)&ipByte, ipBase) - b1 := strconv.FormatInt((ipInt>>ipShift1)&ipByte, ipBase) - b2 := strconv.FormatInt((ipInt>>ipShift2)&ipByte, ipBase) - b3 := strconv.FormatInt((ipInt & ipByte), ipBase) - return b0 + "." + b1 + "." + b2 + "." + b3 + var d [4]byte + binary.BigEndian.PutUint32(d[:], uint32(ipInt)) + return string(net.IPv4(d[0], d[1], d[2], d[3]).String()) } // DisjointIPBlocks returns an IPBlock of disjoint ip ranges from 2 input IPBlock objects @@ -151,8 +148,7 @@ func DisjointIPBlocks(set1, set2 []*IPBlock) []*IPBlock { } if len(res) == 0 { - newAll := GetCidrAll() - res = append(res, newAll) + res = append(res, GetCidrAll()) } return res } @@ -181,15 +177,28 @@ func addIntervalToList(ipbNew *IPBlock, ipbList []*IPBlock) []*IPBlock { // NewIPBlockFromCidr returns a new IPBlock object from input CIDR string func NewIPBlockFromCidr(cidr string) (*IPBlock, error) { - span, err := cidrToInterval(cidr) + start, end, err := cidrToIPRange(cidr) if err != nil { return nil, err } return &IPBlock{ - ipRange: *interval.CreateSetFromInterval(span.Start, span.End), + ipRange: *interval.CreateSetFromInterval(start, end), }, nil } +// ExceptCidrs returns a new IPBlock with all cidr ranges removed +func (b *IPBlock) ExceptCidrs(cidrExceptions ...string) (*IPBlock, error) { + res := b.Copy() + for i := range cidrExceptions { + hole, err := NewIPBlockFromCidr(cidrExceptions[i]) + if err != nil { + return nil, err + } + res = res.Subtract(hole) + } + return res, nil +} + // PairCIDRsToIPBlocks returns two IPBlock objects from two input CIDR strings func PairCIDRsToIPBlocks(cidr1, cidr2 string) (ipb1, ipb2 *IPBlock, err error) { ipb1, err1 := NewIPBlockFromCidr(cidr1) @@ -221,20 +230,15 @@ func NewIPBlockFromCidrList(cidrsList []string) (*IPBlock, error) { return &IPBlock{ipRange: *ipRange}, nil } -func ipv4AddressToCidr(ipAddress string) (string, error) { - if strings.Contains(ipAddress, "/") { - return "", fmt.Errorf("%v is not an IP address", ipAddress) - } - return ipAddress + "/32", nil -} - // NewIPBlockFromIPAddress returns an IPBlock object from input IP address string func NewIPBlockFromIPAddress(ipAddress string) (*IPBlock, error) { - cidr, err := ipv4AddressToCidr(ipAddress) + ipNum, err := parseIP(ipAddress) if err != nil { return nil, err } - return NewIPBlockFromCidr(cidr) + return &IPBlock{ + ipRange: *interval.CreateSetFromInterval(ipNum, ipNum), + }, nil } func cidrToIPRange(cidr string) (start, end int64, err error) { @@ -255,14 +259,6 @@ func cidrToIPRange(cidr string) (start, end int64, err error) { return } -func cidrToInterval(cidr string) (*interval.Interval, error) { - start, end, err := cidrToIPRange(cidr) - if err != nil { - return nil, err - } - return &interval.Interval{Start: start, End: end}, nil -} - // ToCidrList returns a list of CIDR strings for this IPBlock object func (b *IPBlock) ToCidrList() []string { cidrList := []string{} @@ -327,27 +323,24 @@ func intervalToCidrList(ipRange interval.Interval) []string { return res } +func parseIP(ip string) (int64, error) { + startIP := net.ParseIP(ip) + if startIP == nil { + return 0, fmt.Errorf("%v is not a valid ipv4", ip) + } + return int64(binary.BigEndian.Uint32(startIP.To4())), nil +} + // IPBlockFromIPRangeStr returns IPBlock object from input IP range string (example: "169.255.0.0-172.15.255.255") func IPBlockFromIPRangeStr(ipRangeStr string) (*IPBlock, error) { ipAddresses := strings.Split(ipRangeStr, dash) if len(ipAddresses) != 2 { return nil, errors.New("unexpected ipRange str") } - var startIP, endIP *IPBlock - var err error - if startIP, err = NewIPBlockFromIPAddress(ipAddresses[0]); err != nil { - return nil, err - } - if endIP, err = NewIPBlockFromIPAddress(ipAddresses[1]); err != nil { - return nil, err - } - startIPNum, err := startIP.ipRange.Min() - if err != nil { - return nil, err - } - endIPNum, err := endIP.ipRange.Min() - if err != nil { - return nil, err + startIPNum, err0 := parseIP(ipAddresses[0]) + endIPNum, err1 := parseIP(ipAddresses[1]) + if err0 != nil || err1 != nil { + return nil, errors.Join(err0, err1) } res := &IPBlock{ ipRange: *interval.CreateSetFromInterval(startIPNum, endIPNum),