diff --git a/pkg/hypercube/hypercubeset.go b/pkg/hypercube/hypercubeset.go index e0b90fa..4224a20 100644 --- a/pkg/hypercube/hypercubeset.go +++ b/pkg/hypercube/hypercubeset.go @@ -55,39 +55,39 @@ func (c *CanonicalSet) Union(other *CanonicalSet) *CanonicalSet { if c.dimensions != other.dimensions { return nil } - res := NewCanonicalSet(c.dimensions) remainingFromOther := map[*interval.CanonicalSet]*interval.CanonicalSet{} - for k := range other.layers { - remainingFromOther[k] = k.Copy() + for otherKey := range other.layers { + remainingFromOther[otherKey] = otherKey.Copy() } + layers := map[*interval.CanonicalSet]*CanonicalSet{} for k, v := range c.layers { remainingFromSelf := k.Copy() for otherKey, otherVal := range other.layers { - commonElem := k.Copy() - commonElem.Intersect(*otherKey) + commonElem := k.Intersect(*otherKey) if commonElem.IsEmpty() { continue } - remainingFromOther[otherKey].Subtract(*commonElem) - remainingFromSelf.Subtract(*commonElem) - if c.dimensions == 1 { - res.layers[commonElem] = NewCanonicalSet(0) - continue + remainingFromOther[otherKey] = remainingFromOther[otherKey].Subtract(*commonElem) + remainingFromSelf = remainingFromSelf.Subtract(*commonElem) + newSubElem := NewCanonicalSet(0) + if c.dimensions != 1 { + newSubElem = v.Union(otherVal) } - newSubElem := v.Union(otherVal) - res.layers[commonElem] = newSubElem + layers[commonElem] = newSubElem } if !remainingFromSelf.IsEmpty() { - res.layers[remainingFromSelf] = v.Copy() + layers[remainingFromSelf] = v.Copy() } } for k, v := range remainingFromOther { if !v.IsEmpty() { - res.layers[v] = other.layers[k].Copy() + layers[v] = other.layers[k].Copy() } } - res.applyElementsUnionPerLayer() - return res + return &CanonicalSet{ + layers: getElementsUnionPerLayer(layers), + dimensions: c.dimensions, + } } // IsEmpty returns true if c is empty @@ -100,26 +100,28 @@ func (c *CanonicalSet) Intersect(other *CanonicalSet) *CanonicalSet { if c.dimensions != other.dimensions { return nil } - res := NewCanonicalSet(c.dimensions) + + layers := map[*interval.CanonicalSet]*CanonicalSet{} for k, v := range c.layers { for otherKey, otherVal := range other.layers { - commonELem := k.Copy() - commonELem.Intersect(*otherKey) + commonELem := k.Intersect(*otherKey) if commonELem.IsEmpty() { continue } if c.dimensions == 1 { - res.layers[commonELem] = NewCanonicalSet(0) + layers[commonELem] = NewCanonicalSet(0) continue } newSubElem := v.Intersect(otherVal) if !newSubElem.IsEmpty() { - res.layers[commonELem] = newSubElem + layers[commonELem] = newSubElem } } } - res.applyElementsUnionPerLayer() - return res + return &CanonicalSet{ + layers: getElementsUnionPerLayer(layers), + dimensions: c.dimensions, + } } // Subtract returns a new CanonicalSet object that results from subtraction other from c @@ -127,36 +129,37 @@ func (c *CanonicalSet) Subtract(other *CanonicalSet) *CanonicalSet { if c.dimensions != other.dimensions { return nil } - res := NewCanonicalSet(c.dimensions) + layers := map[*interval.CanonicalSet]*CanonicalSet{} for k, v := range c.layers { remainingFromSelf := k.Copy() for otherKey, otherVal := range other.layers { - commonELem := k.Copy() - commonELem.Intersect(*otherKey) - if commonELem.IsEmpty() { + commonElem := k.Intersect(*otherKey) + if commonElem.IsEmpty() { continue } - remainingFromSelf.Subtract(*commonELem) + remainingFromSelf = remainingFromSelf.Subtract(*commonElem) if c.dimensions == 1 { continue } newSubElem := v.Subtract(otherVal) if !newSubElem.IsEmpty() { - res.layers[commonELem] = newSubElem + layers[commonElem] = newSubElem } } if !remainingFromSelf.IsEmpty() { - res.layers[remainingFromSelf] = v.Copy() + layers[remainingFromSelf] = v.Copy() } } - res.applyElementsUnionPerLayer() - return res + return &CanonicalSet{ + layers: getElementsUnionPerLayer(layers), + dimensions: c.dimensions, + } } func (c *CanonicalSet) getIntervalSetUnion() *interval.CanonicalSet { res := interval.NewCanonicalIntervalSet() for k := range c.layers { - res.Union(*k) + res = res.Union(*k) } return res } @@ -176,13 +179,10 @@ func (c *CanonicalSet) ContainedIn(other *CanonicalSet) (bool, error) { } isSubsetCount := 0 - for k, v := range c.layers { - currentLayer := k.Copy() + for currentLayer, v := range c.layers { for otherKey, otherVal := range other.layers { - commonKey := currentLayer.Copy() - commonKey.Intersect(*otherKey) - remaining := currentLayer.Copy() - remaining.Subtract(*commonKey) + commonKey := currentLayer.Intersect(*otherKey) + remaining := currentLayer.Subtract(*commonKey) if !commonKey.IsEmpty() { subContainment, err := v.ContainedIn(otherVal) if !subContainment || err != nil { @@ -248,13 +248,13 @@ func (c *CanonicalSet) GetCubesList() [][]*interval.CanonicalSet { return res } -func (c *CanonicalSet) applyElementsUnionPerLayer() { +func getElementsUnionPerLayer(layers map[*interval.CanonicalSet]*CanonicalSet) map[*interval.CanonicalSet]*CanonicalSet { type pair struct { hc *CanonicalSet // hypercube set object is []*interval.CanonicalSet // interval-set list } equivClasses := map[string]*pair{} - for k, v := range c.layers { + for k, v := range layers { if _, ok := equivClasses[v.String()]; ok { equivClasses[v.String()].is = append(equivClasses[v.String()].is, k) } else { @@ -266,11 +266,11 @@ func (c *CanonicalSet) applyElementsUnionPerLayer() { newVal := p.hc newKey := p.is[0] for i := 1; i < len(p.is); i += 1 { - newKey.Union(*p.is[i]) + newKey = newKey.Union(*p.is[i]) } newLayers[newKey] = newVal } - c.layers = newLayers + return newLayers } // FromCube returns a new CanonicalSet created from a single input cube diff --git a/pkg/hypercube/hypercubeset_test.go b/pkg/hypercube/hypercubeset_test.go index 4c6ac5b..454d58f 100644 --- a/pkg/hypercube/hypercubeset_test.go +++ b/pkg/hypercube/hypercubeset_test.go @@ -3,6 +3,7 @@ package hypercube_test import ( + "fmt" "testing" "github.com/stretchr/testify/require" @@ -63,8 +64,7 @@ func TestBasic(t *testing.T) { a = addCube(a, 1, 2) a = addCube(a, 5, 6) a = addCube(a, 3, 4) - b := hypercube.NewCanonicalSet(1) - b = addCube(b, 1, 6) + b := hypercube.FromCubeShort(1, 6) require.True(t, a.Equal(b)) } @@ -204,27 +204,35 @@ func TestBasicAddHole(t *testing.T) { require.True(t, e.Equal(hypercube.FromCubeShort(1, 10))) } -func TestAddHoleBasic2(t *testing.T) { - a := hypercube.FromCubeShort(1, 100, 200, 300) - b := a.Copy() - c := a.Copy() - a = a.Subtract(hypercube.FromCubeShort(50, 60, 220, 300)) +func TestAddHoleBasic20(t *testing.T) { + a := hypercube.FromCubeShort(1, 100, 200, 300).Subtract(hypercube.FromCubeShort(50, 60, 220, 300)) resA := hypercube.FromCubeShort(61, 100, 200, 300) resA = addCube(resA, 50, 60, 200, 219) resA = addCube(resA, 1, 49, 200, 300) - require.True(t, a.Equal(resA)) + require.True(t, a.Equal(resA), fmt.Sprintf("%v != %v", a, resA)) +} - b = b.Subtract(hypercube.FromCubeShort(50, 1000, 0, 250)) +func TestAddHoleBasic21(t *testing.T) { + b := hypercube.FromCubeShort(1, 100, 200, 300).Subtract(hypercube.FromCubeShort(50, 1000, 0, 250)) resB := hypercube.FromCubeShort(50, 100, 251, 300) resB = addCube(resB, 1, 49, 200, 300) - require.True(t, b.Equal(resB)) - - c = addCube(c, 400, 700, 200, 300) - c = c.Subtract(hypercube.FromCubeShort(50, 1000, 0, 250)) - resC := hypercube.FromCubeShort(50, 100, 251, 300) - resC = addCube(resC, 1, 49, 200, 300) - resC = addCube(resC, 400, 700, 251, 300) - require.True(t, c.Equal(resC)) + require.True(t, b.Equal(resB), fmt.Sprintf("%v != %v", b, resB)) +} + +func TestAddHoleBasic22(t *testing.T) { + a := hypercube.FromCubeShort(1, 2, 1, 2) + require.Equal(t, "[(1),(2)]; [(2),(1-2)]", a.Subtract(hypercube.FromCubeShort(1, 1, 1, 1)).String()) + require.Equal(t, "[(1),(1)]; [(2),(1-2)]", a.Subtract(hypercube.FromCubeShort(1, 1, 2, 2)).String()) + require.Equal(t, "[(1),(1-2)]; [(2),(2)]", a.Subtract(hypercube.FromCubeShort(2, 2, 1, 1)).String()) + require.Equal(t, "[(1),(1-2)]; [(2),(1)]", a.Subtract(hypercube.FromCubeShort(2, 2, 2, 2)).String()) +} + +func TestAddHoleBasic23(t *testing.T) { + a := hypercube.FromCubeShort(1, 100, 200, 300) + a = addCube(a, 400, 700, 200, 300) + require.Equal(t, "[(1-100,400-700),(200-300)]", a.String()) + a = a.Subtract(hypercube.FromCubeShort(50, 1000, 0, 250)) + require.Equal(t, "[(1-49),(200-300)]; [(50-100,400-700),(251-300)]", a.String()) } func TestAddHole(t *testing.T) { @@ -247,6 +255,7 @@ func TestAddHole2(t *testing.T) { d = addCube(d, 301, 400, 20, 300) require.True(t, c.Equal(d)) } + func TestAddHole3(t *testing.T) { c := hypercube.FromCubeShort(1, 100, 200, 300) c = c.Subtract(hypercube.FromCubeShort(1, 100, 200, 300)) diff --git a/pkg/interval/intervalset.go b/pkg/interval/intervalset.go index b58222e..d516c34 100644 --- a/pkg/interval/intervalset.go +++ b/pkg/interval/intervalset.go @@ -64,15 +64,6 @@ func (c *CanonicalSet) AddInterval(v Interval) { c.IntervalSet = slices.Replace(c.IntervalSet, left, right, v) } -// AddHole updates the current CanonicalSet object by removing the input Interval from the set -func (c *CanonicalSet) AddHole(hole Interval) { - newIntervalSet := []Interval{} - for _, interval := range c.IntervalSet { - newIntervalSet = append(newIntervalSet, interval.subtract(hole)...) - } - c.IntervalSet = newIntervalSet -} - // String returns a string representation of the current CanonicalSet object func (c *CanonicalSet) String() string { if c.IsEmpty() { @@ -90,11 +81,13 @@ func (c *CanonicalSet) String() string { return res[:len(res)-1] } -// Union updates the CanonicalSet object with the union result of the input CanonicalSet -func (c *CanonicalSet) Union(other CanonicalSet) { +// Union returns the union of the two sets +func (c *CanonicalSet) Union(other CanonicalSet) *CanonicalSet { + res := c.Copy() for _, interval := range other.IntervalSet { - c.AddInterval(interval) + res.AddInterval(interval) } + return res } // Copy returns a new copy of the CanonicalSet object @@ -123,15 +116,15 @@ func (c *CanonicalSet) ContainedIn(other CanonicalSet) bool { return true } -// Intersect updates current CanonicalSet with intersection result of input CanonicalSet -func (c *CanonicalSet) Intersect(other CanonicalSet) { - newIntervalSet := []Interval{} +// Intersect returns the intersection of the current set with the input set +func (c *CanonicalSet) Intersect(other CanonicalSet) *CanonicalSet { + res := NewCanonicalIntervalSet() for _, interval := range c.IntervalSet { for _, otherInterval := range other.IntervalSet { - newIntervalSet = append(newIntervalSet, interval.intersection(otherInterval)...) + res.IntervalSet = append(res.IntervalSet, interval.intersection(otherInterval)...) } } - c.IntervalSet = newIntervalSet + return res } // Overlaps returns true if current CanonicalSet overlaps with input CanonicalSet @@ -147,9 +140,17 @@ func (c *CanonicalSet) Overlaps(other *CanonicalSet) bool { } // Subtract updates current CanonicalSet with subtraction result of input CanonicalSet -func (c *CanonicalSet) Subtract(other CanonicalSet) { - for _, i := range other.IntervalSet { - c.AddHole(i) +func (c *CanonicalSet) Subtract(other CanonicalSet) *CanonicalSet { + res := slices.Clone(c.IntervalSet) + for _, hole := range other.IntervalSet { + newIntervalSet := []Interval{} + for _, interval := range res { + newIntervalSet = append(newIntervalSet, interval.subtract(hole)...) + } + res = newIntervalSet + } + return &CanonicalSet{ + IntervalSet: res, } } diff --git a/pkg/interval/intervalset_test.go b/pkg/interval/intervalset_test.go index c1de9cc..affe3bb 100644 --- a/pkg/interval/intervalset_test.go +++ b/pkg/interval/intervalset_test.go @@ -22,7 +22,7 @@ func TestIntervalSet(t *testing.T) { is1.AddInterval(interval.Interval{0, 1}) is1.AddInterval(interval.Interval{3, 3}) is1.AddInterval(interval.Interval{70, 80}) - is1.AddHole(interval.Interval{7, 9}) + is1 = is1.Subtract(*interval.CreateSetFromInterval(7, 9)) require.True(t, is1.Contains(5)) require.False(t, is1.Contains(8)) @@ -38,23 +38,31 @@ func TestIntervalSet(t *testing.T) { require.True(t, is1.Overlaps(is2)) require.True(t, is2.Overlaps(is1)) - is1.Subtract(*is2) + is1 = is1.Subtract(*is2) require.False(t, is2.ContainedIn(*is1)) require.False(t, is1.ContainedIn(*is2)) require.False(t, is1.Overlaps(is2)) require.False(t, is2.Overlaps(is1)) - is1.Union(*is2) - is1.Union(*interval.CreateSetFromInterval(7, 9)) + is1 = is1.Union(*is2).Union(*interval.CreateSetFromInterval(7, 9)) require.True(t, is2.ContainedIn(*is1)) require.False(t, is1.ContainedIn(*is2)) require.True(t, is1.Overlaps(is2)) require.True(t, is2.Overlaps(is1)) - is3 := is1.Copy() - is3.Intersect(*is2) + is3 := is1.Intersect(*is2) require.True(t, is3.Equal(*is2)) require.True(t, is2.ContainedIn(*is3)) require.True(t, interval.CreateSetFromInterval(1, 1).IsSingleNumber()) } + +func TestIntervalSetSubtract(t *testing.T) { + s := interval.CreateSetFromInterval(1, 100) + s.AddInterval(interval.Interval{Start: 400, End: 700}) + d := *interval.CreateSetFromInterval(50, 100) + d.AddInterval(interval.Interval{Start: 400, End: 700}) + actual := s.Subtract(d) + expected := interval.CreateSetFromInterval(1, 49) + require.Equal(t, expected.String(), actual.String()) +}