diff --git a/pkg/ipblocks/ipblocks.go b/pkg/ipblocks/ipblocks.go index 79b22a6..16f2525 100644 --- a/pkg/ipblocks/ipblocks.go +++ b/pkg/ipblocks/ipblocks.go @@ -14,6 +14,10 @@ import ( ) const ( + // CidrAll represents the CIDR for all addresses "0.0.0.0/0" + CidrAll = "0.0.0.0/0" + + // internal const below ipByte = 0xff ipShift0 = 24 ipShift1 = 16 @@ -21,31 +25,31 @@ const ( ipBase = 10 ipMask = 0xffffffff maxIPv4Bits = 32 - CidrAll = "0.0.0.0/0" cidrSeparator = "/" bitSize64 = 64 commaSeparator = ", " + dash = "-" ) -// IPBlock captures a set of ip ranges +// IPBlock captures a set of IP ranges type IPBlock struct { ipRange intervals.CanonicalIntervalSet } // ToIPRanges returns a string of the ip ranges in the current IPBlock object func (b *IPBlock) ToIPRanges() string { - return strings.Join(b.ToIPRangesList(), commaSeparator) + return strings.Join(b.toIPRangesList(), commaSeparator) } -// ToIPRange returns a string of the ip range of a single interval +// toIPRange returns a string of the ip range of a single interval func toIPRange(i intervals.Interval) string { - startIP := InttoIP4(i.Start) - endIP := InttoIP4(i.End) + startIP := inttoIP4(i.Start) + endIP := inttoIP4(i.End) return rangeIPstr(startIP, endIP) } -// ToIPRangesList: returns a list of the ip-ranges strings in the current IPBlock object -func (b *IPBlock) ToIPRangesList() []string { +// toIPRangesList: returns a list of the ip-ranges strings in the current IPBlock object +func (b *IPBlock) toIPRangesList() []string { IPRanges := make([]string, len(b.ipRange.IntervalSet)) for index := range b.ipRange.IntervalSet { IPRanges[index] = toIPRange(b.ipRange.IntervalSet[index]) @@ -53,16 +57,12 @@ func (b *IPBlock) ToIPRangesList() []string { return IPRanges } -// IsIPAddress returns true if IPBlock object is a range of exactly one ip address from input -func (b *IPBlock) IsIPAddress(ipAddress string) bool { - ipRanges := b.ToIPRanges() - return ipRanges == rangeIPstr(ipAddress, ipAddress) -} - +// ContainedIn returns true if the input IPBlock is contained in this IPBlock func (b *IPBlock) ContainedIn(c *IPBlock) bool { return b.ipRange.ContainedIn(c.ipRange) } +// Intersection returns a new IPBlock from intersection of this IPBlock with input IPBlock func (b *IPBlock) Intersection(c *IPBlock) *IPBlock { res := &IPBlock{} res.ipRange = b.ipRange.Copy() @@ -70,10 +70,12 @@ func (b *IPBlock) Intersection(c *IPBlock) *IPBlock { return res } +// Equal returns true if this IPBlock equals the input IPBlock func (b *IPBlock) Equal(c *IPBlock) bool { return b.ipRange.Equal(c.ipRange) } +// Subtract returns a new IPBlock from subtraction of input IPBlock from this IPBlock func (b *IPBlock) Subtract(c *IPBlock) *IPBlock { res := &IPBlock{} res.ipRange = b.ipRange.Copy() @@ -81,6 +83,7 @@ func (b *IPBlock) Subtract(c *IPBlock) *IPBlock { return res } +// Union returns a new IPBlock from union of input IPBlock with this IPBlock func (b *IPBlock) Union(c *IPBlock) *IPBlock { res := &IPBlock{} res.ipRange = b.ipRange.Copy() @@ -88,6 +91,7 @@ func (b *IPBlock) Union(c *IPBlock) *IPBlock { return res } +// Empty returns true if this IPBlock is empty func (b *IPBlock) Empty() bool { return b.ipRange.IsEmpty() } @@ -122,8 +126,8 @@ func (b *IPBlock) Split() []*IPBlock { return res } -// InttoIP4 returns a string of an ip address from an input integer ip value -func InttoIP4(ipInt int64) string { +// inttoIP4 returns a string of an ip address from an input integer ip value +func inttoIP4(ipInt int64) string { // need to do two bit shifting and “0xff” masking b0 := strconv.FormatInt((ipInt>>ipShift0)&ipByte, ipBase) b1 := strconv.FormatInt((ipInt>>ipShift1)&ipByte, ipBase) @@ -153,7 +157,7 @@ func DisjointIPBlocks(set1, set2 []*IPBlock) []*IPBlock { res := blocksWithNoOverlaps if len(res) == 0 { - newAll, _ := NewIPBlock("0.0.0.0/0", []string{}) + newAll := GetCidrAll() res = append(res, newAll) } return res @@ -182,32 +186,40 @@ func addIntervalToList(ipbNew *IPBlock, ipbList []*IPBlock) []*IPBlock { return ipbList } -func NewIPBlockFromCidr(cidr string) *IPBlock { - res, err := NewIPBlock(cidr, []string{}) - if err != nil { - return nil +// NewIPBlockFromCidr returns a new IPBlock object from input CIDR string +func NewIPBlockFromCidr(cidr string) (*IPBlock, error) { + return NewIPBlock(cidr, []string{}) +} + +// PairCIDRsToIPBlocks returns two IPBlock objects from two input CIDR strings +func PairCIDRsToIPBlocks(cidr1, cidr2 string) (ipb1, ipb2 *IPBlock, err error) { + ipb1, err1 := NewIPBlockFromCidr(cidr1) + ipb2, err2 := NewIPBlockFromCidr(cidr2) + if err1 != nil || err2 != nil { + return nil, nil, errors.Join(err1, err2) } - return res + return ipb1, ipb2, nil } -func NewIPBlockFromCidrOrAddress(s string) *IPBlock { - var res *IPBlock +// NewIPBlockFromCidr returns a new IPBlock object from input string of CIDR or IP address +func NewIPBlockFromCidrOrAddress(s string) (*IPBlock, error) { if strings.Contains(s, cidrSeparator) { - res = NewIPBlockFromCidr(s) - } else { - res, _ = NewIPBlockFromIPAddress(s) + return NewIPBlockFromCidr(s) } - return res + return NewIPBlockFromIPAddress(s) } // NewIPBlockFromCidrList returns IPBlock object from multiple CIDRs given as list of strings -func NewIPBlockFromCidrList(cidrsList []string) *IPBlock { +func NewIPBlockFromCidrList(cidrsList []string) (*IPBlock, error) { res := &IPBlock{ipRange: intervals.CanonicalIntervalSet{}} for _, cidr := range cidrsList { - block := NewIPBlockFromCidr(cidr) + block, err := NewIPBlockFromCidr(cidr) + if err != nil { + return nil, err + } res = res.Union(block) } - return res + return res, nil } // NewIPBlock returns an IPBlock object from input cidr str an exceptions cidr str @@ -228,13 +240,13 @@ func NewIPBlock(cidr string, exceptions []string) (*IPBlock, error) { return &res, nil } -func IPv4AddressToCidr(ipAddress string) string { +func ipv4AddressToCidr(ipAddress string) string { return ipAddress + "/32" } -// NewIPBlockFromIPAddress returns an IPBlock object from input ip address str +// NewIPBlockFromIPAddress returns an IPBlock object from input IP address string func NewIPBlockFromIPAddress(ipAddress string) (*IPBlock, error) { - return NewIPBlock(IPv4AddressToCidr(ipAddress), []string{}) + return NewIPBlock(ipv4AddressToCidr(ipAddress), []string{}) } func cidrToIPRange(cidr string) (start, end int64, err error) { @@ -263,10 +275,11 @@ func cidrToInterval(cidr string) (*intervals.Interval, error) { return &intervals.Interval{Start: start, End: end}, nil } +// ToCidrList returns a list of CIDR strings for this IPBlock object func (b *IPBlock) ToCidrList() []string { cidrList := []string{} for _, interval := range b.ipRange.IntervalSet { - cidrList = append(cidrList, IntervalToCidrList(interval.Start, interval.End)...) + cidrList = append(cidrList, intervalToCidrList(interval.Start, interval.End)...) } return cidrList } @@ -280,7 +293,7 @@ func (b *IPBlock) ToCidrListString() string { func (b *IPBlock) ListToPrint() []string { cidrsIPRangesList := []string{} for _, interval := range b.ipRange.IntervalSet { - cidr := IntervalToCidrList(interval.Start, interval.End) + cidr := intervalToCidrList(interval.Start, interval.End) if len(cidr) == 1 { cidrsIPRangesList = append(cidrsIPRangesList, cidr[0]) } else { @@ -290,14 +303,15 @@ func (b *IPBlock) ListToPrint() []string { return cidrsIPRangesList } -func (b *IPBlock) ToIPAddress() string { +// ToIPAdressString returns the IP Address string for this IPBlock +func (b *IPBlock) ToIPAddressString() string { if b.ipRange.IsSingleNumber() { - return InttoIP4(b.ipRange.IntervalSet[0].Start) + return inttoIP4(b.ipRange.IntervalSet[0].Start) } return "" } -func IntervalToCidrList(ipStart, ipEnd int64) []string { +func intervalToCidrList(ipStart, ipEnd int64) []string { start := ipStart end := ipEnd res := []string{} @@ -317,15 +331,16 @@ func IntervalToCidrList(ipStart, ipEnd int64) []string { if maxSize < int(maxDiff) { maxSize = int(maxDiff) } - ip := InttoIP4(start) + ip := inttoIP4(start) res = append(res, fmt.Sprintf("%s/%d", ip, maxSize)) start += int64(math.Pow(2, maxIPv4Bits-float64(maxSize))) } return res } +// IPBlockFromIPRangeStr returns IPBlock object from input IP range string (example: "169.255.0.0-172.15.255.255") func IPBlockFromIPRangeStr(ipRagneStr string) (*IPBlock, error) { - ipAddresses := strings.Split(ipRagneStr, "-") + ipAddresses := strings.Split(ipRagneStr, dash) if len(ipAddresses) != 2 { return nil, errors.New("unexpected ipRange str") } @@ -345,23 +360,10 @@ func IPBlockFromIPRangeStr(ipRagneStr string) (*IPBlock, error) { return res, nil } +// GetCidrAll returns IPBlock object of the entire range 0.0.0.0/0 func GetCidrAll() *IPBlock { - return NewIPBlockFromCidr(CidrAll) -} - -func IsAddressInSubnet(address, subnetCidr string) (bool, error) { - var addressIPblock, subnetIPBlock *IPBlock - var err error - if addressIPblock, err = NewIPBlockFromIPAddress(address); err != nil { - return false, err - } - subnetIPBlock = NewIPBlockFromCidr(subnetCidr) - return addressIPblock.ContainedIn(subnetIPBlock), nil -} - -func CIDRtoIPrange(cidr string) string { - ipb := NewIPBlockFromCidr(cidr) - return ipb.ToIPRanges() + res, _ := NewIPBlockFromCidr(CidrAll) + return res } // PrefixLength returns the cidr's prefix length, assuming the ipBlock is exactly one cidr. diff --git a/pkg/ipblocks/ipblocks_test.go b/pkg/ipblocks/ipblocks_test.go index af4145f..1481599 100644 --- a/pkg/ipblocks/ipblocks_test.go +++ b/pkg/ipblocks/ipblocks_test.go @@ -9,19 +9,17 @@ import ( ) func TestOps(t *testing.T) { - ipb1 := ipblocks.NewIPBlockFromCidrOrAddress("1.2.3.0/24") + ipb1, err := ipblocks.NewIPBlockFromCidrOrAddress("1.2.3.0/24") + require.Nil(t, err) require.NotNil(t, ipb1) - ipb2 := ipblocks.NewIPBlockFromCidrOrAddress("1.2.3.4") + ipb2, err := ipblocks.NewIPBlockFromCidrOrAddress("1.2.3.4") + require.Nil(t, err) require.NotNil(t, ipb2) - require.True(t, ipb2.IsIPAddress("1.2.3.4")) require.True(t, ipb2.ContainedIn(ipb1)) require.False(t, ipb1.ContainedIn(ipb2)) minus := ipb1.Subtract(ipb2) - minusRanges := minus.ToIPRangesList() - require.Len(t, minusRanges, 2) - require.Equal(t, "1.2.3.0-1.2.3.3", minusRanges[0]) - require.Equal(t, "1.2.3.5-1.2.3.255", minusRanges[1]) + require.Equal(t, "1.2.3.0-1.2.3.3, 1.2.3.5-1.2.3.255", minus.ToIPRanges()) minus2, err := ipblocks.NewIPBlock(ipb1.ToCidrListString(), []string{ipb2.ToCidrListString()}) require.Nil(t, err) @@ -46,50 +44,99 @@ func TestConversions(t *testing.T) { cidrs := ipb1.ToCidrList() require.Len(t, cidrs, 26) - ipb2 := ipblocks.NewIPBlockFromCidrList(cidrs) + ipb2, err := ipblocks.NewIPBlockFromCidrList(cidrs) + require.Nil(t, err) require.Equal(t, ipb1.ToCidrListString(), ipb2.ToCidrListString()) toprint := ipb1.ListToPrint() require.Len(t, toprint, 1) require.Equal(t, iprange, toprint[0]) - require.Equal(t, "", ipb1.ToIPAddress()) + require.Equal(t, "", ipb1.ToIPAddressString()) } func TestDisjointIPBlocks(t *testing.T) { allIPs := ipblocks.GetCidrAll() - ipb := ipblocks.NewIPBlockFromCidrList([]string{"1.2.3.4/32", "172.0.0.0/8"}) + ipb, err := ipblocks.NewIPBlockFromCidrList([]string{"1.2.3.4/32", "172.0.0.0/8"}) + require.Nil(t, err) disjointBlocks := ipblocks.DisjointIPBlocks([]*ipblocks.IPBlock{allIPs}, []*ipblocks.IPBlock{ipb}) require.Len(t, disjointBlocks, 5) - require.Equal(t, "1.2.3.4", disjointBlocks[0].ToIPAddress()) // list is sorted by ip-block size -} + require.Equal(t, "1.2.3.4", disjointBlocks[0].ToIPAddressString()) // list is sorted by ip-block size -func TestIsAddressInSubnet(t *testing.T) { - res, err := ipblocks.IsAddressInSubnet("1.2.3.4", "1.0.0.0/8") + ipb2, err := ipblocks.NewIPBlockFromCidrList([]string{"1.2.3.0/30"}) require.Nil(t, err) - require.True(t, res) + ipb3, err := ipblocks.IPBlockFromIPRangeStr("1.2.2.255-1.2.3.1") + require.Nil(t, err) + disjointBlocks = ipblocks.DisjointIPBlocks([]*ipblocks.IPBlock{ipb2}, []*ipblocks.IPBlock{ipb3}) + require.Len(t, disjointBlocks, 3) + require.Equal(t, "1.2.2.255", disjointBlocks[0].ToIPAddressString()) + require.Equal(t, "1.2.3.2/31", disjointBlocks[1].ToCidrListString()) + require.Equal(t, "1.2.3.0/31", disjointBlocks[2].ToCidrListString()) +} - res, err = ipblocks.IsAddressInSubnet("1.2.3.4", "1.0.0.0/16") +func TestPairCIDRsToIPBlocks(t *testing.T) { + first, second, err := ipblocks.PairCIDRsToIPBlocks("5.6.7.8/24", "4.9.2.1/32") require.Nil(t, err) - require.False(t, res) + require.Equal(t, "5.6.7.0/24", first.ListToPrint()[0]) + require.Equal(t, "4.9.2.1/32", second.ListToPrint()[0]) - _, err = ipblocks.IsAddressInSubnet("1.2.3.4/30", "1.0.0.0/16") - require.NotNil(t, err) + intersect := first.Intersection(second) + require.Equal(t, "", intersect.ToIPAddressString()) + require.Empty(t, intersect.ListToPrint()) + require.Empty(t, intersect.ToCidrListString()) } func TestPrefixLength(t *testing.T) { - ipb := ipblocks.NewIPBlockFromCidrOrAddress("42.5.2.8/20") + ipb, err := ipblocks.NewIPBlockFromCidrOrAddress("42.5.2.8/20") + require.Nil(t, err) prefLen, err := ipb.PrefixLength() require.Nil(t, err) require.Equal(t, int64(20), prefLen) - ipb = ipblocks.NewIPBlockFromCidrOrAddress("42.5.2.8") + ipb, err = ipblocks.NewIPBlockFromCidrOrAddress("42.5.2.8") + require.Nil(t, err) prefLen, err = ipb.PrefixLength() require.Nil(t, err) require.Equal(t, int64(32), prefLen) - ipb = ipblocks.NewIPBlockFromCidrList([]string{"1.2.3.4/32", "172.0.0.0/8"}) + ipb, err = ipblocks.NewIPBlockFromCidrList([]string{"1.2.3.4/32", "172.0.0.0/8"}) + require.Nil(t, err) _, err = ipb.PrefixLength() require.NotNil(t, err) } + +func TestBadPath(t *testing.T) { + _, err := ipblocks.NewIPBlock("not-a-cidr", nil) + require.NotNil(t, err) + + _, err = ipblocks.NewIPBlock("2.5.7.9/24", []string{"5.6.7.8/20", "not-a-cidr"}) + require.NotNil(t, err) + + _, err = ipblocks.NewIPBlockFromCidrList([]string{"1.2.3.4/20", "not-a-cidr"}) + require.NotNil(t, err) + + _, err = ipblocks.NewIPBlockFromCidrList([]string{"1.2.3.4/20", "1.2.3.4/40"}) + require.NotNil(t, err) + + _, err = ipblocks.IPBlockFromIPRangeStr("1.2.3.4") + require.NotNil(t, err) + + _, err = ipblocks.IPBlockFromIPRangeStr("prefix-1.2.3.4") + require.NotNil(t, err) + + _, err = ipblocks.IPBlockFromIPRangeStr("1.2.3.290-1.2.3.4") + require.NotNil(t, err) + + _, err = ipblocks.IPBlockFromIPRangeStr("1.2.3.4-suffix") + require.NotNil(t, err) + + _, err = ipblocks.IPBlockFromIPRangeStr("1.2.3.4-2.5.6.7/20") + require.NotNil(t, err) + + _, _, err = ipblocks.PairCIDRsToIPBlocks("1.2.3.4/40", "1.2.3.5/24") + require.NotNil(t, err) + + _, _, err = ipblocks.PairCIDRsToIPBlocks("1.2.3.4/20", "not-a-cidr") + require.NotNil(t, err) +}