diff --git a/pkg/netset/ipblock.go b/pkg/netset/ipblock.go index 1dcb173..164a336 100644 --- a/pkg/netset/ipblock.go +++ b/pkg/netset/ipblock.go @@ -327,6 +327,16 @@ func cidrToInterval(cidr string) (interval.Interval, error) { return interval.New(int64(startNum), int64(endNum)), nil } +// AsCidr returns the CIDR string of this IPBlock object, if it contains exactly one CIDR, +// otherwise it returns an error +func (b *IPBlock) AsCidr() (string, error) { + cidrList := b.ToCidrList() + if len(cidrList) != 1 { + return "", fmt.Errorf("ipblock contains %d cidrs", len(cidrList)) + } + return cidrList[0], nil +} + // ToCidrList returns a list of CIDR strings for this IPBlock object func (b *IPBlock) ToCidrList() []string { var cidrList []string diff --git a/pkg/netset/ipblock_test.go b/pkg/netset/ipblock_test.go index 83af0ce..4763fb8 100644 --- a/pkg/netset/ipblock_test.go +++ b/pkg/netset/ipblock_test.go @@ -100,6 +100,15 @@ func TestConversions(t *testing.T) { require.Equal(t, ipRange, toPrint[0]) require.Equal(t, "", ipb1.ToIPAddressString()) + + _, err = ipb1.AsCidr() + require.NotNil(t, err) + + cidr := "5.2.1.0/24" + ipb3, _ := netset.IPBlockFromCidr(cidr) + str, err := ipb3.AsCidr() + require.Nil(t, err) + require.Equal(t, str, cidr) } func TestDisjointIPBlocks(t *testing.T) {