From a99ebea3db71baa20956938b16d9726a17039ec5 Mon Sep 17 00:00:00 2001 From: Elazar Gershuni Date: Mon, 11 Mar 2024 15:13:55 +0200 Subject: [PATCH] consistent use of pointers; simplify/optimize switchSrcDstPortsOnTCP Signed-off-by: Elazar Gershuni --- pkg/connection/connectionset.go | 10 +++++----- pkg/connection/statefulness.go | 18 ++++++------------ pkg/hypercube/hypercubeset.go | 31 ++++++++++++++++++++----------- pkg/interval/intervalset.go | 23 +++++++---------------- pkg/interval/intervalset_test.go | 30 +++++++++++++++--------------- pkg/ipblock/ipblock.go | 22 +++++++++++----------- 6 files changed, 64 insertions(+), 70 deletions(-) diff --git a/pkg/connection/connectionset.go b/pkg/connection/connectionset.go index db67dd1..94a8c51 100644 --- a/pkg/connection/connectionset.go +++ b/pkg/connection/connectionset.go @@ -251,7 +251,7 @@ func protocolStringFromCode(protocolCode int64) netp.ProtocolString { } func getDimensionString(dimValue *interval.CanonicalSet, dim Dimension) string { - if dimValue.Equal(*entireDimension(dim)) { + if dimValue.Equal(entireDimension(dim)) { // avoid adding dimension str on full dimension values return "" } @@ -342,7 +342,7 @@ func getCubeAsTCPorUDPItems(cube []*interval.CanonicalSet, isTCP bool) []netp.Pr tcpItemsTemp := []netp.Protocol{} // consider src ports srcPorts := cube[srcPort] - if srcPorts.Equal(*entireDimension(srcPort)) { + if srcPorts.Equal(entireDimension(srcPort)) { tcpItemsTemp = append(tcpItemsTemp, netp.TCPUDP{IsTCP: isTCP}) } else { // iterate the intervals in the interval-set @@ -359,7 +359,7 @@ func getCubeAsTCPorUDPItems(cube []*interval.CanonicalSet, isTCP bool) []netp.Pr } // consider dst ports dstPorts := cube[dstPort] - if dstPorts.Equal(*entireDimension(dstPort)) { + if dstPorts.Equal(entireDimension(dstPort)) { return tcpItemsTemp } tcpItemsFinal := []netp.Protocol{} @@ -381,8 +381,8 @@ func getCubeAsTCPorUDPItems(cube []*interval.CanonicalSet, isTCP bool) []netp.Pr func getCubeAsICMPItems(cube []*interval.CanonicalSet) []netp.Protocol { icmpTypes := cube[icmpType] icmpCodes := cube[icmpCode] - if icmpCodes.Equal(*entireDimension(icmpCode)) { - if icmpTypes.Equal(*entireDimension(icmpType)) { + if icmpCodes.Equal(entireDimension(icmpCode)) { + if icmpTypes.Equal(entireDimension(icmpType)) { return []netp.Protocol{netp.ICMP{}} } res := []netp.Protocol{} diff --git a/pkg/connection/statefulness.go b/pkg/connection/statefulness.go index 0153bfd..3805d8e 100644 --- a/pkg/connection/statefulness.go +++ b/pkg/connection/statefulness.go @@ -4,7 +4,6 @@ package connection import ( "github.com/np-guard/models/pkg/hypercube" - "github.com/np-guard/models/pkg/interval" "github.com/np-guard/models/pkg/netp" ) @@ -72,19 +71,14 @@ func (conn *Set) switchSrcDstPortsOnTCP() *Set { return conn.Copy() } res := None() - cubes := conn.connectionProperties.GetCubesList() - for _, cube := range cubes { + for _, cube := range conn.connectionProperties.GetCubesList() { // assuming cube[protocol] contains TCP only - srcPorts := cube[srcPort] - dstPorts := cube[dstPort] - // if the entire domain is enabled by both src and dst no need to switch - if !srcPorts.Equal(*entireDimension(srcPort)) || !dstPorts.Equal(*entireDimension(dstPort)) { - newCube := interval.CopyCube(cube) - newCube[srcPort], newCube[dstPort] = newCube[dstPort], newCube[srcPort] - res.connectionProperties = res.connectionProperties.Union(hypercube.FromCube(newCube)) - } else { - res.connectionProperties = res.connectionProperties.Union(hypercube.FromCube(cube)) + // no need to switch if src equals dst + if !cube[srcPort].Equal(cube[dstPort]) { + cube = hypercube.CopyCube(cube) + cube[srcPort], cube[dstPort] = cube[dstPort], cube[srcPort] } + res.connectionProperties = res.connectionProperties.Union(hypercube.FromCube(cube)) } return res } diff --git a/pkg/hypercube/hypercubeset.go b/pkg/hypercube/hypercubeset.go index 4224a20..d1a0139 100644 --- a/pkg/hypercube/hypercubeset.go +++ b/pkg/hypercube/hypercubeset.go @@ -63,12 +63,12 @@ func (c *CanonicalSet) Union(other *CanonicalSet) *CanonicalSet { for k, v := range c.layers { remainingFromSelf := k.Copy() for otherKey, otherVal := range other.layers { - commonElem := k.Intersect(*otherKey) + commonElem := k.Intersect(otherKey) if commonElem.IsEmpty() { continue } - remainingFromOther[otherKey] = remainingFromOther[otherKey].Subtract(*commonElem) - remainingFromSelf = remainingFromSelf.Subtract(*commonElem) + remainingFromOther[otherKey] = remainingFromOther[otherKey].Subtract(commonElem) + remainingFromSelf = remainingFromSelf.Subtract(commonElem) newSubElem := NewCanonicalSet(0) if c.dimensions != 1 { newSubElem = v.Union(otherVal) @@ -104,7 +104,7 @@ func (c *CanonicalSet) Intersect(other *CanonicalSet) *CanonicalSet { layers := map[*interval.CanonicalSet]*CanonicalSet{} for k, v := range c.layers { for otherKey, otherVal := range other.layers { - commonELem := k.Intersect(*otherKey) + commonELem := k.Intersect(otherKey) if commonELem.IsEmpty() { continue } @@ -133,11 +133,11 @@ func (c *CanonicalSet) Subtract(other *CanonicalSet) *CanonicalSet { for k, v := range c.layers { remainingFromSelf := k.Copy() for otherKey, otherVal := range other.layers { - commonElem := k.Intersect(*otherKey) + commonElem := k.Intersect(otherKey) if commonElem.IsEmpty() { continue } - remainingFromSelf = remainingFromSelf.Subtract(*commonElem) + remainingFromSelf = remainingFromSelf.Subtract(commonElem) if c.dimensions == 1 { continue } @@ -159,7 +159,7 @@ func (c *CanonicalSet) Subtract(other *CanonicalSet) *CanonicalSet { func (c *CanonicalSet) getIntervalSetUnion() *interval.CanonicalSet { res := interval.NewCanonicalIntervalSet() for k := range c.layers { - res = res.Union(*k) + res = res.Union(k) } return res } @@ -175,14 +175,14 @@ func (c *CanonicalSet) ContainedIn(other *CanonicalSet) (bool, error) { } cInterval := c.getIntervalSetUnion() otherInterval := other.getIntervalSetUnion() - return cInterval.ContainedIn(*otherInterval), nil + return cInterval.ContainedIn(otherInterval), nil } isSubsetCount := 0 for currentLayer, v := range c.layers { for otherKey, otherVal := range other.layers { - commonKey := currentLayer.Intersect(*otherKey) - remaining := currentLayer.Subtract(*commonKey) + commonKey := currentLayer.Intersect(otherKey) + remaining := currentLayer.Subtract(commonKey) if !commonKey.IsEmpty() { subContainment, err := v.ContainedIn(otherVal) if !subContainment || err != nil { @@ -266,7 +266,7 @@ func getElementsUnionPerLayer(layers map[*interval.CanonicalSet]*CanonicalSet) m newVal := p.hc newKey := p.is[0] for i := 1; i < len(p.is); i += 1 { - newKey = newKey.Union(*p.is[i]) + newKey = newKey.Union(p.is[i]) } newLayers[newKey] = newVal } @@ -299,3 +299,12 @@ func FromCubeShort(values ...int64) *CanonicalSet { } return FromCube(cube) } + +// CopyCube returns a new slice of intervals copied from input cube +func CopyCube(cube []*interval.CanonicalSet) []*interval.CanonicalSet { + newCube := make([]*interval.CanonicalSet, len(cube)) + for i, intervalSet := range cube { + newCube[i] = intervalSet.Copy() + } + return newCube +} diff --git a/pkg/interval/intervalset.go b/pkg/interval/intervalset.go index d516c34..df69bcc 100644 --- a/pkg/interval/intervalset.go +++ b/pkg/interval/intervalset.go @@ -34,7 +34,7 @@ func (c *CanonicalSet) CalculateSize() int64 { } // Equal returns true if the CanonicalSet equals the input CanonicalSet -func (c *CanonicalSet) Equal(other CanonicalSet) bool { +func (c *CanonicalSet) Equal(other *CanonicalSet) bool { if len(c.IntervalSet) != len(other.IntervalSet) { return false } @@ -82,7 +82,7 @@ func (c *CanonicalSet) String() string { } // Union returns the union of the two sets -func (c *CanonicalSet) Union(other CanonicalSet) *CanonicalSet { +func (c *CanonicalSet) Union(other *CanonicalSet) *CanonicalSet { res := c.Copy() for _, interval := range other.IntervalSet { res.AddInterval(interval) @@ -97,11 +97,11 @@ func (c *CanonicalSet) Copy() *CanonicalSet { func (c *CanonicalSet) Contains(n int64) bool { i := CreateSetFromInterval(n, n) - return i.ContainedIn(*c) + return i.ContainedIn(c) } -// ContainedIn returns true of the current CanonicalIntervalSet is contained in the input CanonicalIntervalSet -func (c *CanonicalSet) ContainedIn(other CanonicalSet) bool { +// ContainedIn returns true of the current interval.CanonicalSet is contained in the input interval.CanonicalSet +func (c *CanonicalSet) ContainedIn(other *CanonicalSet) bool { larger := other.IntervalSet for _, target := range c.IntervalSet { left := sort.Search(len(larger), func(i int) bool { @@ -117,7 +117,7 @@ func (c *CanonicalSet) ContainedIn(other CanonicalSet) bool { } // Intersect returns the intersection of the current set with the input set -func (c *CanonicalSet) Intersect(other CanonicalSet) *CanonicalSet { +func (c *CanonicalSet) Intersect(other *CanonicalSet) *CanonicalSet { res := NewCanonicalIntervalSet() for _, interval := range c.IntervalSet { for _, otherInterval := range other.IntervalSet { @@ -140,7 +140,7 @@ func (c *CanonicalSet) Overlaps(other *CanonicalSet) bool { } // Subtract updates current CanonicalSet with subtraction result of input CanonicalSet -func (c *CanonicalSet) Subtract(other CanonicalSet) *CanonicalSet { +func (c *CanonicalSet) Subtract(other *CanonicalSet) *CanonicalSet { res := slices.Clone(c.IntervalSet) for _, hole := range other.IntervalSet { newIntervalSet := []Interval{} @@ -190,12 +190,3 @@ func (c *CanonicalSet) Elements() []int64 { func CreateSetFromInterval(start, end int64) *CanonicalSet { return &CanonicalSet{IntervalSet: []Interval{{Start: start, End: end}}} } - -// copyCube returns a new slice of intervals copied from input cube -func CopyCube(cube []*CanonicalSet) []*CanonicalSet { - newCube := make([]*CanonicalSet, len(cube)) - for i, intervalSet := range cube { - newCube[i] = intervalSet.Copy() - } - return newCube -} diff --git a/pkg/interval/intervalset_test.go b/pkg/interval/intervalset_test.go index affe3bb..76876dd 100644 --- a/pkg/interval/intervalset_test.go +++ b/pkg/interval/intervalset_test.go @@ -22,7 +22,7 @@ func TestIntervalSet(t *testing.T) { is1.AddInterval(interval.Interval{0, 1}) is1.AddInterval(interval.Interval{3, 3}) is1.AddInterval(interval.Interval{70, 80}) - is1 = is1.Subtract(*interval.CreateSetFromInterval(7, 9)) + is1 = is1.Subtract(interval.CreateSetFromInterval(7, 9)) require.True(t, is1.Contains(5)) require.False(t, is1.Contains(8)) @@ -31,28 +31,28 @@ func TestIntervalSet(t *testing.T) { is2.AddInterval(interval.Interval{6, 8}) require.Equal(t, "6-8", is2.String()) require.False(t, is2.IsSingleNumber()) - require.False(t, is2.ContainedIn(*is1)) - require.False(t, is1.ContainedIn(*is2)) - require.False(t, is2.Equal(*is1)) - require.False(t, is1.Equal(*is2)) + require.False(t, is2.ContainedIn(is1)) + require.False(t, is1.ContainedIn(is2)) + require.False(t, is2.Equal(is1)) + require.False(t, is1.Equal(is2)) require.True(t, is1.Overlaps(is2)) require.True(t, is2.Overlaps(is1)) - is1 = is1.Subtract(*is2) - require.False(t, is2.ContainedIn(*is1)) - require.False(t, is1.ContainedIn(*is2)) + is1 = is1.Subtract(is2) + require.False(t, is2.ContainedIn(is1)) + require.False(t, is1.ContainedIn(is2)) require.False(t, is1.Overlaps(is2)) require.False(t, is2.Overlaps(is1)) - is1 = is1.Union(*is2).Union(*interval.CreateSetFromInterval(7, 9)) - require.True(t, is2.ContainedIn(*is1)) - require.False(t, is1.ContainedIn(*is2)) + is1 = is1.Union(is2).Union(interval.CreateSetFromInterval(7, 9)) + require.True(t, is2.ContainedIn(is1)) + require.False(t, is1.ContainedIn(is2)) require.True(t, is1.Overlaps(is2)) require.True(t, is2.Overlaps(is1)) - is3 := is1.Intersect(*is2) - require.True(t, is3.Equal(*is2)) - require.True(t, is2.ContainedIn(*is3)) + is3 := is1.Intersect(is2) + require.True(t, is3.Equal(is2)) + require.True(t, is2.ContainedIn(is3)) require.True(t, interval.CreateSetFromInterval(1, 1).IsSingleNumber()) } @@ -62,7 +62,7 @@ func TestIntervalSetSubtract(t *testing.T) { s.AddInterval(interval.Interval{Start: 400, End: 700}) d := *interval.CreateSetFromInterval(50, 100) d.AddInterval(interval.Interval{Start: 400, End: 700}) - actual := s.Subtract(d) + actual := s.Subtract(&d) expected := interval.CreateSetFromInterval(1, 49) require.Equal(t, expected.String(), actual.String()) } diff --git a/pkg/ipblock/ipblock.go b/pkg/ipblock/ipblock.go index d0e4553..560e44d 100644 --- a/pkg/ipblock/ipblock.go +++ b/pkg/ipblock/ipblock.go @@ -35,7 +35,7 @@ const ( // IPBlock captures a set of IP ranges type IPBlock struct { - ipRange interval.CanonicalSet + ipRange *interval.CanonicalSet } // ToIPRanges returns a string of the ip ranges in the current IPBlock object @@ -67,7 +67,7 @@ func (b *IPBlock) ContainedIn(c *IPBlock) bool { // Intersect returns a new IPBlock from intersection of this IPBlock with input IPBlock func (b *IPBlock) Intersect(c *IPBlock) *IPBlock { return &IPBlock{ - ipRange: *b.ipRange.Intersect(c.ipRange), + ipRange: b.ipRange.Intersect(c.ipRange), } } @@ -79,14 +79,14 @@ func (b *IPBlock) Equal(c *IPBlock) bool { // Subtract returns a new IPBlock from subtraction of input IPBlock from this IPBlock func (b *IPBlock) Subtract(c *IPBlock) *IPBlock { return &IPBlock{ - ipRange: *b.ipRange.Subtract(c.ipRange), + ipRange: b.ipRange.Subtract(c.ipRange), } } // Union returns a new IPBlock from union of input IPBlock with this IPBlock func (b *IPBlock) Union(c *IPBlock) *IPBlock { return &IPBlock{ - ipRange: *b.ipRange.Union(c.ipRange), + ipRange: b.ipRange.Union(c.ipRange), } } @@ -102,7 +102,7 @@ func rangeIPstr(start, end string) string { // Copy returns a new copy of IPBlock object func (b *IPBlock) Copy() *IPBlock { return &IPBlock{ - ipRange: *b.ipRange.Copy(), + ipRange: b.ipRange.Copy(), } } @@ -115,7 +115,7 @@ func (b *IPBlock) Split() []*IPBlock { res := make([]*IPBlock, len(b.ipRange.IntervalSet)) for index, set := range b.ipRange.Split() { res[index] = &IPBlock{ - ipRange: *set, + ipRange: set, } } return res @@ -157,7 +157,7 @@ func DisjointIPBlocks(set1, set2 []*IPBlock) []*IPBlock { func addIntervalToList(ipbNew *IPBlock, ipbList []*IPBlock) []*IPBlock { toAdd := []*IPBlock{} for idx, ipb := range ipbList { - if !ipb.ipRange.Overlaps(&ipbNew.ipRange) { + if !ipb.ipRange.Overlaps(ipbNew.ipRange) { continue } intersection := ipb.Intersect(ipbNew) @@ -182,7 +182,7 @@ func NewIPBlockFromCidr(cidr string) (*IPBlock, error) { return nil, err } return &IPBlock{ - ipRange: *interval.CreateSetFromInterval(start, end), + ipRange: interval.CreateSetFromInterval(start, end), }, nil } @@ -227,7 +227,7 @@ func NewIPBlockFromCidrList(cidrsList []string) (*IPBlock, error) { } ipRange = ipRange.Union(block.ipRange) } - return &IPBlock{ipRange: *ipRange}, nil + return &IPBlock{ipRange: ipRange}, nil } // NewIPBlockFromIPAddress returns an IPBlock object from input IP address string @@ -237,7 +237,7 @@ func NewIPBlockFromIPAddress(ipAddress string) (*IPBlock, error) { return nil, err } return &IPBlock{ - ipRange: *interval.CreateSetFromInterval(ipNum, ipNum), + ipRange: interval.CreateSetFromInterval(ipNum, ipNum), }, nil } @@ -343,7 +343,7 @@ func IPBlockFromIPRangeStr(ipRangeStr string) (*IPBlock, error) { return nil, errors.Join(err0, err1) } res := &IPBlock{ - ipRange: *interval.CreateSetFromInterval(startIPNum, endIPNum), + ipRange: interval.CreateSetFromInterval(startIPNum, endIPNum), } return res, nil }