Skip to content

Commit

Permalink
Improve interval.go and intervalset.go
Browse files Browse the repository at this point in the history
interval:
* Separate tests from interval_test.go
* Improve documentation.
* Export and set-like functions that are well defined.
* Rename interval.Subtract to interval.SubtractSplit, and add tests.
* Handle empty cases first.
* Preallocate Elements.

intervalset:
* Guard Size() from overflow, and use intervalset.CalculateSize().
* Handle empty cases first.
* Remove String() method, since it is not obvious; clients should implement.

Signed-off-by: Elazar Gershuni <[email protected]>
  • Loading branch information
elazarg committed Jul 8, 2024
1 parent 82b2ce3 commit 8d8e435
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 45 deletions.
36 changes: 23 additions & 13 deletions pkg/interval/interval.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@ package interval

import "fmt"

// Interval is an integer interval from start to end inclusive
// Interval is an integer interval from start to end inclusive.
// An empty interval is represented by [-1, 0].
type Interval struct {
start int64
end int64
}

// New creates a new Interval object with the given start and end values.
// If end < start, the interval is considered empty, and is returned as [-1, 0].
func New(start, end int64) Interval {
if end < start {
return Interval{start: 0, end: -1}
Expand All @@ -29,7 +32,7 @@ func (i Interval) End() int64 {
return i.end
}

// String returns a String representation of Interval object
// String returns a String representation of Interval object: [start-end]
func (i Interval) String() string {
if i.IsEmpty() {
return "[]"
Expand All @@ -38,7 +41,7 @@ func (i Interval) String() string {
}

// ShortString returns a compacted String representation of Interval object:
// "v" instead of "v-v", without braces
// Without braces, and "v" instead of "v-v"
func (i Interval) ShortString() string {
if i.IsEmpty() {
return ""
Expand All @@ -64,26 +67,32 @@ func (i Interval) Size() int64 {
return i.end - i.start + 1
}

func (i Interval) overlap(other Interval) bool {
func (i Interval) Overlap(other Interval) bool {
if i.IsEmpty() {
return false
}
return other.end >= i.start && other.start <= i.end
}

func (i Interval) isSubset(other Interval) bool {
func (i Interval) IsSubset(other Interval) bool {
if i.IsEmpty() {
return true
}
return other.start <= i.start && other.end >= i.end
}

// returns a list with up to 2 intervals
func (i Interval) subtract(other Interval) []Interval {
if !i.overlap(other) {
// SubtractSplit returns a list with up to 2 intervals
func (i Interval) SubtractSplit(other Interval) []Interval {
if i.IsEmpty() {
return []Interval{}
}
if other.IsEmpty() {
return []Interval{i}
}
if !i.Overlap(other) {
return []Interval{i}
}
if i.isSubset(other) {
if i.IsSubset(other) {
return []Interval{}
}
if i.start < other.start && i.end > other.end {
Expand All @@ -96,17 +105,18 @@ func (i Interval) subtract(other Interval) []Interval {
return []Interval{{start: max(i.start, other.end+1), end: i.end}}
}

func (i Interval) intersect(other Interval) Interval {
func (i Interval) Intersect(other Interval) Interval {
return New(
max(i.start, other.start),
min(i.end, other.end),
)
}

func (i Interval) Elements() []int64 {
var res []int64
for v := i.start; v <= i.end; v++ {
res = append(res, v)
size := i.Size()
res := make([]int64, size)
for v := int64(0); v < size; v++ {
res[v] = i.start + v
}
return res
}
Expand Down
106 changes: 106 additions & 0 deletions pkg/interval/interval_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
Copyright 2023- IBM Inc. All Rights Reserved.
SPDX-License-Identifier: Apache-2.0
*/

package interval_test

import (
"testing"

"github.com/stretchr/testify/require"

"github.com/np-guard/models/pkg/interval"
)

func span(start, end int64) interval.Interval {
return interval.New(start, end)
}

func empty() interval.Interval {
return interval.New(0, -1)
}

// To avoid recursion, only use helper functions that are declared earlier in the file.

func requireEqual(t *testing.T, actual, expected interval.Interval) {
t.Helper()
require.Equal(t, actual.Start(), expected.Start())
require.Equal(t, actual.End(), expected.End())
}

func requireIntersection(t *testing.T, i1, i2, expected interval.Interval) {
t.Helper()
require.Equal(t, !expected.IsEmpty(), i1.Overlap(i2))
require.Equal(t, !expected.IsEmpty(), i2.Overlap(i1))
requireEqual(t, i1.Intersect(i2), expected)
requireEqual(t, i2.Intersect(i1), expected)
}

func requireUnrelated(t *testing.T, i1, i2 interval.Interval) {
t.Helper()
require.False(t, i1.IsSubset(i2))
require.False(t, i2.IsSubset(i1))
require.True(t, i1.Intersect(i2).Size() < min(i1.Size(), i2.Size()))
}

func requireSubset(t *testing.T, small, large interval.Interval) {
t.Helper()
require.True(t, small.IsSubset(large))
requireIntersection(t, small, large, small)
require.Equal(t, small.Equal(large), large.IsSubset(small))
}

func TestInterval_Elements(t *testing.T) {
it1 := span(3, 7)

require.Equal(t, int64(5), it1.Size())
require.Equal(t, []int64{3, 4, 5, 6, 7}, it1.Elements())
}

func TestInterval_Empty(t *testing.T) {
requireEqual(t, empty(), empty())
requireEqual(t, span(5, 4), empty())

require.Equal(t, int64(0), span(-1, -2).Size())
require.Equal(t, []int64{}, span(-1, -2).Elements())
}

func TestInterval_Intersect(t *testing.T) {
s := span(4, 6)
requireIntersection(t, s, s, span(4, 6))

// requireIntersection checks both directions; no need to add tests for all combinations
requireIntersection(t, empty(), empty(), empty())
requireIntersection(t, span(4, 6), span(3, 7), span(4, 6))
requireIntersection(t, span(4, 6), span(3, 5), span(4, 5))
requireIntersection(t, span(4, 6), empty(), empty())
requireIntersection(t, span(4, 6), span(7, 8), empty())
}

func TestInterval_IsSubset(t *testing.T) {
requireSubset(t, empty(), empty())
requireSubset(t, empty(), span(1, 2))
requireSubset(t, span(1, 2), span(1, 2))
requireSubset(t, span(1, 2), span(0, 3))

requireUnrelated(t, span(1, 2), span(2, 3))
requireUnrelated(t, span(1, 2), span(3, 4))
requireUnrelated(t, span(1, 2), span(5, 6))
}

func TestInterval_SubtractSplit(t *testing.T) {
require.Equal(t, []interval.Interval{}, empty().SubtractSplit(empty()))
require.Equal(t, []interval.Interval{}, empty().SubtractSplit(span(1, 3)))
require.Equal(t, []interval.Interval{}, span(1, 3).SubtractSplit(span(1, 3)))
require.Equal(t, []interval.Interval{}, span(1, 3).SubtractSplit(span(0, 3)))
require.Equal(t, []interval.Interval{}, span(1, 3).SubtractSplit(span(1, 4)))
require.Equal(t, []interval.Interval{}, span(1, 3).SubtractSplit(span(0, 4)))

require.Equal(t, []interval.Interval{span(1, 1), span(3, 3)}, span(1, 3).SubtractSplit(span(2, 2)))
require.Equal(t, []interval.Interval{span(1, 2), span(5, 6)}, span(1, 6).SubtractSplit(span(3, 4)))

require.Equal(t, []interval.Interval{span(3, 4)}, span(3, 4).SubtractSplit(span(1, 2)))
require.Equal(t, []interval.Interval{span(3, 4)}, span(3, 4).SubtractSplit(span(5, 6)))
}
38 changes: 15 additions & 23 deletions pkg/interval/intervalset.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package interval

import (
"log"
"math"
"slices"
"sort"
)
Expand Down Expand Up @@ -55,6 +56,14 @@ func (c *CanonicalSet) CalculateSize() int64 {
return res
}

func (c *CanonicalSet) Size() int {
res := c.CalculateSize()
if res > math.MaxInt {
log.Panic("size of CanonicalSet exceeds int")
}
return int(res)
}

// Equal returns true if the CanonicalSet equals the input CanonicalSet
func (c *CanonicalSet) Equal(other *CanonicalSet) bool {
if c == other {
Expand Down Expand Up @@ -94,25 +103,16 @@ func (c *CanonicalSet) AddInterval(v Interval) {

// AddHole updates the current CanonicalSet object by removing the input Interval from the set
func (c *CanonicalSet) AddHole(hole Interval) {
if hole.IsEmpty() {
return
}
var newIntervalSet []Interval
for _, interval := range c.intervalSet {
newIntervalSet = append(newIntervalSet, interval.subtract(hole)...)
newIntervalSet = append(newIntervalSet, interval.SubtractSplit(hole)...)
}
c.intervalSet = newIntervalSet
}

// String returns a string representation of the current CanonicalSet object
func (c *CanonicalSet) String() string {
if c.IsEmpty() {
return "Empty"
}
res := ""
for _, interval := range c.intervalSet {
res += interval.ShortString() + ","
}
return res[:len(res)-1]
}

// Union returns the union of the two sets
func (c *CanonicalSet) Union(other *CanonicalSet) *CanonicalSet {
res := c.Copy()
Expand All @@ -134,14 +134,6 @@ func (c *CanonicalSet) Contains(n int64) bool {
return New(n, n).ToSet().IsSubset(c)
}

func (c *CanonicalSet) Size() int {
res := 0
for _, v := range c.intervalSet {
res += int(v.Size())
}
return res
}

// IsSubset returns true of the current CanonicalSet is contained in the input CanonicalSet
func (c *CanonicalSet) IsSubset(other *CanonicalSet) bool {
if c == other {
Expand Down Expand Up @@ -169,7 +161,7 @@ func (c *CanonicalSet) Intersect(other *CanonicalSet) *CanonicalSet {
res := NewCanonicalSet()
for _, left := range c.intervalSet {
for _, right := range other.intervalSet {
res.AddInterval(left.intersect(right))
res.AddInterval(left.Intersect(right))
}
}
return res
Expand All @@ -182,7 +174,7 @@ func (c *CanonicalSet) Overlap(other *CanonicalSet) bool {
}
for _, selfInterval := range c.intervalSet {
for _, otherInterval := range other.intervalSet {
if selfInterval.overlap(otherInterval) {
if selfInterval.Overlap(otherInterval) {
return true
}
}
Expand Down
13 changes: 4 additions & 9 deletions pkg/interval/intervalset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,21 @@ import (
"github.com/np-guard/models/pkg/interval"
)

func TestInterval(t *testing.T) {
it1 := interval.New(3, 7)

require.Equal(t, "[3-7]", it1.String())
}

func TestIntervalSet(t *testing.T) {
is1 := interval.NewCanonicalSet()
is1.AddInterval(interval.New(5, 10))
is1.AddInterval(interval.New(0, 1))
is1.AddInterval(interval.New(3, 3))
is1.AddInterval(interval.New(70, 80))
is1.AddInterval(interval.New(0, -1))
is1 = is1.Subtract(interval.New(7, 9).ToSet())
require.True(t, is1.Contains(5))
require.False(t, is1.Contains(8))

is2 := interval.NewCanonicalSet()
require.Equal(t, "Empty", is2.String())
require.True(t, is2.IsEmpty())
is2.AddInterval(interval.New(6, 8))
require.Equal(t, "6-8", is2.String())
require.Equal(t, []int64{6, 7, 8}, is2.Elements())
require.False(t, is2.IsSingleNumber())
require.False(t, is2.IsSubset(is1))
require.False(t, is1.IsSubset(is2))
Expand Down Expand Up @@ -68,5 +63,5 @@ func TestIntervalSetSubtract(t *testing.T) {
d.AddInterval(interval.New(400, 700))
actual := s.Subtract(&d)
expected := interval.New(1, 49).ToSet()
require.Equal(t, expected.String(), actual.String())
require.Equal(t, expected.Elements(), actual.Elements())
}

0 comments on commit 8d8e435

Please sign in to comment.