Skip to content

Commit

Permalink
Make interval fields private, handle empty interval (#28)
Browse files Browse the repository at this point in the history
* make interval fields private, handle empty interval
* add .ShortString()
---------

Signed-off-by: Elazar Gershuni <[email protected]>
  • Loading branch information
elazarg authored Mar 24, 2024
1 parent fc2b1ae commit 7db084b
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 60 deletions.
6 changes: 3 additions & 3 deletions pkg/connection/connectionset.go
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ func getCubeAsTCPItems(cube []*interval.CanonicalSet, protocol spec.TcpUdpProtoc
if !srcPorts.Equal(entireDimension(srcPort)) {
// iterate the interval in the interval-set
for _, interval := range srcPorts.Intervals() {
tcpRes := spec.TcpUdp{Protocol: protocol, MinSourcePort: int(interval.Start), MaxSourcePort: int(interval.End)}
tcpRes := spec.TcpUdp{Protocol: protocol, MinSourcePort: int(interval.Start()), MaxSourcePort: int(interval.End())}
tcpItemsTemp = append(tcpItemsTemp, tcpRes)
}
} else {
Expand All @@ -305,8 +305,8 @@ func getCubeAsTCPItems(cube []*interval.CanonicalSet, protocol spec.TcpUdpProtoc
Protocol: protocol,
MinSourcePort: tcpItemTemp.MinSourcePort,
MaxSourcePort: tcpItemTemp.MaxSourcePort,
MinDestinationPort: int(interval.Start),
MaxDestinationPort: int(interval.End),
MinDestinationPort: int(interval.Start()),
MaxDestinationPort: int(interval.End()),
}
tcpItemsFinal = append(tcpItemsFinal, tcpRes)
}
Expand Down
88 changes: 66 additions & 22 deletions pkg/interval/interval.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,63 +4,107 @@ package interval

import "fmt"

// Interval is an integer interval from Start to End
// Interval is an integer interval from start to end inclusive
type Interval struct {
Start int64
End int64
start int64
end int64
}

func New(start, end int64) Interval {
return Interval{Start: start, End: end}
if end < start {
return Interval{start: 0, end: -1}
}
return Interval{start: start, end: end}
}

func (i Interval) Start() int64 {
return i.start
}

func (i Interval) End() int64 {
return i.end
}

// String returns a String representation of Interval object
func (i Interval) String() string {
return fmt.Sprintf("[%v-%v]", i.Start, i.End)
if i.IsEmpty() {
return "[]"
}
return fmt.Sprintf("[%v-%v]", i.start, i.end)
}

// ShortString returns a compacted String representation of Interval object:
// "v" instead of "v-v", without braces
func (i Interval) ShortString() string {
if i.IsEmpty() {
return ""
}
if i.start == i.end {
return fmt.Sprintf("%v", i.start)
}
return fmt.Sprintf("%v-%v", i.start, i.end)
}

// IsEmpty returns true if the interval is empty, false otherwise.
// An interval is considered empty if its start is greater than its end.
func (i Interval) IsEmpty() bool {
return i.end < i.start
}

// Equal returns true if current Interval obj is equal to the input Interval
func (i Interval) Equal(x Interval) bool {
return i.Start == x.Start && i.End == x.End
return i.start == x.start && i.end == x.end
}

func (i Interval) Size() int64 {
return i.End - i.Start + 1
return i.end - i.start + 1
}

func (i Interval) overlaps(other Interval) bool {
return other.End >= i.Start && other.Start <= i.End
func (i Interval) overlap(other Interval) bool {
if i.IsEmpty() {
return false
}
return other.end >= i.start && other.start <= i.end
}

func (i Interval) isSubset(other Interval) bool {
return other.Start <= i.Start && other.End >= i.End
if i.IsEmpty() {
return true
}
return other.start <= i.start && other.end >= i.end
}

// returns a list with up to 2 intervals
func (i Interval) subtract(other Interval) []Interval {
if !i.overlaps(other) {
if !i.overlap(other) {
return []Interval{i}
}
if i.isSubset(other) {
return []Interval{}
}
if i.Start < other.Start && i.End > other.End {
if i.start < other.start && i.end > other.end {
// self is split into two ranges by other
return []Interval{{Start: i.Start, End: other.Start - 1}, {Start: other.End + 1, End: i.End}}
return []Interval{{start: i.start, end: other.start - 1}, {start: other.end + 1, end: i.end}}
}
if i.Start < other.Start {
return []Interval{{Start: i.Start, End: min(i.End, other.Start-1)}}
if i.start < other.start {
return []Interval{{start: i.start, end: min(i.end, other.start-1)}}
}
return []Interval{{Start: max(i.Start, other.End+1), End: i.End}}
return []Interval{{start: max(i.start, other.end+1), end: i.end}}
}

func (i Interval) intersection(other Interval) []Interval {
maxStart := max(i.Start, other.Start)
minEnd := min(i.End, other.End)
if minEnd < maxStart {
return []Interval{}
func (i Interval) intersect(other Interval) Interval {
return New(
max(i.start, other.start),
min(i.end, other.end),
)
}

func (i Interval) Elements() []int64 {
res := []int64{}
for v := i.start; v <= i.end; v++ {
res = append(res, v)
}
return []Interval{{Start: maxStart, End: minEnd}}
return res
}

func (i Interval) ToSet() *CanonicalSet {
Expand Down
41 changes: 18 additions & 23 deletions pkg/interval/intervalset.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
package interval

import (
"fmt"
"log"
"slices"
"sort"
Expand Down Expand Up @@ -32,7 +31,7 @@ func (c *CanonicalSet) Min() int64 {
if len(c.intervalSet) == 0 {
log.Panic("cannot take min from empty interval set")
}
return c.intervalSet[0].Start
return c.intervalSet[0].Start()
}

// IsEmpty returns true if the CanonicalSet is empty
Expand Down Expand Up @@ -66,18 +65,21 @@ func (c *CanonicalSet) Equal(other *CanonicalSet) bool {

// AddInterval adds a new interval range to the set
func (c *CanonicalSet) AddInterval(v Interval) {
if v.IsEmpty() {
return
}
set := c.intervalSet
left := sort.Search(len(set), func(i int) bool {
return set[i].End >= v.Start-1
return set[i].End() >= v.Start()-1
})
if left < len(set) && set[left].Start <= v.End {
v.Start = min(v.Start, set[left].Start)
if left < len(set) && set[left].Start() <= v.End() {
v = New(min(v.Start(), set[left].Start()), v.End())
}
right := sort.Search(len(set), func(j int) bool {
return set[j].Start > v.End+1
return set[j].Start() > v.End()+1
})
if right > 0 && set[right-1].End >= v.Start {
v.End = max(v.End, set[right-1].End)
if right > 0 && set[right-1].End() >= v.Start() {
v = New(v.Start(), max(v.End(), set[right-1].End()))
}
c.intervalSet = slices.Replace(c.intervalSet, left, right, v)
}
Expand All @@ -98,12 +100,7 @@ func (c *CanonicalSet) String() string {
}
res := ""
for _, interval := range c.intervalSet {
if interval.Start != interval.End {
res += fmt.Sprintf("%v-%v", interval.Start, interval.End)
} else {
res += fmt.Sprintf("%v", interval.Start)
}
res += ","
res += interval.ShortString() + ","
}
return res[:len(res)-1]
}
Expand Down Expand Up @@ -137,9 +134,9 @@ func (c *CanonicalSet) ContainedIn(other *CanonicalSet) bool {
larger := other.intervalSet
for _, target := range c.intervalSet {
left := sort.Search(len(larger), func(i int) bool {
return larger[i].End >= target.End
return larger[i].End() >= target.End()
})
if left == len(larger) || larger[left].Start > target.Start {
if left == len(larger) || larger[left].Start() > target.Start() {
return false
}
// Optimization
Expand All @@ -154,11 +151,9 @@ func (c *CanonicalSet) Intersect(other *CanonicalSet) *CanonicalSet {
return c.Copy()
}
res := NewCanonicalSet()
for _, interval := range c.intervalSet {
for _, otherInterval := range other.intervalSet {
for _, span := range interval.intersection(otherInterval) {
res.AddInterval(span)
}
for _, left := range c.intervalSet {
for _, right := range other.intervalSet {
res.AddInterval(left.intersect(right))
}
}
return res
Expand All @@ -171,7 +166,7 @@ func (c *CanonicalSet) Overlap(other *CanonicalSet) bool {
}
for _, selfInterval := range c.intervalSet {
for _, otherInterval := range other.intervalSet {
if selfInterval.overlaps(otherInterval) {
if selfInterval.overlap(otherInterval) {
return true
}
}
Expand Down Expand Up @@ -205,7 +200,7 @@ func (c *CanonicalSet) Elements() []int64 {
res := make([]int64, c.CalculateSize())
i := 0
for _, interval := range c.intervalSet {
for v := interval.Start; v <= interval.End; v++ {
for v := interval.Start(); v <= interval.End(); v++ {
res[i] = v
i++
}
Expand Down
16 changes: 8 additions & 8 deletions pkg/interval/intervalset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,24 @@ import (
)

func TestInterval(t *testing.T) {
it1 := interval.Interval{3, 7}
it1 := interval.New(3, 7)

require.Equal(t, "[3-7]", it1.String())
}

func TestIntervalSet(t *testing.T) {
is1 := interval.NewCanonicalSet()
is1.AddInterval(interval.Interval{5, 10})
is1.AddInterval(interval.Interval{0, 1})
is1.AddInterval(interval.Interval{3, 3})
is1.AddInterval(interval.Interval{70, 80})
is1.AddInterval(interval.New(5, 10))
is1.AddInterval(interval.New(0, 1))
is1.AddInterval(interval.New(3, 3))
is1.AddInterval(interval.New(70, 80))
is1 = is1.Subtract(interval.New(7, 9).ToSet())
require.True(t, is1.Contains(5))
require.False(t, is1.Contains(8))

is2 := interval.NewCanonicalSet()
require.Equal(t, "Empty", is2.String())
is2.AddInterval(interval.Interval{6, 8})
is2.AddInterval(interval.New(6, 8))
require.Equal(t, "6-8", is2.String())
require.False(t, is2.IsSingleNumber())
require.False(t, is2.ContainedIn(is1))
Expand Down Expand Up @@ -59,9 +59,9 @@ func TestIntervalSet(t *testing.T) {

func TestIntervalSetSubtract(t *testing.T) {
s := interval.New(1, 100).ToSet()
s.AddInterval(interval.Interval{Start: 400, End: 700})
s.AddInterval(interval.New(400, 700))
d := *interval.New(50, 100).ToSet()
d.AddInterval(interval.Interval{Start: 400, End: 700})
d.AddInterval(interval.New(400, 700))
actual := s.Subtract(&d)
expected := interval.New(1, 49).ToSet()
require.Equal(t, expected.String(), actual.String())
Expand Down
8 changes: 4 additions & 4 deletions pkg/ipblock/ipblock.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ func (b *IPBlock) ToIPRanges() string {

// toIPRange returns a string of the ip range of a single interval
func toIPRange(i interval.Interval) string {
startIP := intToIP4(i.Start)
endIP := intToIP4(i.End)
startIP := intToIP4(i.Start())
endIP := intToIP4(i.End())
return rangeIPstr(startIP, endIP)
}

Expand Down Expand Up @@ -314,8 +314,8 @@ func (b *IPBlock) ToIPAddressString() string {
}

func intervalToCidrList(ipRange interval.Interval) []string {
start := ipRange.Start
end := ipRange.End
start := ipRange.Start()
end := ipRange.End()
res := []string{}
for end >= start {
maxSize := maxIPv4Bits
Expand Down

0 comments on commit 7db084b

Please sign in to comment.