From 1b431bc675473b2c3bb1853ad19f70744de04dbf Mon Sep 17 00:00:00 2001 From: Elazar Gershuni Date: Tue, 12 Mar 2024 17:33:14 +0200 Subject: [PATCH 01/15] move connectionset from analyzer and protocols from synthesizer Signed-off-by: Elazar Gershuni --- Makefile | 2 +- pkg/connection/connectionset.go | 442 +++++++++++++++++++++++++++ pkg/connection/connectionset_test.go | 58 ++++ pkg/connection/statefulness.go | 84 +++++ pkg/connection/statefulness_test.go | 148 +++++++++ pkg/hypercube/hypercubeset.go | 118 +++---- pkg/hypercube/hypercubeset_test.go | 373 ++++++++++------------ pkg/interval/interval.go | 20 +- pkg/interval/intervalset.go | 151 +++++---- pkg/interval/intervalset_test.go | 43 ++- pkg/ipblock/ipblock.go | 205 ++++++------- pkg/ipblock/ipblock_test.go | 15 +- pkg/netp/common.go | 20 ++ pkg/netp/icmp.go | 128 ++++++++ pkg/netp/tcpudp.go | 35 +++ 15 files changed, 1376 insertions(+), 466 deletions(-) create mode 100644 pkg/connection/connectionset.go create mode 100644 pkg/connection/connectionset_test.go create mode 100644 pkg/connection/statefulness.go create mode 100644 pkg/connection/statefulness_test.go create mode 100644 pkg/netp/common.go create mode 100644 pkg/netp/icmp.go create mode 100644 pkg/netp/tcpudp.go diff --git a/Makefile b/Makefile index ced09a0..7f6abe2 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -REPOSITORY := github.com/np-guard/common +REPOSITORY := github.com/np-guard/models mod: go.mod @echo -- $@ -- diff --git a/pkg/connection/connectionset.go b/pkg/connection/connectionset.go new file mode 100644 index 0000000..c83d72e --- /dev/null +++ b/pkg/connection/connectionset.go @@ -0,0 +1,442 @@ +// Copyright 2020- IBM Inc. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package connection + +import ( + "log" + "sort" + "strings" + + "github.com/np-guard/models/pkg/hypercube" + "github.com/np-guard/models/pkg/interval" + "github.com/np-guard/models/pkg/netp" +) + +const ( + TCPCode = 0 + UDPCode = 1 + ICMPCode = 2 + MinICMPtype int64 = 0 + MaxICMPtype int64 = netp.InformationReply + MinICMPcode int64 = 0 + MaxICMPcode int64 = 5 + minProtocol int64 = 0 + maxProtocol int64 = 2 + MinPort = 1 + MaxPort = netp.MaxPort +) + +const ( + AllConnections = "All Connections" + NoConnections = "No Connections" +) + +type Dimension int + +const ( + protocol Dimension = 0 + srcPort Dimension = 1 + dstPort Dimension = 2 + icmpType Dimension = 3 + icmpCode Dimension = 4 + numDimensions = 5 +) + +const propertySeparator string = " " + +// dimensionsList is the ordered list of dimensions in the Set object +// this should be the only place where the order is hard-coded +var dimensionsList = []Dimension{ + protocol, + srcPort, + dstPort, + icmpType, + icmpCode, +} + +func entireDimension(dim Dimension) *interval.CanonicalSet { + switch dim { + case protocol: + return interval.CreateSetFromInterval(minProtocol, maxProtocol) + case srcPort: + return interval.CreateSetFromInterval(MinPort, MaxPort) + case dstPort: + return interval.CreateSetFromInterval(MinPort, MaxPort) + case icmpType: + return interval.CreateSetFromInterval(MinICMPtype, MaxICMPtype) + case icmpCode: + return interval.CreateSetFromInterval(MinICMPcode, MaxICMPcode) + } + return nil +} + +func getDimensionDomainsList() []*interval.CanonicalSet { + res := make([]*interval.CanonicalSet, len(dimensionsList)) + for i := range dimensionsList { + res[i] = entireDimension(dimensionsList[i]) + } + return res +} + +type Set struct { + AllowAll bool + connectionProperties *hypercube.CanonicalSet + IsStateful StatefulState +} + +func newSet(all bool) *Set { + return &Set{AllowAll: all, connectionProperties: hypercube.NewCanonicalSet(numDimensions)} +} + +func All() *Set { + return newSet(true) +} + +func None() *Set { + return newSet(false) +} + +func (conn *Set) Copy() *Set { + return &Set{ + AllowAll: conn.AllowAll, + connectionProperties: conn.connectionProperties.Copy(), + IsStateful: conn.IsStateful, + } +} + +func (conn *Set) Intersect(other *Set) *Set { + if other.AllowAll { + return conn.Copy() + } + if conn.AllowAll { + return other.Copy() + } + return &Set{AllowAll: false, connectionProperties: conn.connectionProperties.Intersect(other.connectionProperties)} +} + +func (conn *Set) IsEmpty() bool { + if conn.AllowAll { + return false + } + return conn.connectionProperties.IsEmpty() +} + +func (conn *Set) Union(other *Set) *Set { + if conn.AllowAll || other.AllowAll { + return All() + } + if other.IsEmpty() { + return conn.Copy() + } + if conn.IsEmpty() { + return other.Copy() + } + res := &Set{ + AllowAll: false, + connectionProperties: conn.connectionProperties.Union(other.connectionProperties), + } + res.canonicalize() + return res +} + +func getAllPropertiesObject() *hypercube.CanonicalSet { + return hypercube.FromCube(getDimensionDomainsList()) +} + +// Subtract +// ToDo: Subtract seems to ignore IsStateful (see https://github.com/np-guard/vpc-network-config-analyzer/issues/199): +// 1. is the delta connection stateful +// 2. connectionProperties is identical but conn stateful while other is not +// the 2nd item can be computed here, with enhancement to relevant structure +// the 1st can not since we do not know where exactly the statefulness came from +func (conn *Set) Subtract(other *Set) *Set { + if conn.IsEmpty() || other.AllowAll { + return None() + } + if other.IsEmpty() { + return conn.Copy() + } + var connProperties *hypercube.CanonicalSet + if conn.AllowAll { + connProperties = getAllPropertiesObject() + } else { + connProperties = conn.connectionProperties + } + return &Set{AllowAll: false, connectionProperties: connProperties.Subtract(other.connectionProperties)} +} + +// ContainedIn returns true if conn is subset of other +func (conn *Set) ContainedIn(other *Set) bool { + if other.AllowAll { + return true + } + if conn.AllowAll { + return false + } + res, err := conn.connectionProperties.ContainedIn(other.connectionProperties) + if err != nil { + log.Fatalf("invalid connection set. %e", err) + } + return res +} + +func ProtocolStringToCode(protocol netp.ProtocolString) int64 { + switch protocol { + case netp.ProtocolStringTCP: + return TCPCode + case netp.ProtocolStringUDP: + return UDPCode + case netp.ProtocolStringICMP: + return ICMPCode + } + log.Fatalf("Impossible protocol code %v", protocol) + return 0 +} + +func (conn *Set) addConnection(protocol netp.ProtocolString, + srcMinP, srcMaxP, dstMinP, dstMaxP, + icmpTypeMin, icmpTypeMax, icmpCodeMin, icmpCodeMax int64) { + code := ProtocolStringToCode(protocol) + cube := hypercube.FromCubeShort(code, code, + srcMinP, srcMaxP, dstMinP, dstMaxP, + icmpTypeMin, icmpTypeMax, icmpCodeMin, icmpCodeMax) + conn.connectionProperties = conn.connectionProperties.Union(cube) + conn.canonicalize() +} + +func (conn *Set) canonicalize() { + if !conn.AllowAll && conn.connectionProperties.Equal(getAllPropertiesObject()) { + conn.AllowAll = true + conn.connectionProperties = hypercube.NewCanonicalSet(numDimensions) + } +} + +func TCPorUDPConnection(protocol netp.ProtocolString, srcMinP, srcMaxP, dstMinP, dstMaxP int64) *Set { + conn := None() + conn.addConnection(protocol, + srcMinP, srcMaxP, dstMinP, dstMaxP, + MinICMPtype, MaxICMPtype, MinICMPcode, MaxICMPcode) + return conn +} + +func ICMPConnection(icmpTypeMin, icmpTypeMax, icmpCodeMin, icmpCodeMax int64) *Set { + conn := None() + conn.addConnection(netp.ProtocolStringICMP, + MinPort, MaxPort, MinPort, MaxPort, + icmpTypeMin, icmpTypeMax, icmpCodeMin, icmpCodeMax) + return conn +} + +func (conn *Set) Equal(other *Set) bool { + if conn.AllowAll != other.AllowAll { + return false + } + if conn.AllowAll { + return true + } + return conn.connectionProperties.Equal(other.connectionProperties) +} + +func protocolStringFromCode(protocolCode int64) netp.ProtocolString { + switch protocolCode { + case TCPCode: + return netp.ProtocolStringTCP + case UDPCode: + return netp.ProtocolStringUDP + case ICMPCode: + return netp.ProtocolStringICMP + } + log.Fatalf("impossible protocol code %v", protocolCode) + return "" +} + +func getDimensionString(dimValue *interval.CanonicalSet, dim Dimension) string { + if dimValue.Equal(entireDimension(dim)) { + // avoid adding dimension str on full dimension values + return "" + } + switch dim { + case protocol: + pList := []string{} + for code := minProtocol; code <= maxProtocol; code++ { + if dimValue.Contains(code) { + pList = append(pList, string(protocolStringFromCode(code))) + } + } + return "protocol: " + strings.Join(pList, ",") + case srcPort: + return "src-ports: " + dimValue.String() + case dstPort: + return "dst-ports: " + dimValue.String() + case icmpType: + return "icmp-type: " + dimValue.String() + case icmpCode: + return "icmp-code: " + dimValue.String() + } + return "" +} + +func filterEmptyPropertiesStr(inputList []string) []string { + res := []string{} + for _, propertyStr := range inputList { + if propertyStr != "" { + res = append(res, propertyStr) + } + } + return res +} + +func getICMPbasedCubeStr(protocolsValues, icmpTypeValues, icmpCodeValues *interval.CanonicalSet) string { + strList := []string{ + getDimensionString(protocolsValues, protocol), + getDimensionString(icmpTypeValues, icmpType), + getDimensionString(icmpCodeValues, icmpCode), + } + return strings.Join(filterEmptyPropertiesStr(strList), propertySeparator) +} + +func getPortBasedCubeStr(protocolsValues, srcPortsValues, dstPortsValues *interval.CanonicalSet) string { + strList := []string{ + getDimensionString(protocolsValues, protocol), + getDimensionString(srcPortsValues, srcPort), + getDimensionString(dstPortsValues, dstPort), + } + return strings.Join(filterEmptyPropertiesStr(strList), propertySeparator) +} + +func getMixedProtocolsCubeStr(protocols *interval.CanonicalSet) string { + // TODO: make sure other dimension values are full + return getDimensionString(protocols, protocol) +} + +func getConnsCubeStr(cube []*interval.CanonicalSet) string { + protocols := cube[protocol] + if (protocols.Contains(TCPCode) || protocols.Contains(UDPCode)) && !protocols.Contains(ICMPCode) { + return getPortBasedCubeStr(protocols, cube[srcPort], cube[dstPort]) + } + if protocols.Contains(ICMPCode) && !(protocols.Contains(TCPCode) || protocols.Contains(UDPCode)) { + return getICMPbasedCubeStr(protocols, cube[icmpType], cube[icmpCode]) + } + return getMixedProtocolsCubeStr(protocols) +} + +// String returns a string representation of a Set object +func (conn *Set) String() string { + if conn.AllowAll { + return AllConnections + } else if conn.IsEmpty() { + return NoConnections + } + resStrings := []string{} + // get cubes and cube str per each cube + cubes := conn.connectionProperties.GetCubesList() + for _, cube := range cubes { + resStrings = append(resStrings, getConnsCubeStr(cube)) + } + + sort.Strings(resStrings) + return strings.Join(resStrings, "; ") +} + +func getCubeAsTCPorUDPItems(cube []*interval.CanonicalSet, isTCP bool) []netp.Protocol { + tcpItemsTemp := []netp.Protocol{} + // consider src ports + srcPorts := cube[srcPort] + if srcPorts.Equal(entireDimension(srcPort)) { + tcpItemsTemp = append(tcpItemsTemp, netp.TCPUDP{IsTCP: isTCP}) + } else { + // iterate the intervals in the interval-set + for _, portRange := range srcPorts.Intervals() { + tcpRes := netp.TCPUDP{ + IsTCP: isTCP, + PortRangePair: netp.PortRangePair{ + SrcPort: portRange, + DstPort: interval.Interval{Start: netp.MinPort, End: netp.MaxPort}, + }, + } + tcpItemsTemp = append(tcpItemsTemp, tcpRes) + } + } + // consider dst ports + dstPorts := cube[dstPort] + if dstPorts.Equal(entireDimension(dstPort)) { + return tcpItemsTemp + } + tcpItemsFinal := []netp.Protocol{} + for _, portRange := range dstPorts.Intervals() { + for _, tcpItemTemp := range tcpItemsTemp { + item, _ := tcpItemTemp.(netp.TCPUDP) + tcpItemsFinal = append(tcpItemsFinal, netp.TCPUDP{ + IsTCP: isTCP, + PortRangePair: netp.PortRangePair{ + SrcPort: item.PortRangePair.SrcPort, + DstPort: portRange, + }, + }) + } + } + return tcpItemsFinal +} + +func getCubeAsICMPItems(cube []*interval.CanonicalSet) []netp.Protocol { + icmpTypes := cube[icmpType] + icmpCodes := cube[icmpCode] + if icmpCodes.Equal(entireDimension(icmpCode)) { + if icmpTypes.Equal(entireDimension(icmpType)) { + return []netp.Protocol{netp.ICMP{}} + } + res := []netp.Protocol{} + for _, t := range icmpTypes.Elements() { + icmp, err := netp.NewICMP(&netp.ICMPTypeCode{Type: int(t)}) + if err != nil { + log.Panic(err) + } + res = append(res, icmp) + } + return res + } + + // iterate both codes and types + res := []netp.Protocol{} + for _, t := range icmpTypes.Elements() { + codes := icmpCodes.Elements() + for i := range codes { + // TODO: merge when all codes for certain type exist + c := int(codes[i]) + icmp, err := netp.NewICMP(&netp.ICMPTypeCode{Type: int(t), Code: &c}) + if err != nil { + log.Panic(err) + } + res = append(res, icmp) + } + } + return res +} + +type Details []netp.Protocol + +func ConnToJSONRep(c *Set) Details { + if c == nil { + return nil // one of the connections in connectionDiff can be empty + } + if c.AllowAll { + return []netp.Protocol{netp.AnyProtocol{}} + } + var res []netp.Protocol + + cubes := c.connectionProperties.GetCubesList() + for _, cube := range cubes { + protocols := cube[protocol] + if protocols.Contains(TCPCode) { + res = append(res, getCubeAsTCPorUDPItems(cube, true)...) + } + if protocols.Contains(UDPCode) { + res = append(res, getCubeAsTCPorUDPItems(cube, false)...) + } + if protocols.Contains(ICMPCode) { + res = append(res, getCubeAsICMPItems(cube)...) + } + } + + return res +} diff --git a/pkg/connection/connectionset_test.go b/pkg/connection/connectionset_test.go new file mode 100644 index 0000000..a949dc6 --- /dev/null +++ b/pkg/connection/connectionset_test.go @@ -0,0 +1,58 @@ +// Copyright 2020- IBM Inc. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package connection_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/np-guard/models/pkg/connection" + "github.com/np-guard/models/pkg/netp" +) + +const ICMPValue = netp.DestinationUnreachable + +func TestAllConnections(t *testing.T) { + c := connection.All() + require.Equal(t, "All Connections", c.String()) +} + +func TestNoConnections(t *testing.T) { + c := connection.None() + require.Equal(t, "No Connections", c.String()) +} + +func TestBasicSetICMP(t *testing.T) { + c := connection.ICMPConnection(ICMPValue, ICMPValue, 5, 5) + require.Equal(t, "protocol: ICMP icmp-type: 3 icmp-code: 5", c.String()) +} + +func TestBasicSetTCP(t *testing.T) { + e := connection.TCPorUDPConnection(netp.ProtocolStringTCP, 1, 65535, 1, 65535) + require.Equal(t, "protocol: TCP", e.String()) + + c := connection.All().Subtract(e) + require.Equal(t, "protocol: UDP,ICMP", c.String()) + + c = c.Union(e) + require.Equal(t, "All Connections", c.String()) +} + +func TestBasicSet2(t *testing.T) { + except1 := connection.ICMPConnection(ICMPValue, ICMPValue, 5, 5) + + except2 := connection.TCPorUDPConnection(netp.ProtocolStringTCP, 1, 65535, 1, 65535) + + d := connection.All().Subtract(except1).Subtract(except2) + require.Equal(t, ""+ + "protocol: ICMP icmp-type: 0-2,4-16; "+ + "protocol: ICMP icmp-type: 3 icmp-code: 0-4; "+ + "protocol: UDP", d.String()) +} + +func TestBasicSet3(t *testing.T) { + c := connection.ICMPConnection(ICMPValue, ICMPValue, 5, 5) + d := connection.All().Subtract(c).Union(connection.ICMPConnection(ICMPValue, ICMPValue, 5, 5)) + require.Equal(t, "All Connections", d.String()) +} diff --git a/pkg/connection/statefulness.go b/pkg/connection/statefulness.go new file mode 100644 index 0000000..3805d8e --- /dev/null +++ b/pkg/connection/statefulness.go @@ -0,0 +1,84 @@ +// Copyright 2020- IBM Inc. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package connection + +import ( + "github.com/np-guard/models/pkg/hypercube" + "github.com/np-guard/models/pkg/netp" +) + +// default is StatefulUnknown +type StatefulState int + +const ( + // StatefulUnknown is the default value for a Set object, + StatefulUnknown StatefulState = 0 + // StatefulTrue represents a connection object for which any allowed TCP (on all allowed src/dst ports) + // has an allowed response connection + StatefulTrue StatefulState = 1 + // StatefulFalse represents a connection object for which there exists some allowed TCP + // (on any allowed subset from the allowed src/dst ports) that does not have an allowed response connection + StatefulFalse StatefulState = 2 +) + +// EnhancedString returns a connection string with possibly added asterisk for stateless connection +func (conn *Set) EnhancedString() string { + if conn.IsStateful == StatefulFalse { + return conn.String() + " *" + } + return conn.String() +} + +func newTCPSet() *Set { + return TCPorUDPConnection(netp.ProtocolStringTCP, MinPort, MaxPort, MinPort, MaxPort) +} + +// ConnectionWithStatefulness updates `conn` object with `IsStateful` property, based on input `secondDirectionConn`. +// `conn` represents a src-to-dst connection, and `secondDirectionConn` represents dst-to-src connection. +// The property `IsStateful` of `conn` is set as `StatefulFalse` if there is at least some subset within TCP from `conn` +// which is not stateful (such that the response direction for this subset is not enabled). +// This function also returns a connection object with the exact subset of the stateful part (within TCP) +// from the entire connection `conn`, and with the original connections on other protocols. +func (conn *Set) ConnectionWithStatefulness(secondDirectionConn *Set) *Set { + connTCP := conn.Intersect(newTCPSet()) + if connTCP.IsEmpty() { + conn.IsStateful = StatefulTrue + return conn + } + statefulCombinedConnTCP := connTCP.connTCPWithStatefulness(secondDirectionConn.Intersect(newTCPSet())) + conn.IsStateful = connTCP.IsStateful + return conn.Subtract(connTCP).Union(statefulCombinedConnTCP) +} + +// connTCPWithStatefulness assumes that both `conn` and `secondDirectionConn` are within TCP. +// it assigns IsStateful a value within `conn`, and returns the subset from `conn` which is stateful. +func (conn *Set) connTCPWithStatefulness(secondDirectionConn *Set) *Set { + // flip src/dst ports before intersection + statefulCombinedConn := conn.Intersect(secondDirectionConn.switchSrcDstPortsOnTCP()) + if conn.Equal(statefulCombinedConn) { + conn.IsStateful = StatefulTrue + } else { + conn.IsStateful = StatefulFalse + } + return statefulCombinedConn +} + +// switchSrcDstPortsOnTCP returns a new Set object, built from the input Set object. +// It assumes the input connection object is only within TCP protocol. +// For TCP the src and dst ports on relevant cubes are being switched. +func (conn *Set) switchSrcDstPortsOnTCP() *Set { + if conn.AllowAll || conn.IsEmpty() { + return conn.Copy() + } + res := None() + for _, cube := range conn.connectionProperties.GetCubesList() { + // assuming cube[protocol] contains TCP only + // no need to switch if src equals dst + if !cube[srcPort].Equal(cube[dstPort]) { + cube = hypercube.CopyCube(cube) + cube[srcPort], cube[dstPort] = cube[dstPort], cube[srcPort] + } + res.connectionProperties = res.connectionProperties.Union(hypercube.FromCube(cube)) + } + return res +} diff --git a/pkg/connection/statefulness_test.go b/pkg/connection/statefulness_test.go new file mode 100644 index 0000000..3c5b111 --- /dev/null +++ b/pkg/connection/statefulness_test.go @@ -0,0 +1,148 @@ +// Copyright 2020- IBM Inc. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package connection_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/np-guard/models/pkg/connection" + "github.com/np-guard/models/pkg/netp" +) + +func newTCPConn(t *testing.T, srcMinP, srcMaxP, dstMinP, dstMaxP int64) *connection.Set { + t.Helper() + return connection.TCPorUDPConnection(netp.ProtocolStringTCP, srcMinP, srcMaxP, dstMinP, dstMaxP) +} + +func newUDPConn(t *testing.T, srcMinP, srcMaxP, dstMinP, dstMaxP int64) *connection.Set { + t.Helper() + return connection.TCPorUDPConnection(netp.ProtocolStringUDP, srcMinP, srcMaxP, dstMinP, dstMaxP) +} + +func newICMPconn(t *testing.T) *connection.Set { + t.Helper() + return connection.ICMPConnection( + connection.MinICMPtype, connection.MaxICMPtype, + connection.MinICMPcode, connection.MaxICMPcode) +} + +func newTCPUDPSet(t *testing.T, p netp.ProtocolString) *connection.Set { + t.Helper() + return connection.TCPorUDPConnection(p, + connection.MinPort, connection.MaxPort, + connection.MinPort, connection.MaxPort) +} + +type statefulnessTest struct { + name string + srcToDst *connection.Set + dstToSrc *connection.Set + // expectedIsStateful represents the expected IsStateful computed value for srcToDst, + // which should be either StatefulTrue or StatefulFalse, given the input dstToSrc connection. + // the computation applies only to the TCP protocol within those connections. + expectedIsStateful connection.StatefulState + // expectedStatefulConn represents the subset from srcToDst which is not related to the "non-stateful" mark (*) on the srcToDst connection, + // the stateless part for TCP is srcToDst.Subtract(statefulConn) + expectedStatefulConn *connection.Set +} + +func (tt statefulnessTest) runTest(t *testing.T) { + t.Helper() + statefulConn := tt.srcToDst.ConnectionWithStatefulness(tt.dstToSrc) + require.Equal(t, tt.expectedIsStateful, tt.srcToDst.IsStateful) + require.True(t, tt.expectedStatefulConn.Equal(statefulConn)) +} + +func TestAll(t *testing.T) { + var testCasesStatefulness = []statefulnessTest{ + { + name: "tcp_all_ports_on_both_directions", + srcToDst: newTCPUDPSet(t, netp.ProtocolStringTCP), // TCP all ports + dstToSrc: newTCPUDPSet(t, netp.ProtocolStringTCP), // TCP all ports + expectedIsStateful: connection.StatefulTrue, + expectedStatefulConn: newTCPUDPSet(t, netp.ProtocolStringTCP), // TCP all ports + }, + { + name: "first_all_cons_second_tcp_with_ports", + srcToDst: connection.All(), // all connections + dstToSrc: newTCPConn(t, 80, 80, connection.MinPort, connection.MaxPort), // TCP , src-ports: 80, dst-ports: all + + // there is a subset of the tcp connection which is not stateful + expectedIsStateful: connection.StatefulFalse, + + // TCP src-ports: all, dst-port: 80 , union: all non-TCP conns + expectedStatefulConn: connection.All().Subtract(newTCPUDPSet(t, netp.ProtocolStringTCP)).Union( + newTCPConn(t, connection.MinPort, connection.MaxPort, 80, 80)), + }, + { + name: "first_all_conns_second_no_tcp", + srcToDst: connection.All(), // all connections + dstToSrc: newICMPconn(t), // ICMP + expectedIsStateful: connection.StatefulFalse, + // UDP, ICMP (all TCP is considered stateless here) + expectedStatefulConn: connection.All().Subtract(newTCPUDPSet(t, netp.ProtocolStringTCP)), + }, + { + name: "tcp_with_ports_both_directions_exact_match", + srcToDst: newTCPConn(t, 80, 80, 443, 443), + dstToSrc: newTCPConn(t, 443, 443, 80, 80), + expectedIsStateful: connection.StatefulTrue, + expectedStatefulConn: newTCPConn(t, 80, 80, 443, 443), + }, + { + name: "tcp_with_ports_both_directions_partial_match", + srcToDst: newTCPConn(t, 80, 100, 443, 443), + dstToSrc: newTCPConn(t, 443, 443, 80, 80), + expectedIsStateful: connection.StatefulFalse, + expectedStatefulConn: newTCPConn(t, 80, 80, 443, 443), + }, + { + name: "tcp_with_ports_both_directions_no_match", + srcToDst: newTCPConn(t, 80, 100, 443, 443), + dstToSrc: newTCPConn(t, 80, 80, 80, 80), + expectedIsStateful: connection.StatefulFalse, + expectedStatefulConn: connection.None(), + }, + { + name: "udp_and_tcp_with_ports_both_directions_no_match", + srcToDst: newTCPConn(t, 80, 100, 443, 443).Union(newUDPConn(t, 80, 100, 443, 443)), + dstToSrc: newTCPConn(t, 80, 80, 80, 80).Union(newUDPConn(t, 80, 80, 80, 80)), + expectedIsStateful: connection.StatefulFalse, + expectedStatefulConn: newUDPConn(t, 80, 100, 443, 443), + }, + { + name: "no_tcp_in_first_direction", + srcToDst: newUDPConn(t, 70, 100, 443, 443), + dstToSrc: newTCPConn(t, 70, 80, 80, 80).Union(newUDPConn(t, 70, 80, 80, 80)), + expectedIsStateful: connection.StatefulTrue, + expectedStatefulConn: newUDPConn(t, 70, 100, 443, 443), + }, + { + name: "empty_conn_in_first_direction", + srcToDst: connection.None(), + dstToSrc: newTCPConn(t, 80, 80, 80, 80).Union(newTCPUDPSet(t, netp.ProtocolStringUDP)), + expectedIsStateful: connection.StatefulTrue, + expectedStatefulConn: connection.None(), + }, + { + name: "only_udp_icmp_in_first_direction_and_empty_second_direction", + srcToDst: newTCPUDPSet(t, netp.ProtocolStringUDP).Union(newICMPconn(t)), + dstToSrc: connection.None(), + // stateful analysis does not apply to udp/icmp, thus considered in the result as "stateful" + // (to avoid marking it as stateless in the output) + expectedIsStateful: connection.StatefulTrue, + expectedStatefulConn: newTCPUDPSet(t, netp.ProtocolStringUDP).Union(newICMPconn(t)), + }, + } + t.Parallel() + // explainTests is the list of tests to run + for testIdx := range testCasesStatefulness { + tt := testCasesStatefulness[testIdx] + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + tt.runTest(t) + }) + } +} diff --git a/pkg/hypercube/hypercubeset.go b/pkg/hypercube/hypercubeset.go index cac9943..d1a0139 100644 --- a/pkg/hypercube/hypercubeset.go +++ b/pkg/hypercube/hypercubeset.go @@ -1,3 +1,5 @@ +// Copyright 2020- IBM Inc. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 package hypercube import ( @@ -53,40 +55,39 @@ func (c *CanonicalSet) Union(other *CanonicalSet) *CanonicalSet { if c.dimensions != other.dimensions { return nil } - res := NewCanonicalSet(c.dimensions) remainingFromOther := map[*interval.CanonicalSet]*interval.CanonicalSet{} - for k := range other.layers { - kCopy := k.Copy() - remainingFromOther[k] = &kCopy + for otherKey := range other.layers { + remainingFromOther[otherKey] = otherKey.Copy() } + layers := map[*interval.CanonicalSet]*CanonicalSet{} for k, v := range c.layers { - remainingFromSelf := copyIntervalSet(k) + remainingFromSelf := k.Copy() for otherKey, otherVal := range other.layers { - commonElem := copyIntervalSet(k) - commonElem.Intersect(*otherKey) + commonElem := k.Intersect(otherKey) if commonElem.IsEmpty() { continue } - remainingFromOther[otherKey].Subtract(*commonElem) - remainingFromSelf.Subtract(*commonElem) - if c.dimensions == 1 { - res.layers[commonElem] = NewCanonicalSet(0) - continue + remainingFromOther[otherKey] = remainingFromOther[otherKey].Subtract(commonElem) + remainingFromSelf = remainingFromSelf.Subtract(commonElem) + newSubElem := NewCanonicalSet(0) + if c.dimensions != 1 { + newSubElem = v.Union(otherVal) } - newSubElem := v.Union(otherVal) - res.layers[commonElem] = newSubElem + layers[commonElem] = newSubElem } if !remainingFromSelf.IsEmpty() { - res.layers[remainingFromSelf] = v.Copy() + layers[remainingFromSelf] = v.Copy() } } for k, v := range remainingFromOther { if !v.IsEmpty() { - res.layers[v] = other.layers[k].Copy() + layers[v] = other.layers[k].Copy() } } - res.applyElementsUnionPerLayer() - return res + return &CanonicalSet{ + layers: getElementsUnionPerLayer(layers), + dimensions: c.dimensions, + } } // IsEmpty returns true if c is empty @@ -99,26 +100,28 @@ func (c *CanonicalSet) Intersect(other *CanonicalSet) *CanonicalSet { if c.dimensions != other.dimensions { return nil } - res := NewCanonicalSet(c.dimensions) + + layers := map[*interval.CanonicalSet]*CanonicalSet{} for k, v := range c.layers { for otherKey, otherVal := range other.layers { - commonELem := copyIntervalSet(k) - commonELem.Intersect(*otherKey) + commonELem := k.Intersect(otherKey) if commonELem.IsEmpty() { continue } if c.dimensions == 1 { - res.layers[commonELem] = NewCanonicalSet(0) + layers[commonELem] = NewCanonicalSet(0) continue } newSubElem := v.Intersect(otherVal) if !newSubElem.IsEmpty() { - res.layers[commonELem] = newSubElem + layers[commonELem] = newSubElem } } } - res.applyElementsUnionPerLayer() - return res + return &CanonicalSet{ + layers: getElementsUnionPerLayer(layers), + dimensions: c.dimensions, + } } // Subtract returns a new CanonicalSet object that results from subtraction other from c @@ -126,41 +129,42 @@ func (c *CanonicalSet) Subtract(other *CanonicalSet) *CanonicalSet { if c.dimensions != other.dimensions { return nil } - res := NewCanonicalSet(c.dimensions) + layers := map[*interval.CanonicalSet]*CanonicalSet{} for k, v := range c.layers { - remainingFromSelf := copyIntervalSet(k) + remainingFromSelf := k.Copy() for otherKey, otherVal := range other.layers { - commonELem := copyIntervalSet(k) - commonELem.Intersect(*otherKey) - if commonELem.IsEmpty() { + commonElem := k.Intersect(otherKey) + if commonElem.IsEmpty() { continue } - remainingFromSelf.Subtract(*commonELem) + remainingFromSelf = remainingFromSelf.Subtract(commonElem) if c.dimensions == 1 { continue } newSubElem := v.Subtract(otherVal) if !newSubElem.IsEmpty() { - res.layers[commonELem] = newSubElem + layers[commonElem] = newSubElem } } if !remainingFromSelf.IsEmpty() { - res.layers[remainingFromSelf] = v.Copy() + layers[remainingFromSelf] = v.Copy() } } - res.applyElementsUnionPerLayer() - return res + return &CanonicalSet{ + layers: getElementsUnionPerLayer(layers), + dimensions: c.dimensions, + } } func (c *CanonicalSet) getIntervalSetUnion() *interval.CanonicalSet { res := interval.NewCanonicalIntervalSet() for k := range c.layers { - res.Union(*k) + res = res.Union(k) } return res } -// ContainedIn returns true ic other contained in c +// ContainedIn returns true if c is subset of other func (c *CanonicalSet) ContainedIn(other *CanonicalSet) (bool, error) { if c.dimensions != other.dimensions { return false, errors.New("ContainedIn mismatch between num of dimensions for input args") @@ -171,17 +175,14 @@ func (c *CanonicalSet) ContainedIn(other *CanonicalSet) (bool, error) { } cInterval := c.getIntervalSetUnion() otherInterval := other.getIntervalSetUnion() - return cInterval.ContainedIn(*otherInterval), nil + return cInterval.ContainedIn(otherInterval), nil } isSubsetCount := 0 - for k, v := range c.layers { - currentLayer := copyIntervalSet(k) + for currentLayer, v := range c.layers { for otherKey, otherVal := range other.layers { - commonKey := copyIntervalSet(currentLayer) - commonKey.Intersect(*otherKey) - remaining := copyIntervalSet(currentLayer) - remaining.Subtract(*commonKey) + commonKey := currentLayer.Intersect(otherKey) + remaining := currentLayer.Subtract(commonKey) if !commonKey.IsEmpty() { subContainment, err := v.ContainedIn(otherVal) if !subContainment || err != nil { @@ -203,8 +204,7 @@ func (c *CanonicalSet) ContainedIn(other *CanonicalSet) (bool, error) { func (c *CanonicalSet) Copy() *CanonicalSet { res := NewCanonicalSet(c.dimensions) for k, v := range c.layers { - newKey := k.Copy() - res.layers[&newKey] = v.Copy() + res.layers[k.Copy()] = v.Copy() } return res } @@ -248,13 +248,13 @@ func (c *CanonicalSet) GetCubesList() [][]*interval.CanonicalSet { return res } -func (c *CanonicalSet) applyElementsUnionPerLayer() { +func getElementsUnionPerLayer(layers map[*interval.CanonicalSet]*CanonicalSet) map[*interval.CanonicalSet]*CanonicalSet { type pair struct { hc *CanonicalSet // hypercube set object is []*interval.CanonicalSet // interval-set list } equivClasses := map[string]*pair{} - for k, v := range c.layers { + for k, v := range layers { if _, ok := equivClasses[v.String()]; ok { equivClasses[v.String()].is = append(equivClasses[v.String()].is, k) } else { @@ -266,11 +266,11 @@ func (c *CanonicalSet) applyElementsUnionPerLayer() { newVal := p.hc newKey := p.is[0] for i := 1; i < len(p.is); i += 1 { - newKey.Union(*p.is[i]) + newKey = newKey.Union(p.is[i]) } newLayers[newKey] = newVal } - c.layers = newLayers + return newLayers } // FromCube returns a new CanonicalSet created from a single input cube @@ -281,13 +281,11 @@ func FromCube(cube []*interval.CanonicalSet) *CanonicalSet { } if len(cube) == 1 { res := NewCanonicalSet(1) - cubeVal := cube[0].Copy() - res.layers[&cubeVal] = NewCanonicalSet(0) + res.layers[cube[0].Copy()] = NewCanonicalSet(0) return res } res := NewCanonicalSet(len(cube)) - cubeVal := cube[0].Copy() - res.layers[&cubeVal] = FromCube(cube[1:]) + res.layers[cube[0].Copy()] = FromCube(cube[1:]) return res } @@ -297,12 +295,16 @@ func FromCube(cube []*interval.CanonicalSet) *CanonicalSet { func FromCubeShort(values ...int64) *CanonicalSet { cube := []*interval.CanonicalSet{} for i := 0; i < len(values); i += 2 { - cube = append(cube, interval.FromInterval(values[i], values[i+1])) + cube = append(cube, interval.CreateSetFromInterval(values[i], values[i+1])) } return FromCube(cube) } -func copyIntervalSet(a *interval.CanonicalSet) *interval.CanonicalSet { - res := a.Copy() - return &res +// CopyCube returns a new slice of intervals copied from input cube +func CopyCube(cube []*interval.CanonicalSet) []*interval.CanonicalSet { + newCube := make([]*interval.CanonicalSet, len(cube)) + for i, intervalSet := range cube { + newCube[i] = intervalSet.Copy() + } + return newCube } diff --git a/pkg/hypercube/hypercubeset_test.go b/pkg/hypercube/hypercubeset_test.go index 22c9404..454d58f 100644 --- a/pkg/hypercube/hypercubeset_test.go +++ b/pkg/hypercube/hypercubeset_test.go @@ -1,192 +1,120 @@ +// Copyright 2020- IBM Inc. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 package hypercube_test import ( "fmt" "testing" + "github.com/stretchr/testify/require" + "github.com/np-guard/models/pkg/hypercube" - "github.com/np-guard/models/pkg/interval" ) func TestHCBasic(t *testing.T) { - cube1 := []*interval.CanonicalSet{interval.FromInterval(1, 100)} - cube2 := []*interval.CanonicalSet{interval.FromInterval(1, 100)} - cube3 := []*interval.CanonicalSet{interval.FromInterval(1, 200)} - cube4 := []*interval.CanonicalSet{interval.FromInterval(1, 100), interval.FromInterval(1, 100)} - cube5 := []*interval.CanonicalSet{interval.FromInterval(1, 100), interval.FromInterval(1, 100)} - cube6 := []*interval.CanonicalSet{interval.FromInterval(1, 100), interval.FromInterval(1, 200)} - - a := hypercube.FromCube(cube1) - b := hypercube.FromCube(cube2) - c := hypercube.FromCube(cube3) - d := hypercube.FromCube(cube4) - e := hypercube.FromCube(cube5) - f := hypercube.FromCube(cube6) - - if !a.Equal(b) { - t.FailNow() - } - if a.Equal(c) { - t.FailNow() - } - if b.Equal(c) { - t.FailNow() - } - //nolint:all - /*if !c.Equal(c) { - t.FailNow() - } - if !a.Equal(a) { - t.FailNow() - } - if !b.Equal(b) { - t.FailNow() - }*/ - if !d.Equal(e) { - t.FailNow() - } - if !e.Equal(d) { - t.FailNow() - } - if d.Equal(f) { - t.FailNow() - } - if f.Equal(d) { - t.FailNow() - } + a := hypercube.FromCubeShort(1, 100) + b := hypercube.FromCubeShort(1, 100) + c := hypercube.FromCubeShort(1, 200) + d := hypercube.FromCubeShort(1, 100, 1, 100) + e := hypercube.FromCubeShort(1, 100, 1, 100) + f := hypercube.FromCubeShort(1, 100, 1, 200) + + require.True(t, a.Equal(b)) + require.True(t, b.Equal(a)) + + require.False(t, a.Equal(c)) + require.False(t, c.Equal(a)) + + require.False(t, a.Equal(d)) + require.False(t, d.Equal(a)) + + require.True(t, d.Equal(e)) + require.True(t, e.Equal(d)) + + require.False(t, d.Equal(f)) + require.False(t, f.Equal(d)) } func TestCopy(t *testing.T) { - cube1 := []*interval.CanonicalSet{interval.FromInterval(1, 100)} - a := hypercube.FromCube(cube1) + a := hypercube.FromCubeShort(1, 100) b := a.Copy() - if !a.Equal(b) { - t.FailNow() - } - if !b.Equal(a) { - t.FailNow() - } - if a == b { - t.FailNow() - } + require.True(t, a.Equal(b)) + require.True(t, b.Equal(a)) + require.True(t, a != b) } func TestString(t *testing.T) { - cube1 := []*interval.CanonicalSet{interval.FromInterval(1, 100)} - cube2 := []*interval.CanonicalSet{interval.FromInterval(1, 100), interval.FromInterval(1, 100)} - a := hypercube.FromCube(cube1) - b := hypercube.FromCube(cube2) - fmt.Println(a.String()) - fmt.Println(b.String()) - fmt.Println("done") + require.Equal(t, "[(1-3)]", hypercube.FromCubeShort(1, 3).String()) + require.Equal(t, "[(1-3),(2-4)]", hypercube.FromCubeShort(1, 3, 2, 4).String()) } func TestOr(t *testing.T) { - cube1 := []*interval.CanonicalSet{interval.FromInterval(1, 100), interval.FromInterval(1, 100)} - cube2 := []*interval.CanonicalSet{interval.FromInterval(1, 90), interval.FromInterval(1, 200)} - a := hypercube.FromCube(cube1) - b := hypercube.FromCube(cube2) + a := hypercube.FromCubeShort(1, 100, 1, 100) + b := hypercube.FromCubeShort(1, 90, 1, 200) c := a.Union(b) - fmt.Println(a.String()) - fmt.Println(b.String()) - fmt.Println(c.String()) - fmt.Println("done") -} - -func addCube1Dim(o *hypercube.CanonicalSet, start, end int64) *hypercube.CanonicalSet { - cube := []*interval.CanonicalSet{interval.FromInterval(start, end)} - a := hypercube.FromCube(cube) - return o.Union(a) + require.Equal(t, "[(1-90),(1-200)]; [(91-100),(1-100)]", c.String()) } -func addCube2Dim(o *hypercube.CanonicalSet, start1, end1, start2, end2 int64) *hypercube.CanonicalSet { - cube := []*interval.CanonicalSet{interval.FromInterval(start1, end1), interval.FromInterval(start2, end2)} - a := hypercube.FromCube(cube) - return o.Union(a) -} - -func addCube3Dim(o *hypercube.CanonicalSet, s1, e1, s2, e2, s3, e3 int64) *hypercube.CanonicalSet { - cube := []*interval.CanonicalSet{ - interval.FromInterval(s1, e1), - interval.FromInterval(s2, e2), - interval.FromInterval(s3, e3)} - a := hypercube.FromCube(cube) - return o.Union(a) +func addCube(o *hypercube.CanonicalSet, bounds ...int64) *hypercube.CanonicalSet { + return o.Union(hypercube.FromCubeShort(bounds...)) } func TestBasic(t *testing.T) { a := hypercube.NewCanonicalSet(1) - a = addCube1Dim(a, 1, 2) - a = addCube1Dim(a, 5, 6) - a = addCube1Dim(a, 3, 4) - b := hypercube.NewCanonicalSet(1) - b = addCube1Dim(b, 1, 6) - if !a.Equal(b) { - t.FailNow() - } + a = addCube(a, 1, 2) + a = addCube(a, 5, 6) + a = addCube(a, 3, 4) + b := hypercube.FromCubeShort(1, 6) + require.True(t, a.Equal(b)) } func TestBasic2(t *testing.T) { a := hypercube.NewCanonicalSet(2) - a = addCube2Dim(a, 1, 2, 1, 5) - a = addCube2Dim(a, 1, 2, 7, 9) - a = addCube2Dim(a, 1, 2, 6, 7) + a = addCube(a, 1, 2, 1, 5) + a = addCube(a, 1, 2, 7, 9) + a = addCube(a, 1, 2, 6, 7) b := hypercube.NewCanonicalSet(2) - b = addCube2Dim(b, 1, 2, 1, 9) - if !a.Equal(b) { - t.FailNow() - } + b = addCube(b, 1, 2, 1, 9) + require.True(t, a.Equal(b)) } func TestNew(t *testing.T) { a := hypercube.NewCanonicalSet(3) - a = addCube3Dim(a, 10, 20, 10, 20, 1, 65535) - a = addCube3Dim(a, 1, 65535, 15, 40, 1, 65535) - a = addCube3Dim(a, 1, 65535, 100, 200, 30, 80) - expectedStr := "[(1-9,21-65535),(100-200),(30-80)]; [(1-9,21-65535),(15-40),(1-65535)]" - expectedStr += "; [(10-20),(10-40),(1-65535)]; [(10-20),(100-200),(30-80)]" - actualStr := a.String() - if actualStr != expectedStr { - t.FailNow() - } - fmt.Println(a.String()) - fmt.Println("done") + a = addCube(a, 10, 20, 10, 20, 1, 65535) + a = addCube(a, 1, 65535, 15, 40, 1, 65535) + a = addCube(a, 1, 65535, 100, 200, 30, 80) + expectedStr := "[(1-9,21-65535),(100-200),(30-80)]; " + + "[(1-9,21-65535),(15-40),(1-65535)]; " + + "[(10-20),(10-40),(1-65535)]; " + + "[(10-20),(100-200),(30-80)]" + require.Equal(t, expectedStr, a.String()) } func checkContained(t *testing.T, a, b *hypercube.CanonicalSet, expected bool) { t.Helper() contained, err := a.ContainedIn(b) - if contained != expected || err != nil { - t.FailNow() - } -} - -func checkEqual(t *testing.T, a, b *hypercube.CanonicalSet, expected bool) { - t.Helper() - res := a.Equal(b) - if res != expected { - t.FailNow() - } + require.Nil(t, err) + require.Equal(t, expected, contained) } func TestContainedIn(t *testing.T) { a := hypercube.FromCubeShort(1, 100, 200, 300) b := hypercube.FromCubeShort(10, 80, 210, 280) checkContained(t, b, a, true) - b = addCube2Dim(b, 10, 200, 210, 280) + b = addCube(b, 10, 200, 210, 280) checkContained(t, b, a, false) } func TestContainedIn2(t *testing.T) { c := hypercube.FromCubeShort(1, 100, 200, 300) - c = addCube2Dim(c, 150, 180, 20, 300) - c = addCube2Dim(c, 200, 240, 200, 300) - c = addCube2Dim(c, 241, 300, 200, 350) + c = addCube(c, 150, 180, 20, 300) + c = addCube(c, 200, 240, 200, 300) + c = addCube(c, 241, 300, 200, 350) a := hypercube.FromCubeShort(1, 100, 200, 300) - a = addCube2Dim(a, 150, 180, 20, 300) - a = addCube2Dim(a, 200, 240, 200, 300) - a = addCube2Dim(a, 242, 300, 200, 350) + a = addCube(a, 150, 180, 20, 300) + a = addCube(a, 200, 240, 200, 300) + a = addCube(a, 242, 300, 200, 350) d := hypercube.FromCubeShort(210, 220, 210, 280) e := hypercube.FromCubeShort(210, 310, 210, 280) @@ -205,9 +133,9 @@ func TestContainedIn2(t *testing.T) { func TestContainedIn3(t *testing.T) { a := hypercube.FromCubeShort(105, 105, 54, 54) b := hypercube.FromCubeShort(0, 204, 0, 255) - b = addCube2Dim(b, 205, 205, 0, 53) - b = addCube2Dim(b, 205, 205, 55, 255) - b = addCube2Dim(b, 206, 254, 0, 255) + b = addCube(b, 205, 205, 0, 53) + b = addCube(b, 205, 205, 55, 255) + b = addCube(b, 206, 254, 0, 255) checkContained(t, a, b, true) } @@ -223,42 +151,42 @@ func TestContainedIn5(t *testing.T) { checkContained(t, b, a, false) } -func TestEqual(t *testing.T) { +func TestEquals(t *testing.T) { a := hypercube.FromCubeShort(1, 2) b := hypercube.FromCubeShort(1, 2) - checkEqual(t, a, b, true) + require.True(t, a.Equal(b)) c := hypercube.FromCubeShort(1, 2, 1, 5) d := hypercube.FromCubeShort(1, 2, 1, 5) - checkEqual(t, c, d, true) - c = addCube2Dim(c, 1, 2, 7, 9) - c = addCube2Dim(c, 1, 2, 6, 7) - c = addCube2Dim(c, 4, 8, 1, 9) + require.True(t, c.Equal(d)) + c = addCube(c, 1, 2, 7, 9) + c = addCube(c, 1, 2, 6, 7) + c = addCube(c, 4, 8, 1, 9) res := hypercube.FromCubeShort(4, 8, 1, 9) - res = addCube2Dim(res, 1, 2, 1, 9) - checkEqual(t, res, c, true) + res = addCube(res, 1, 2, 1, 9) + require.True(t, res.Equal(c)) - a = addCube1Dim(a, 5, 6) - a = addCube1Dim(a, 3, 4) + a = addCube(a, 5, 6) + a = addCube(a, 3, 4) res1 := hypercube.FromCubeShort(1, 6) - checkEqual(t, res1, a, true) + require.True(t, res1.Equal(a)) - d = addCube2Dim(d, 1, 2, 1, 5) - d = addCube2Dim(d, 5, 6, 1, 5) - d = addCube2Dim(d, 3, 4, 1, 5) + d = addCube(d, 1, 2, 1, 5) + d = addCube(d, 5, 6, 1, 5) + d = addCube(d, 3, 4, 1, 5) res2 := hypercube.FromCubeShort(1, 6, 1, 5) - checkEqual(t, res2, d, true) + require.True(t, res2.Equal(d)) } func TestBasicAddCube(t *testing.T) { a := hypercube.FromCubeShort(1, 2) - a = addCube1Dim(a, 8, 10) + a = addCube(a, 8, 10) b := a - a = addCube1Dim(a, 1, 2) - a = addCube1Dim(a, 6, 10) - a = addCube1Dim(a, 1, 10) + a = addCube(a, 1, 2) + a = addCube(a, 6, 10) + a = addCube(a, 1, 10) res := hypercube.FromCubeShort(1, 10) - checkEqual(t, res, a, true) - checkEqual(t, res, b, false) + require.True(t, res.Equal(a)) + require.NotEqual(t, res, b) } func TestBasicAddHole(t *testing.T) { a := hypercube.FromCubeShort(1, 10) @@ -268,68 +196,77 @@ func TestBasicAddHole(t *testing.T) { e := a.Subtract(hypercube.FromCubeShort(12, 14)) a = a.Subtract(hypercube.FromCubeShort(3, 7)) f := hypercube.FromCubeShort(1, 2) - f = addCube1Dim(f, 8, 10) - checkEqual(t, a, f, true) - checkEqual(t, b, hypercube.FromCubeShort(1, 2), true) - checkEqual(t, c, hypercube.NewCanonicalSet(1), true) - checkEqual(t, d, hypercube.FromCubeShort(6, 10), true) - checkEqual(t, e, hypercube.FromCubeShort(1, 10), true) + f = addCube(f, 8, 10) + require.True(t, a.Equal(f)) + require.True(t, b.Equal(hypercube.FromCubeShort(1, 2))) + require.True(t, c.Equal(hypercube.NewCanonicalSet(1))) + require.True(t, d.Equal(hypercube.FromCubeShort(6, 10))) + require.True(t, e.Equal(hypercube.FromCubeShort(1, 10))) } -func TestAddHoleBasic2(t *testing.T) { - a := hypercube.FromCubeShort(1, 100, 200, 300) - b := a.Copy() - c := a.Copy() - a = a.Subtract(hypercube.FromCubeShort(50, 60, 220, 300)) +func TestAddHoleBasic20(t *testing.T) { + a := hypercube.FromCubeShort(1, 100, 200, 300).Subtract(hypercube.FromCubeShort(50, 60, 220, 300)) resA := hypercube.FromCubeShort(61, 100, 200, 300) - resA = addCube2Dim(resA, 50, 60, 200, 219) - resA = addCube2Dim(resA, 1, 49, 200, 300) - checkEqual(t, a, resA, true) + resA = addCube(resA, 50, 60, 200, 219) + resA = addCube(resA, 1, 49, 200, 300) + require.True(t, a.Equal(resA), fmt.Sprintf("%v != %v", a, resA)) +} - b = b.Subtract(hypercube.FromCubeShort(50, 1000, 0, 250)) +func TestAddHoleBasic21(t *testing.T) { + b := hypercube.FromCubeShort(1, 100, 200, 300).Subtract(hypercube.FromCubeShort(50, 1000, 0, 250)) resB := hypercube.FromCubeShort(50, 100, 251, 300) - resB = addCube2Dim(resB, 1, 49, 200, 300) - checkEqual(t, b, resB, true) - - c = addCube2Dim(c, 400, 700, 200, 300) - c = c.Subtract(hypercube.FromCubeShort(50, 1000, 0, 250)) - resC := hypercube.FromCubeShort(50, 100, 251, 300) - resC = addCube2Dim(resC, 1, 49, 200, 300) - resC = addCube2Dim(resC, 400, 700, 251, 300) - checkEqual(t, c, resC, true) + resB = addCube(resB, 1, 49, 200, 300) + require.True(t, b.Equal(resB), fmt.Sprintf("%v != %v", b, resB)) +} + +func TestAddHoleBasic22(t *testing.T) { + a := hypercube.FromCubeShort(1, 2, 1, 2) + require.Equal(t, "[(1),(2)]; [(2),(1-2)]", a.Subtract(hypercube.FromCubeShort(1, 1, 1, 1)).String()) + require.Equal(t, "[(1),(1)]; [(2),(1-2)]", a.Subtract(hypercube.FromCubeShort(1, 1, 2, 2)).String()) + require.Equal(t, "[(1),(1-2)]; [(2),(2)]", a.Subtract(hypercube.FromCubeShort(2, 2, 1, 1)).String()) + require.Equal(t, "[(1),(1-2)]; [(2),(1)]", a.Subtract(hypercube.FromCubeShort(2, 2, 2, 2)).String()) +} + +func TestAddHoleBasic23(t *testing.T) { + a := hypercube.FromCubeShort(1, 100, 200, 300) + a = addCube(a, 400, 700, 200, 300) + require.Equal(t, "[(1-100,400-700),(200-300)]", a.String()) + a = a.Subtract(hypercube.FromCubeShort(50, 1000, 0, 250)) + require.Equal(t, "[(1-49),(200-300)]; [(50-100,400-700),(251-300)]", a.String()) } func TestAddHole(t *testing.T) { c := hypercube.FromCubeShort(1, 100, 200, 300) c = c.Subtract(hypercube.FromCubeShort(50, 60, 220, 300)) d := hypercube.FromCubeShort(1, 49, 200, 300) - d = addCube2Dim(d, 50, 60, 200, 219) - d = addCube2Dim(d, 61, 100, 200, 300) - checkEqual(t, c, d, true) + d = addCube(d, 50, 60, 200, 219) + d = addCube(d, 61, 100, 200, 300) + require.True(t, c.Equal(d)) } func TestAddHole2(t *testing.T) { c := hypercube.FromCubeShort(80, 100, 20, 300) - c = addCube2Dim(c, 250, 400, 20, 300) + c = addCube(c, 250, 400, 20, 300) c = c.Subtract(hypercube.FromCubeShort(30, 300, 100, 102)) d := hypercube.FromCubeShort(80, 100, 20, 99) - d = addCube2Dim(d, 80, 100, 103, 300) - d = addCube2Dim(d, 250, 300, 20, 99) - d = addCube2Dim(d, 250, 300, 103, 300) - d = addCube2Dim(d, 301, 400, 20, 300) - checkEqual(t, c, d, true) + d = addCube(d, 80, 100, 103, 300) + d = addCube(d, 250, 300, 20, 99) + d = addCube(d, 250, 300, 103, 300) + d = addCube(d, 301, 400, 20, 300) + require.True(t, c.Equal(d)) } + func TestAddHole3(t *testing.T) { c := hypercube.FromCubeShort(1, 100, 200, 300) c = c.Subtract(hypercube.FromCubeShort(1, 100, 200, 300)) - checkEqual(t, c, hypercube.NewCanonicalSet(2), true) + require.Equal(t, c, hypercube.NewCanonicalSet(2)) } func TestIntervalsUnion(t *testing.T) { c := hypercube.FromCubeShort(1, 100, 200, 300) - c = addCube2Dim(c, 101, 200, 200, 300) + c = addCube(c, 101, 200, 200, 300) d := hypercube.FromCubeShort(1, 200, 200, 300) - checkEqual(t, c, d, true) + require.True(t, c.Equal(d)) if c.String() != d.String() { t.FailNow() } @@ -337,23 +274,23 @@ func TestIntervalsUnion(t *testing.T) { func TestIntervalsUnion2(t *testing.T) { c := hypercube.FromCubeShort(1, 100, 200, 300) - c = addCube2Dim(c, 101, 200, 200, 300) - c = addCube2Dim(c, 201, 300, 200, 300) - c = addCube2Dim(c, 301, 400, 200, 300) - c = addCube2Dim(c, 402, 500, 200, 300) - c = addCube2Dim(c, 500, 600, 200, 700) - c = addCube2Dim(c, 601, 700, 200, 700) + c = addCube(c, 101, 200, 200, 300) + c = addCube(c, 201, 300, 200, 300) + c = addCube(c, 301, 400, 200, 300) + c = addCube(c, 402, 500, 200, 300) + c = addCube(c, 500, 600, 200, 700) + c = addCube(c, 601, 700, 200, 700) d := c.Copy() - d = addCube2Dim(d, 702, 800, 200, 700) + d = addCube(d, 702, 800, 200, 700) cExpected := hypercube.FromCubeShort(1, 400, 200, 300) - cExpected = addCube2Dim(cExpected, 402, 500, 200, 300) - cExpected = addCube2Dim(cExpected, 500, 700, 200, 700) + cExpected = addCube(cExpected, 402, 500, 200, 300) + cExpected = addCube(cExpected, 500, 700, 200, 700) dExpected := cExpected.Copy() - dExpected = addCube2Dim(dExpected, 702, 800, 200, 700) - checkEqual(t, c, cExpected, true) - checkEqual(t, d, dExpected, true) + dExpected = addCube(dExpected, 702, 800, 200, 700) + require.True(t, c.Equal(cExpected)) + require.True(t, d.Equal(dExpected)) } func TestAndSubOr(t *testing.T) { @@ -362,26 +299,26 @@ func TestAndSubOr(t *testing.T) { c := a.Intersect(b) d := hypercube.FromCubeShort(8, 15, 7, 10) - checkEqual(t, c, d, true) + require.True(t, c.Equal(d)) f := a.Union(b) e := hypercube.FromCubeShort(5, 15, 3, 6) - e = addCube2Dim(e, 5, 30, 7, 10) - e = addCube2Dim(e, 8, 30, 11, 20) - checkEqual(t, e, f, true) + e = addCube(e, 5, 30, 7, 10) + e = addCube(e, 8, 30, 11, 20) + require.True(t, e.Equal(f)) g := a.Subtract(b) h := hypercube.FromCubeShort(5, 7, 3, 10) - h = addCube2Dim(h, 8, 15, 3, 6) - checkEqual(t, g, h, true) + h = addCube(h, 8, 15, 3, 6) + require.True(t, g.Equal(h)) } func TestAnd2(t *testing.T) { a := hypercube.FromCubeShort(5, 15, 3, 10) b := hypercube.FromCubeShort(1, 3, 7, 20) - b = addCube2Dim(b, 20, 23, 7, 20) + b = addCube(b, 20, 23, 7, 20) c := a.Intersect(b) - checkEqual(t, c, hypercube.NewCanonicalSet(2), true) + require.Equal(t, c, hypercube.NewCanonicalSet(2)) } func TestOr2(t *testing.T) { @@ -389,7 +326,7 @@ func TestOr2(t *testing.T) { b := hypercube.FromCubeShort(1, 65535, 10054, 10054) a = a.Union(b) expected := hypercube.FromCubeShort(1, 79, 10054, 10054) - expected = addCube2Dim(expected, 80, 100, 10053, 10054) - expected = addCube2Dim(expected, 101, 65535, 10054, 10054) - checkEqual(t, a, expected, true) + expected = addCube(expected, 80, 100, 10053, 10054) + expected = addCube(expected, 101, 65535, 10054, 10054) + require.True(t, a.Equal(expected)) } diff --git a/pkg/interval/interval.go b/pkg/interval/interval.go index 031f642..facdff5 100644 --- a/pkg/interval/interval.go +++ b/pkg/interval/interval.go @@ -1,3 +1,5 @@ +// Copyright 2020- IBM Inc. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 package interval import "fmt" @@ -9,27 +11,31 @@ type Interval struct { } // String returns a String representation of Interval object -func (i *Interval) String() string { +func (i Interval) String() string { return fmt.Sprintf("[%v-%v]", i.Start, i.End) } // Equal returns true if current Interval obj is equal to the input Interval -func (i *Interval) Equal(x Interval) bool { +func (i Interval) Equal(x Interval) bool { return i.Start == x.Start && i.End == x.End } -func (i *Interval) overlaps(other Interval) bool { +func (i Interval) Size() int64 { + return i.End - i.Start + 1 +} + +func (i Interval) overlaps(other Interval) bool { return other.End >= i.Start && other.Start <= i.End } -func (i *Interval) isSubset(other Interval) bool { +func (i Interval) isSubset(other Interval) bool { return other.Start <= i.Start && other.End >= i.End } // returns a list with up to 2 intervals -func (i *Interval) subtract(other Interval) []Interval { +func (i Interval) subtract(other Interval) []Interval { if !i.overlaps(other) { - return []Interval{*i} + return []Interval{i} } if i.isSubset(other) { return []Interval{} @@ -44,7 +50,7 @@ func (i *Interval) subtract(other Interval) []Interval { return []Interval{{Start: max(i.Start, other.End+1), End: i.End}} } -func (i *Interval) intersection(other Interval) []Interval { +func (i Interval) intersection(other Interval) []Interval { maxStart := max(i.Start, other.Start) minEnd := min(i.End, other.End) if minEnd < maxStart { diff --git a/pkg/interval/intervalset.go b/pkg/interval/intervalset.go index 5d52f0e..a4b1b81 100644 --- a/pkg/interval/intervalset.go +++ b/pkg/interval/intervalset.go @@ -1,34 +1,45 @@ +// Copyright 2020- IBM Inc. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 package interval import ( + "errors" "fmt" "slices" "sort" ) -// CanonicalSet is a canonical representation of a set of Interval objects +// CanonicalSet is a set of int64 integers, implemented using an ordered slice of non-overlapping, non-touching interval type CanonicalSet struct { - IntervalSet []Interval // sorted list of non-overlapping intervals + intervalSet []Interval } func NewCanonicalIntervalSet() *CanonicalSet { return &CanonicalSet{ - IntervalSet: []Interval{}, + intervalSet: []Interval{}, } } // IsEmpty returns true if the CanonicalSet is empty func (c *CanonicalSet) IsEmpty() bool { - return len(c.IntervalSet) == 0 + return len(c.intervalSet) == 0 +} + +func (c *CanonicalSet) CalculateSize() int64 { + var res int64 = 0 + for _, r := range c.intervalSet { + res += r.Size() + } + return res } // Equal returns true if the CanonicalSet equals the input CanonicalSet -func (c *CanonicalSet) Equal(other CanonicalSet) bool { - if len(c.IntervalSet) != len(other.IntervalSet) { +func (c *CanonicalSet) Equal(other *CanonicalSet) bool { + if len(c.intervalSet) != len(other.intervalSet) { return false } - for index := range c.IntervalSet { - if !(c.IntervalSet[index].Equal(other.IntervalSet[index])) { + for index := range c.intervalSet { + if !(c.intervalSet[index].Equal(other.intervalSet[index])) { return false } } @@ -37,7 +48,7 @@ func (c *CanonicalSet) Equal(other CanonicalSet) bool { // AddInterval adds a new interval range to the set func (c *CanonicalSet) AddInterval(v Interval) { - set := c.IntervalSet + set := c.intervalSet left := sort.Search(len(set), func(i int) bool { return set[i].End >= v.Start-1 }) @@ -50,20 +61,7 @@ func (c *CanonicalSet) AddInterval(v Interval) { if right > 0 && set[right-1].End >= v.Start { v.End = max(v.End, set[right-1].End) } - c.IntervalSet = slices.Replace(c.IntervalSet, left, right, v) -} - -// AddHole updates the current CanonicalSet object by removing the input Interval from the set -func (c *CanonicalSet) AddHole(hole Interval) { - newIntervalSet := []Interval{} - for _, interval := range c.IntervalSet { - newIntervalSet = append(newIntervalSet, interval.subtract(hole)...) - } - c.IntervalSet = newIntervalSet -} - -func getNumAsStr(num int64) string { - return fmt.Sprintf("%v", num) + c.intervalSet = slices.Replace(c.intervalSet, left, right, v) } // String returns a string representation of the current CanonicalSet object @@ -72,37 +70,40 @@ func (c *CanonicalSet) String() string { return "Empty" } res := "" - for _, interval := range c.IntervalSet { - res += getNumAsStr(interval.Start) + for _, interval := range c.intervalSet { if interval.Start != interval.End { - res += "-" + getNumAsStr(interval.End) + res += fmt.Sprintf("%v-%v", interval.Start, interval.End) + } else { + res += fmt.Sprintf("%v", interval.Start) } res += "," } return res[:len(res)-1] } -// Union updates the CanonicalSet object with the union result of the input CanonicalSet -func (c *CanonicalSet) Union(other CanonicalSet) { - for _, interval := range other.IntervalSet { - c.AddInterval(interval) +// Union returns the union of the two sets +func (c *CanonicalSet) Union(other *CanonicalSet) *CanonicalSet { + res := c.Copy() + for _, interval := range other.intervalSet { + res.AddInterval(interval) } + return res } // Copy returns a new copy of the CanonicalSet object -func (c *CanonicalSet) Copy() CanonicalSet { - return CanonicalSet{IntervalSet: append([]Interval(nil), c.IntervalSet...)} +func (c *CanonicalSet) Copy() *CanonicalSet { + return &CanonicalSet{intervalSet: slices.Clone(c.intervalSet)} } func (c *CanonicalSet) Contains(n int64) bool { - i := FromInterval(n, n) - return i.ContainedIn(*c) + i := CreateSetFromInterval(n, n) + return i.ContainedIn(c) } -// ContainedIn returns true of the current CanonicalSet is contained in the input CanonicalSet -func (c *CanonicalSet) ContainedIn(other CanonicalSet) bool { - larger := other.IntervalSet - for _, target := range c.IntervalSet { +// ContainedIn returns true of the current interval.CanonicalSet is contained in the input interval.CanonicalSet +func (c *CanonicalSet) ContainedIn(other *CanonicalSet) bool { + larger := other.intervalSet + for _, target := range c.intervalSet { left := sort.Search(len(larger), func(i int) bool { return larger[i].End >= target.End }) @@ -115,21 +116,21 @@ func (c *CanonicalSet) ContainedIn(other CanonicalSet) bool { return true } -// Intersect updates current CanonicalSet with intersection result of input CanonicalSet -func (c *CanonicalSet) Intersect(other CanonicalSet) { - newIntervalSet := []Interval{} - for _, interval := range c.IntervalSet { - for _, otherInterval := range other.IntervalSet { - newIntervalSet = append(newIntervalSet, interval.intersection(otherInterval)...) +// Intersect returns the intersection of the current set with the input set +func (c *CanonicalSet) Intersect(other *CanonicalSet) *CanonicalSet { + res := NewCanonicalIntervalSet() + for _, interval := range c.intervalSet { + for _, otherInterval := range other.intervalSet { + res.intervalSet = append(res.intervalSet, interval.intersection(otherInterval)...) } } - c.IntervalSet = newIntervalSet + return res } // Overlaps returns true if current CanonicalSet overlaps with input CanonicalSet func (c *CanonicalSet) Overlaps(other *CanonicalSet) bool { - for _, selfInterval := range c.IntervalSet { - for _, otherInterval := range other.IntervalSet { + for _, selfInterval := range c.intervalSet { + for _, otherInterval := range other.intervalSet { if selfInterval.overlaps(otherInterval) { return true } @@ -138,20 +139,62 @@ func (c *CanonicalSet) Overlaps(other *CanonicalSet) bool { return false } -// Subtract updates current CanonicalSet with subtraction result of input CanonicalSet -func (c *CanonicalSet) Subtract(other CanonicalSet) { - for _, i := range other.IntervalSet { - c.AddHole(i) +// Subtract returns the subtraction result of input CanonicalSet +func (c *CanonicalSet) Subtract(other *CanonicalSet) *CanonicalSet { + res := slices.Clone(c.intervalSet) + for _, hole := range other.intervalSet { + newIntervalSet := []Interval{} + for _, interval := range res { + newIntervalSet = append(newIntervalSet, interval.subtract(hole)...) + } + res = newIntervalSet + } + return &CanonicalSet{ + intervalSet: res, } } func (c *CanonicalSet) IsSingleNumber() bool { - if len(c.IntervalSet) == 1 && c.IntervalSet[0].Start == c.IntervalSet[0].End { + if len(c.intervalSet) == 1 && c.intervalSet[0].Size() == 1 { return true } return false } -func FromInterval(start, end int64) *CanonicalSet { - return &CanonicalSet{IntervalSet: []Interval{{Start: start, End: end}}} +func (c *CanonicalSet) Min() (int64, error) { + if len(c.intervalSet) > 0 { + return c.intervalSet[0].Start, nil + } + return 0, errors.New("empty interval set") +} + +// Split returns a set of canonical set objects, each with a single interval +func (c *CanonicalSet) Split() []*CanonicalSet { + res := make([]*CanonicalSet, len(c.intervalSet)) + for index, ipr := range c.intervalSet { + res[index] = CreateSetFromInterval(ipr.Start, ipr.End) + } + return res +} + +func (c *CanonicalSet) Intervals() []Interval { + return slices.Clone(c.intervalSet) +} + +func (c *CanonicalSet) NumIntervals() int { + return len(c.intervalSet) +} + +func (c *CanonicalSet) Elements() []int64 { + res := []int64{} + for _, interval := range c.intervalSet { + for i := interval.Start; i <= interval.End; i++ { + res = append(res, i) + } + } + return res +} + +func CreateSetFromInterval(start, end int64) *CanonicalSet { + return &CanonicalSet{intervalSet: []Interval{{Start: start, End: end}}} } diff --git a/pkg/interval/intervalset_test.go b/pkg/interval/intervalset_test.go index 3c50ba3..76876dd 100644 --- a/pkg/interval/intervalset_test.go +++ b/pkg/interval/intervalset_test.go @@ -1,3 +1,5 @@ +// Copyright 2020- IBM Inc. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 package interval_test import ( @@ -10,6 +12,7 @@ import ( func TestInterval(t *testing.T) { it1 := interval.Interval{3, 7} + require.Equal(t, "[3-7]", it1.String()) } @@ -19,7 +22,7 @@ func TestIntervalSet(t *testing.T) { is1.AddInterval(interval.Interval{0, 1}) is1.AddInterval(interval.Interval{3, 3}) is1.AddInterval(interval.Interval{70, 80}) - is1.AddHole(interval.Interval{7, 9}) + is1 = is1.Subtract(interval.CreateSetFromInterval(7, 9)) require.True(t, is1.Contains(5)) require.False(t, is1.Contains(8)) @@ -28,30 +31,38 @@ func TestIntervalSet(t *testing.T) { is2.AddInterval(interval.Interval{6, 8}) require.Equal(t, "6-8", is2.String()) require.False(t, is2.IsSingleNumber()) - require.False(t, is2.ContainedIn(*is1)) - require.False(t, is1.ContainedIn(*is2)) - require.False(t, is2.Equal(*is1)) - require.False(t, is1.Equal(*is2)) + require.False(t, is2.ContainedIn(is1)) + require.False(t, is1.ContainedIn(is2)) + require.False(t, is2.Equal(is1)) + require.False(t, is1.Equal(is2)) require.True(t, is1.Overlaps(is2)) require.True(t, is2.Overlaps(is1)) - is1.Subtract(*is2) - require.False(t, is2.ContainedIn(*is1)) - require.False(t, is1.ContainedIn(*is2)) + is1 = is1.Subtract(is2) + require.False(t, is2.ContainedIn(is1)) + require.False(t, is1.ContainedIn(is2)) require.False(t, is1.Overlaps(is2)) require.False(t, is2.Overlaps(is1)) - is1.Union(*is2) - is1.Union(*interval.FromInterval(7, 9)) - require.True(t, is2.ContainedIn(*is1)) - require.False(t, is1.ContainedIn(*is2)) + is1 = is1.Union(is2).Union(interval.CreateSetFromInterval(7, 9)) + require.True(t, is2.ContainedIn(is1)) + require.False(t, is1.ContainedIn(is2)) require.True(t, is1.Overlaps(is2)) require.True(t, is2.Overlaps(is1)) - is3 := is1.Copy() - is3.Intersect(*is2) - require.True(t, is3.Equal(*is2)) + is3 := is1.Intersect(is2) + require.True(t, is3.Equal(is2)) require.True(t, is2.ContainedIn(is3)) - require.True(t, interval.FromInterval(1, 1).IsSingleNumber()) + require.True(t, interval.CreateSetFromInterval(1, 1).IsSingleNumber()) +} + +func TestIntervalSetSubtract(t *testing.T) { + s := interval.CreateSetFromInterval(1, 100) + s.AddInterval(interval.Interval{Start: 400, End: 700}) + d := *interval.CreateSetFromInterval(50, 100) + d.AddInterval(interval.Interval{Start: 400, End: 700}) + actual := s.Subtract(&d) + expected := interval.CreateSetFromInterval(1, 49) + require.Equal(t, expected.String(), actual.String()) } diff --git a/pkg/ipblock/ipblock.go b/pkg/ipblock/ipblock.go index d25b7a4..d8176fd 100644 --- a/pkg/ipblock/ipblock.go +++ b/pkg/ipblock/ipblock.go @@ -1,3 +1,5 @@ +// Copyright 2020- IBM Inc. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 package ipblock import ( @@ -33,7 +35,7 @@ const ( // IPBlock captures a set of IP ranges type IPBlock struct { - ipRange interval.CanonicalSet + ipRange *interval.CanonicalSet } // ToIPRanges returns a string of the ip ranges in the current IPBlock object @@ -50,24 +52,23 @@ func toIPRange(i interval.Interval) string { // toIPRangesList: returns a list of the ip-ranges strings in the current IPBlock object func (b *IPBlock) toIPRangesList() []string { - IPRanges := make([]string, len(b.ipRange.IntervalSet)) - for index := range b.ipRange.IntervalSet { - IPRanges[index] = toIPRange(b.ipRange.IntervalSet[index]) + IPRanges := make([]string, b.ipRange.NumIntervals()) + for index, v := range b.ipRange.Intervals() { + IPRanges[index] = toIPRange(v) } return IPRanges } -// ContainedIn checks if this IP block is contained within another IP block. -func (b *IPBlock) ContainedIn(other *IPBlock) bool { - return b.ipRange.ContainedIn(other.ipRange) +// ContainedIn returns true if the input IPBlock is contained in this IPBlock +func (b *IPBlock) ContainedIn(c *IPBlock) bool { + return b.ipRange.ContainedIn(c.ipRange) } // Intersect returns a new IPBlock from intersection of this IPBlock with input IPBlock func (b *IPBlock) Intersect(c *IPBlock) *IPBlock { - res := &IPBlock{} - res.ipRange = b.ipRange.Copy() - res.ipRange.Intersect(c.ipRange) - return res + return &IPBlock{ + ipRange: b.ipRange.Intersect(c.ipRange), + } } // Equal returns true if this IPBlock equals the input IPBlock @@ -77,18 +78,16 @@ func (b *IPBlock) Equal(c *IPBlock) bool { // Subtract returns a new IPBlock from subtraction of input IPBlock from this IPBlock func (b *IPBlock) Subtract(c *IPBlock) *IPBlock { - res := &IPBlock{} - res.ipRange = b.ipRange.Copy() - res.ipRange.Subtract(c.ipRange) - return res + return &IPBlock{ + ipRange: b.ipRange.Subtract(c.ipRange), + } } // Union returns a new IPBlock from union of input IPBlock with this IPBlock func (b *IPBlock) Union(c *IPBlock) *IPBlock { - res := &IPBlock{} - res.ipRange = b.ipRange.Copy() - res.ipRange.Union(c.ipRange) - return res + return &IPBlock{ + ipRange: b.ipRange.Union(c.ipRange), + } } // Empty returns true if this IPBlock is empty @@ -102,38 +101,31 @@ func rangeIPstr(start, end string) string { // Copy returns a new copy of IPBlock object func (b *IPBlock) Copy() *IPBlock { - res := &IPBlock{} - res.ipRange = b.ipRange.Copy() - return res + return &IPBlock{ + ipRange: b.ipRange.Copy(), + } } func (b *IPBlock) ipCount() int { - res := 0 - for _, r := range b.ipRange.IntervalSet { - res += int(r.End) - int(r.Start) + 1 - } - return res + return int(b.ipRange.CalculateSize()) } // Split returns a set of IpBlock objects, each with a single range of ips func (b *IPBlock) Split() []*IPBlock { - res := make([]*IPBlock, len(b.ipRange.IntervalSet)) - for index, ipr := range b.ipRange.IntervalSet { - newBlock := IPBlock{} - newBlock.ipRange.IntervalSet = append(newBlock.ipRange.IntervalSet, interval.Interval{Start: ipr.Start, End: ipr.End}) - res[index] = &newBlock + res := make([]*IPBlock, b.ipRange.NumIntervals()) + for index, set := range b.ipRange.Split() { + res[index] = &IPBlock{ + ipRange: set, + } } return res } // intToIP4 returns a string of an ip address from an input integer ip value func intToIP4(ipInt int64) string { - // need to do two bit shifting and “0xff” masking - b0 := strconv.FormatInt((ipInt>>ipShift0)&ipByte, ipBase) - b1 := strconv.FormatInt((ipInt>>ipShift1)&ipByte, ipBase) - b2 := strconv.FormatInt((ipInt>>ipShift2)&ipByte, ipBase) - b3 := strconv.FormatInt((ipInt & ipByte), ipBase) - return b0 + "." + b1 + "." + b2 + "." + b3 + var d [4]byte + binary.BigEndian.PutUint32(d[:], uint32(ipInt)) + return net.IPv4(d[0], d[1], d[2], d[3]).String() } // DisjointIPBlocks returns an IPBlock of disjoint ip ranges from 2 input IPBlock objects @@ -150,15 +142,13 @@ func DisjointIPBlocks(set1, set2 []*IPBlock) []*IPBlock { return ipbList[i].ipCount() < ipbList[j].ipCount() }) // making sure the resulting list does not contain overlapping ipBlocks - blocksWithNoOverlaps := []*IPBlock{} + res := []*IPBlock{} for _, ipb := range ipbList { - blocksWithNoOverlaps = addIntervalToList(ipb, blocksWithNoOverlaps) + res = addIntervalToList(ipb, res) } - res := blocksWithNoOverlaps if len(res) == 0 { - newAll := GetCidrAll() - res = append(res, newAll) + res = append(res, GetCidrAll()) } return res } @@ -167,17 +157,16 @@ func DisjointIPBlocks(set1, set2 []*IPBlock) []*IPBlock { func addIntervalToList(ipbNew *IPBlock, ipbList []*IPBlock) []*IPBlock { toAdd := []*IPBlock{} for idx, ipb := range ipbList { - if !ipb.ipRange.Overlaps(&ipbNew.ipRange) { + if !ipb.ipRange.Overlaps(ipbNew.ipRange) { continue } - intersection := ipb.Copy() - intersection.ipRange.Intersect(ipbNew.ipRange) - ipbNew.ipRange.Subtract(intersection.ipRange) - if !ipb.ipRange.Equal(intersection.ipRange) { + intersection := ipb.Intersect(ipbNew) + ipbNew = ipbNew.Subtract(intersection) + if !ipb.Equal(intersection) { toAdd = append(toAdd, intersection) - ipbList[idx].ipRange.Subtract(intersection.ipRange) + ipbList[idx] = ipbList[idx].Subtract(intersection) } - if len(ipbNew.ipRange.IntervalSet) == 0 { + if ipbNew.ipRange.IsEmpty() { break } } @@ -188,7 +177,26 @@ func addIntervalToList(ipbNew *IPBlock, ipbList []*IPBlock) []*IPBlock { // FromCidr returns a new IPBlock object from input CIDR string func FromCidr(cidr string) (*IPBlock, error) { - return FromCidrExcept(cidr, []string{}) + start, end, err := cidrToIPRange(cidr) + if err != nil { + return nil, err + } + return &IPBlock{ + ipRange: interval.CreateSetFromInterval(start, end), + }, nil +} + +// ExceptCidrs returns a new IPBlock with all cidr ranges removed +func (b *IPBlock) ExceptCidrs(cidrExceptions ...string) (*IPBlock, error) { + res := b.Copy() + for i := range cidrExceptions { + hole, err := FromCidr(cidrExceptions[i]) + if err != nil { + return nil, err + } + res = res.Subtract(hole) + } + return res, nil } // PairCIDRsToIPBlocks returns two IPBlock objects from two input CIDR strings @@ -201,7 +209,14 @@ func PairCIDRsToIPBlocks(cidr1, cidr2 string) (ipb1, ipb2 *IPBlock, err error) { return ipb1, ipb2, nil } -// FromCidrOrAddress returns a new IPBlock object from input string of CIDR or IP address +// New returns a new IPBlock object +func New() *IPBlock { + return &IPBlock{ + ipRange: interval.NewCanonicalIntervalSet(), + } +} + +// FromCidr returns a new IPBlock object from input string of CIDR or IP address func FromCidrOrAddress(s string) (*IPBlock, error) { if strings.Contains(s, cidrSeparator) { return FromCidr(s) @@ -211,42 +226,26 @@ func FromCidrOrAddress(s string) (*IPBlock, error) { // FromCidrList returns IPBlock object from multiple CIDRs given as list of strings func FromCidrList(cidrsList []string) (*IPBlock, error) { - res := &IPBlock{ipRange: interval.CanonicalSet{}} + ipRange := interval.NewCanonicalIntervalSet() for _, cidr := range cidrsList { block, err := FromCidr(cidr) if err != nil { return nil, err } - res = res.Union(block) + ipRange = ipRange.Union(block.ipRange) } - return res, nil + return &IPBlock{ipRange: ipRange}, nil } -// FromCidrExcept returns an IPBlock object from input cidr str an exceptions cidr str -func FromCidrExcept(cidr string, exceptions []string) (*IPBlock, error) { - res := IPBlock{ipRange: interval.CanonicalSet{}} - span, err := cidrToInterval(cidr) +// FromIPAddress returns an IPBlock object from input IP address string +func FromIPAddress(ipAddress string) (*IPBlock, error) { + ipNum, err := parseIP(ipAddress) if err != nil { return nil, err } - res.ipRange.AddInterval(*span) - for i := range exceptions { - intervalHole, err := cidrToInterval(exceptions[i]) - if err != nil { - return nil, err - } - res.ipRange.AddHole(*intervalHole) - } - return &res, nil -} - -func ipv4AddressToCidr(ipAddress string) string { - return ipAddress + "/32" -} - -// FromIPAddress returns an IPBlock object from input IP address string -func FromIPAddress(ipAddress string) (*IPBlock, error) { - return FromCidrExcept(ipv4AddressToCidr(ipAddress), []string{}) + return &IPBlock{ + ipRange: interval.CreateSetFromInterval(ipNum, ipNum), + }, nil } func cidrToIPRange(cidr string) (start, end int64, err error) { @@ -267,19 +266,11 @@ func cidrToIPRange(cidr string) (start, end int64, err error) { return } -func cidrToInterval(cidr string) (*interval.Interval, error) { - start, end, err := cidrToIPRange(cidr) - if err != nil { - return nil, err - } - return &interval.Interval{Start: start, End: end}, nil -} - // ToCidrList returns a list of CIDR strings for this IPBlock object func (b *IPBlock) ToCidrList() []string { cidrList := []string{} - for _, interval := range b.ipRange.IntervalSet { - cidrList = append(cidrList, intervalToCidrList(interval.Start, interval.End)...) + for _, ipRange := range b.ipRange.Intervals() { + cidrList = append(cidrList, intervalToCidrList(ipRange)...) } return cidrList } @@ -292,12 +283,12 @@ func (b *IPBlock) ToCidrListString() string { // ListToPrint: returns a uniform to print list s.t. each element contains either a single cidr or an ip range func (b *IPBlock) ListToPrint() []string { cidrsIPRangesList := []string{} - for _, interval := range b.ipRange.IntervalSet { - cidr := intervalToCidrList(interval.Start, interval.End) + for _, ipRange := range b.ipRange.Intervals() { + cidr := intervalToCidrList(ipRange) if len(cidr) == 1 { cidrsIPRangesList = append(cidrsIPRangesList, cidr[0]) } else { - cidrsIPRangesList = append(cidrsIPRangesList, toIPRange(interval)) + cidrsIPRangesList = append(cidrsIPRangesList, toIPRange(ipRange)) } } return cidrsIPRangesList @@ -306,14 +297,15 @@ func (b *IPBlock) ListToPrint() []string { // ToIPAdressString returns the IP Address string for this IPBlock func (b *IPBlock) ToIPAddressString() string { if b.ipRange.IsSingleNumber() { - return intToIP4(b.ipRange.IntervalSet[0].Start) + m, _ := b.ipRange.Min() + return intToIP4(m) } return "" } -func intervalToCidrList(ipStart, ipEnd int64) []string { - start := ipStart - end := ipEnd +func intervalToCidrList(ipRange interval.Interval) []string { + start := ipRange.Start + end := ipRange.End res := []string{} for end >= start { maxSize := maxIPv4Bits @@ -338,25 +330,28 @@ func intervalToCidrList(ipStart, ipEnd int64) []string { return res } +func parseIP(ip string) (int64, error) { + startIP := net.ParseIP(ip) + if startIP == nil { + return 0, fmt.Errorf("%v is not a valid ipv4", ip) + } + return int64(binary.BigEndian.Uint32(startIP.To4())), nil +} + // FromIPRangeStr returns IPBlock object from input IP range string (example: "169.255.0.0-172.15.255.255") func FromIPRangeStr(ipRangeStr string) (*IPBlock, error) { ipAddresses := strings.Split(ipRangeStr, dash) if len(ipAddresses) != 2 { return nil, errors.New("unexpected ipRange str") } - var startIP, endIP *IPBlock - var err error - if startIP, err = FromIPAddress(ipAddresses[0]); err != nil { - return nil, err + startIPNum, err0 := parseIP(ipAddresses[0]) + endIPNum, err1 := parseIP(ipAddresses[1]) + if err0 != nil || err1 != nil { + return nil, errors.Join(err0, err1) } - if endIP, err = FromIPAddress(ipAddresses[1]); err != nil { - return nil, err + res := &IPBlock{ + ipRange: interval.CreateSetFromInterval(startIPNum, endIPNum), } - res := &IPBlock{} - res.ipRange = interval.CanonicalSet{IntervalSet: []interval.Interval{}} - startIPNum := startIP.ipRange.IntervalSet[0].Start - endIPNum := endIP.ipRange.IntervalSet[0].Start - res.ipRange.IntervalSet = append(res.ipRange.IntervalSet, interval.Interval{Start: startIPNum, End: endIPNum}) return res, nil } diff --git a/pkg/ipblock/ipblock_test.go b/pkg/ipblock/ipblock_test.go index 8b8bfc0..e941cc4 100644 --- a/pkg/ipblock/ipblock_test.go +++ b/pkg/ipblock/ipblock_test.go @@ -1,3 +1,5 @@ +// Copyright 2020- IBM Inc. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 package ipblock_test import ( @@ -21,15 +23,17 @@ func TestOps(t *testing.T) { minus := ipb1.Subtract(ipb2) require.Equal(t, "1.2.3.0-1.2.3.3, 1.2.3.5-1.2.3.255", minus.ToIPRanges()) - minus2, err := ipblock.FromCidrExcept(ipb1.ToCidrListString(), []string{ipb2.ToCidrListString()}) + minus2, err := ipblock.FromCidr(ipb1.ToCidrListString()) + require.Nil(t, err) + minus2, err = minus2.ExceptCidrs(ipb2.ToCidrListString()) require.Nil(t, err) require.Equal(t, minus.ToCidrListString(), minus2.ToCidrListString()) intersect := ipb1.Intersect(ipb2) - require.True(t, intersect.Equal(ipb2)) + require.Equal(t, intersect, ipb2) union := intersect.Union(minus) - require.True(t, union.Equal(ipb1)) + require.Equal(t, union, ipb1) intersect2 := minus.Intersect(intersect) require.True(t, intersect2.Empty()) @@ -107,10 +111,7 @@ func TestPrefixLength(t *testing.T) { } func TestBadPath(t *testing.T) { - _, err := ipblock.FromCidrExcept("not-a-cidr", nil) - require.NotNil(t, err) - - _, err = ipblock.FromCidrExcept("2.5.7.9/24", []string{"5.6.7.8/20", "not-a-cidr"}) + _, err := ipblock.FromCidr("not-a-cidr") require.NotNil(t, err) _, err = ipblock.FromCidrList([]string{"1.2.3.4/20", "not-a-cidr"}) diff --git a/pkg/netp/common.go b/pkg/netp/common.go new file mode 100644 index 0000000..abd2b31 --- /dev/null +++ b/pkg/netp/common.go @@ -0,0 +1,20 @@ +// Copyright 2020- IBM Inc. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package netp + +type Protocol interface { + // InverseDirection returns the response expected for a request made using this protocol + InverseDirection() Protocol +} + +type AnyProtocol struct{} + +func (t AnyProtocol) InverseDirection() Protocol { return AnyProtocol{} } + +type ProtocolString string + +const ( + ProtocolStringTCP ProtocolString = "TCP" + ProtocolStringUDP ProtocolString = "UDP" + ProtocolStringICMP ProtocolString = "ICMP" +) diff --git a/pkg/netp/icmp.go b/pkg/netp/icmp.go new file mode 100644 index 0000000..f976adc --- /dev/null +++ b/pkg/netp/icmp.go @@ -0,0 +1,128 @@ +// Copyright 2020- IBM Inc. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package netp + +import ( + "fmt" + "log" +) + +type ICMPTypeCode struct { + // ICMP type allowed. + Type int + + // ICMP code allowed. If omitted, any code is allowed + Code *int +} + +type ICMP struct { + typeCode *ICMPTypeCode +} + +func NewICMP(typeCode *ICMPTypeCode) (ICMP, error) { + err := ValidateICMP(typeCode) + if err != nil { + return ICMP{}, err + } + return ICMP{typeCode: typeCode}, nil +} + +func (t ICMP) ICMPTypeCode() *ICMPTypeCode { + if t.typeCode == nil { + return nil + } + if t.typeCode.Code == nil { + return t.typeCode + } + // avoid aliasing and mutation by someone else + code := *t.typeCode.Code + return &ICMPTypeCode{Type: t.typeCode.Type, Code: &code} +} + +func (t ICMP) InverseDirection() Protocol { + if t.typeCode == nil { + return nil + } + + if invType := inverseICMPType(t.typeCode.Type); invType != undefinedICMP { + return ICMP{typeCode: &ICMPTypeCode{Type: invType, Code: t.typeCode.Code}} + } + return nil +} + +// Based on https://datatracker.ietf.org/doc/html/rfc792 + +const ( + EchoReply = 0 + DestinationUnreachable = 3 + SourceQuench = 4 + Redirect = 5 + Echo = 8 + TimeExceeded = 11 + ParameterProblem = 12 + Timestamp = 13 + TimestampReply = 14 + InformationRequest = 15 + InformationReply = 16 + + undefinedICMP = -2 +) + +// inverseICMPType returns the reply type for request type and vice versa. +// When there is no inverse, returns undefinedICMP +func inverseICMPType(t int) int { + switch t { + case Echo: + return EchoReply + case EchoReply: + return Echo + + case Timestamp: + return TimestampReply + case TimestampReply: + return Timestamp + + case InformationRequest: + return InformationReply + case InformationReply: + return InformationRequest + + case DestinationUnreachable, SourceQuench, Redirect, TimeExceeded, ParameterProblem: + return undefinedICMP + default: + log.Panicf("Impossible ICMP type: %v", t) + } + return undefinedICMP +} + +//nolint:revive // magic numbers are fine here +func ValidateICMP(typeCode *ICMPTypeCode) error { + if typeCode == nil { + return nil + } + maxCodes := map[int]int{ + EchoReply: 0, + DestinationUnreachable: 5, + SourceQuench: 0, + Redirect: 3, + Echo: 0, + TimeExceeded: 1, + ParameterProblem: 0, + Timestamp: 0, + TimestampReply: 0, + InformationRequest: 0, + InformationReply: 0, + } + maxCode, ok := maxCodes[typeCode.Type] + if !ok { + return fmt.Errorf("invalid ICMP type %v", typeCode.Type) + } + if *typeCode.Code > maxCode { + return fmt.Errorf("ICMP code %v is invalid for ICMP type %v", *typeCode.Code, typeCode.Type) + } + return nil +} + +func (t ICMP) ProtocolString() ProtocolString { + return ProtocolStringICMP +} diff --git a/pkg/netp/tcpudp.go b/pkg/netp/tcpudp.go new file mode 100644 index 0000000..29e5c5d --- /dev/null +++ b/pkg/netp/tcpudp.go @@ -0,0 +1,35 @@ +// Copyright 2020- IBM Inc. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package netp + +import "github.com/np-guard/models/pkg/interval" + +const MinPort = 1 +const MaxPort = 65535 + +type PortRangePair struct { + SrcPort interval.Interval + DstPort interval.Interval +} + +type TCPUDP struct { + IsTCP bool + PortRangePair PortRangePair +} + +func (t TCPUDP) InverseDirection() Protocol { + if !t.IsTCP { + return nil + } + return TCPUDP{ + IsTCP: true, + PortRangePair: PortRangePair{SrcPort: t.PortRangePair.DstPort, DstPort: t.PortRangePair.SrcPort}, + } +} + +func (t TCPUDP) ProtocolString() ProtocolString { + if t.IsTCP { + return ProtocolStringTCP + } + return ProtocolStringUDP +} From 92a9c9d1b0ab133b3d204ca1b1d8b6abb607a9d4 Mon Sep 17 00:00:00 2001 From: Elazar Gershuni Date: Sun, 17 Mar 2024 14:48:30 +0200 Subject: [PATCH 02/15] Fix merge Signed-off-by: Elazar Gershuni --- pkg/connection/connectionset.go | 2 +- pkg/connection/statefulness.go | 5 ++++- pkg/hypercube/hypercubeset.go | 11 +++++++++++ pkg/hypercube/hypercubeset_test.go | 6 +----- 4 files changed, 17 insertions(+), 7 deletions(-) diff --git a/pkg/connection/connectionset.go b/pkg/connection/connectionset.go index c83d72e..2d07687 100644 --- a/pkg/connection/connectionset.go +++ b/pkg/connection/connectionset.go @@ -197,7 +197,7 @@ func (conn *Set) addConnection(protocol netp.ProtocolString, srcMinP, srcMaxP, dstMinP, dstMaxP, icmpTypeMin, icmpTypeMax, icmpCodeMin, icmpCodeMax int64) { code := ProtocolStringToCode(protocol) - cube := hypercube.FromCubeShort(code, code, + cube := hypercube.Cube(code, code, srcMinP, srcMaxP, dstMinP, dstMaxP, icmpTypeMin, icmpTypeMax, icmpCodeMin, icmpCodeMax) conn.connectionProperties = conn.connectionProperties.Union(cube) diff --git a/pkg/connection/statefulness.go b/pkg/connection/statefulness.go index 3805d8e..3369ca5 100644 --- a/pkg/connection/statefulness.go +++ b/pkg/connection/statefulness.go @@ -3,6 +3,8 @@ package connection import ( + "slices" + "github.com/np-guard/models/pkg/hypercube" "github.com/np-guard/models/pkg/netp" ) @@ -75,7 +77,8 @@ func (conn *Set) switchSrcDstPortsOnTCP() *Set { // assuming cube[protocol] contains TCP only // no need to switch if src equals dst if !cube[srcPort].Equal(cube[dstPort]) { - cube = hypercube.CopyCube(cube) + // Shallow clone should be enough, since we do shallow swap: + cube = slices.Clone(cube) cube[srcPort], cube[dstPort] = cube[dstPort], cube[srcPort] } res.connectionProperties = res.connectionProperties.Union(hypercube.FromCube(cube)) diff --git a/pkg/hypercube/hypercubeset.go b/pkg/hypercube/hypercubeset.go index a242f65..236ff6f 100644 --- a/pkg/hypercube/hypercubeset.go +++ b/pkg/hypercube/hypercubeset.go @@ -300,3 +300,14 @@ func FromCube(cube []*interval.CanonicalSet) *CanonicalSet { res.layers[cube[0].Copy()] = FromCube(cube[1:]) return res } + +// cube returns a new hypercube.CanonicalSet created from a single input cube +// the input cube is given as an ordered list of integer values, where each two values +// represent the range (start,end) for a dimension value +func Cube(values ...int64) *CanonicalSet { + cube := []*interval.CanonicalSet{} + for i := 0; i < len(values); i += 2 { + cube = append(cube, interval.CreateSetFromInterval(values[i], values[i+1])) + } + return FromCube(cube) +} diff --git a/pkg/hypercube/hypercubeset_test.go b/pkg/hypercube/hypercubeset_test.go index 3b29533..c34a186 100644 --- a/pkg/hypercube/hypercubeset_test.go +++ b/pkg/hypercube/hypercubeset_test.go @@ -14,11 +14,7 @@ import ( // the input cube is given as an ordered list of integer values, where each two values // represent the range (start,end) for a dimension value func cube(values ...int64) *hypercube.CanonicalSet { - cube := []*interval.CanonicalSet{} - for i := 0; i < len(values); i += 2 { - cube = append(cube, interval.FromInterval(values[i], values[i+1])) - } - return hypercube.FromCube(cube) + return hypercube.Cube(values...) } func union(set *hypercube.CanonicalSet, sets ...*hypercube.CanonicalSet) *hypercube.CanonicalSet { From 00486c5347371a6a437f5b100cf629ee6337de3c Mon Sep 17 00:00:00 2001 From: Elazar Gershuni Date: Tue, 19 Mar 2024 10:31:14 +0200 Subject: [PATCH 03/15] fix merge Signed-off-by: Elazar Gershuni --- Makefile | 2 +- pkg/connection/connectionset.go | 10 +++--- pkg/hypercube/hypercubeset.go | 6 ++-- pkg/interval/interval.go | 8 +++++ pkg/interval/intervalset.go | 54 +++++++++++------------------ pkg/interval/intervalset_test.go | 28 +++++++-------- pkg/ipblock/ipblock.go | 58 +++++++++++++++----------------- pkg/ipblock/ipblock_test.go | 2 +- pkg/model/data_model.go | 2 +- 9 files changed, 81 insertions(+), 89 deletions(-) diff --git a/Makefile b/Makefile index c02dc03..6fdf75b 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ REPOSITORY := github.com/np-guard/models -JSON_PACKAGE_NAME := model +JSON_PACKAGE_NAME := spec mod: go.mod @echo -- $@ -- diff --git a/pkg/connection/connectionset.go b/pkg/connection/connectionset.go index 2d07687..7e70f0b 100644 --- a/pkg/connection/connectionset.go +++ b/pkg/connection/connectionset.go @@ -57,15 +57,15 @@ var dimensionsList = []Dimension{ func entireDimension(dim Dimension) *interval.CanonicalSet { switch dim { case protocol: - return interval.CreateSetFromInterval(minProtocol, maxProtocol) + return interval.New(minProtocol, maxProtocol).ToSet() case srcPort: - return interval.CreateSetFromInterval(MinPort, MaxPort) + return interval.New(MinPort, MaxPort).ToSet() case dstPort: - return interval.CreateSetFromInterval(MinPort, MaxPort) + return interval.New(MinPort, MaxPort).ToSet() case icmpType: - return interval.CreateSetFromInterval(MinICMPtype, MaxICMPtype) + return interval.New(MinICMPtype, MaxICMPtype).ToSet() case icmpCode: - return interval.CreateSetFromInterval(MinICMPcode, MaxICMPcode) + return interval.New(MinICMPcode, MaxICMPcode).ToSet() } return nil } diff --git a/pkg/hypercube/hypercubeset.go b/pkg/hypercube/hypercubeset.go index 236ff6f..f014f37 100644 --- a/pkg/hypercube/hypercubeset.go +++ b/pkg/hypercube/hypercubeset.go @@ -166,7 +166,7 @@ func (c *CanonicalSet) Subtract(other *CanonicalSet) *CanonicalSet { } func (c *CanonicalSet) getIntervalSetUnion() *interval.CanonicalSet { - res := interval.NewCanonicalIntervalSet() + res := interval.NewCanonicalSet() for k := range c.layers { res = res.Union(k) } @@ -301,13 +301,13 @@ func FromCube(cube []*interval.CanonicalSet) *CanonicalSet { return res } -// cube returns a new hypercube.CanonicalSet created from a single input cube +// Cube returns a new CanonicalSet created from a single input cube // the input cube is given as an ordered list of integer values, where each two values // represent the range (start,end) for a dimension value func Cube(values ...int64) *CanonicalSet { cube := []*interval.CanonicalSet{} for i := 0; i < len(values); i += 2 { - cube = append(cube, interval.CreateSetFromInterval(values[i], values[i+1])) + cube = append(cube, interval.New(values[i], values[i+1]).ToSet()) } return FromCube(cube) } diff --git a/pkg/interval/interval.go b/pkg/interval/interval.go index facdff5..18f1d08 100644 --- a/pkg/interval/interval.go +++ b/pkg/interval/interval.go @@ -20,6 +20,10 @@ func (i Interval) Equal(x Interval) bool { return i.Start == x.Start && i.End == x.End } +func New(start, end int64) Interval { + return Interval{Start: start, End: end} +} + func (i Interval) Size() int64 { return i.End - i.Start + 1 } @@ -58,3 +62,7 @@ func (i Interval) intersection(other Interval) []Interval { } return []Interval{{Start: maxStart, End: minEnd}} } + +func (i Interval) ToSet() *CanonicalSet { + return NewSetFromInterval(i) +} diff --git a/pkg/interval/intervalset.go b/pkg/interval/intervalset.go index 3e14839..b7d15dd 100644 --- a/pkg/interval/intervalset.go +++ b/pkg/interval/intervalset.go @@ -3,7 +3,6 @@ package interval import ( - "errors" "fmt" "log" "slices" @@ -15,7 +14,7 @@ type CanonicalSet struct { intervalSet []Interval } -func NewCanonicalIntervalSet() *CanonicalSet { +func NewCanonicalSet() *CanonicalSet { return &CanonicalSet{ intervalSet: []Interval{}, } @@ -36,7 +35,7 @@ func (c *CanonicalSet) Min() int64 { return c.intervalSet[0].Start } -// IsEmpty returns true if the CanonicalSet is empty +// IsEmpty returns true if the CanonicalSet is empty func (c *CanonicalSet) IsEmpty() bool { return len(c.intervalSet) == 0 } @@ -115,8 +114,8 @@ func (c *CanonicalSet) Union(other *CanonicalSet) *CanonicalSet { if c == other { return res } - for _, interval := range other.intervalSet { - res.AddInterval(interval) + for _, v := range other.intervalSet { + res.AddInterval(v) } return res } @@ -127,11 +126,10 @@ func (c *CanonicalSet) Copy() *CanonicalSet { } func (c *CanonicalSet) Contains(n int64) bool { - i := CreateSetFromInterval(n, n) - return i.ContainedIn(c) + return New(n, n).ToSet().ContainedIn(c) } -// ContainedIn returns true of the current interval.CanonicalSet is contained in the input interval.CanonicalSet +// ContainedIn returns true of the current CanonicalSet is contained in the input CanonicalSet func (c *CanonicalSet) ContainedIn(other *CanonicalSet) bool { if c == other { return true @@ -155,17 +153,19 @@ func (c *CanonicalSet) Intersect(other *CanonicalSet) *CanonicalSet { if c == other { return c.Copy() } - newIntervalSet := []Interval{} + res := NewCanonicalSet() for _, interval := range c.intervalSet { for _, otherInterval := range other.intervalSet { - newIntervalSet = append(newIntervalSet, interval.intersection(otherInterval)...) + for _, span := range interval.intersection(otherInterval) { + res.AddInterval(span) + } } } - c.intervalSet = newIntervalSet + return res } -// Overlaps returns true if current CanonicalSet overlaps with input CanonicalSet -func (c *CanonicalSet) Overlaps(other *CanonicalSet) bool { +// Overlap returns true if current CanonicalSet overlaps with input CanonicalSet +func (c *CanonicalSet) Overlap(other *CanonicalSet) bool { if c == other { return !c.IsEmpty() } @@ -179,11 +179,13 @@ func (c *CanonicalSet) Overlaps(other *CanonicalSet) bool { return false } -// Subtract updates current CanonicalSet with subtraction result of input CanonicalSet -func (c *CanonicalSet) Subtract(other *CanonicalSet) { +// Subtract returns the subtraction result of input CanonicalSet +func (c *CanonicalSet) Subtract(other *CanonicalSet) *CanonicalSet { + res := c.Copy() for _, i := range other.intervalSet { - c.AddHole(i) + res.AddHole(i) } + return res } func (c *CanonicalSet) IsSingleNumber() bool { @@ -192,22 +194,6 @@ func (c *CanonicalSet) IsSingleNumber() bool { } return false } -// Split returns a set of canonical set objects, each with a single interval -func (c *CanonicalSet) Split() []*CanonicalSet { - res := make([]*CanonicalSet, len(c.intervalSet)) - for index, ipr := range c.intervalSet { - res[index] = CreateSetFromInterval(ipr.Start, ipr.End) - } - return res -} - -func (c *CanonicalSet) Intervals() []Interval { - return slices.Clone(c.intervalSet) -} - -func (c *CanonicalSet) NumIntervals() int { - return len(c.intervalSet) -} func (c *CanonicalSet) Elements() []int64 { res := []int64{} @@ -219,6 +205,6 @@ func (c *CanonicalSet) Elements() []int64 { return res } -func CreateSetFromInterval(start, end int64) *CanonicalSet { - return &CanonicalSet{intervalSet: []Interval{{Start: start, End: end}}} +func NewSetFromInterval(span Interval) *CanonicalSet { + return &CanonicalSet{intervalSet: []Interval{span}} } diff --git a/pkg/interval/intervalset_test.go b/pkg/interval/intervalset_test.go index 76876dd..b80167c 100644 --- a/pkg/interval/intervalset_test.go +++ b/pkg/interval/intervalset_test.go @@ -17,16 +17,16 @@ func TestInterval(t *testing.T) { } func TestIntervalSet(t *testing.T) { - is1 := interval.NewCanonicalIntervalSet() + is1 := interval.NewCanonicalSet() is1.AddInterval(interval.Interval{5, 10}) is1.AddInterval(interval.Interval{0, 1}) is1.AddInterval(interval.Interval{3, 3}) is1.AddInterval(interval.Interval{70, 80}) - is1 = is1.Subtract(interval.CreateSetFromInterval(7, 9)) + is1 = is1.Subtract(interval.New(7, 9).ToSet()) require.True(t, is1.Contains(5)) require.False(t, is1.Contains(8)) - is2 := interval.NewCanonicalIntervalSet() + is2 := interval.NewCanonicalSet() require.Equal(t, "Empty", is2.String()) is2.AddInterval(interval.Interval{6, 8}) require.Equal(t, "6-8", is2.String()) @@ -35,34 +35,34 @@ func TestIntervalSet(t *testing.T) { require.False(t, is1.ContainedIn(is2)) require.False(t, is2.Equal(is1)) require.False(t, is1.Equal(is2)) - require.True(t, is1.Overlaps(is2)) - require.True(t, is2.Overlaps(is1)) + require.True(t, is1.Overlap(is2)) + require.True(t, is2.Overlap(is1)) is1 = is1.Subtract(is2) require.False(t, is2.ContainedIn(is1)) require.False(t, is1.ContainedIn(is2)) - require.False(t, is1.Overlaps(is2)) - require.False(t, is2.Overlaps(is1)) + require.False(t, is1.Overlap(is2)) + require.False(t, is2.Overlap(is1)) - is1 = is1.Union(is2).Union(interval.CreateSetFromInterval(7, 9)) + is1 = is1.Union(is2).Union(interval.New(7, 9).ToSet()) require.True(t, is2.ContainedIn(is1)) require.False(t, is1.ContainedIn(is2)) - require.True(t, is1.Overlaps(is2)) - require.True(t, is2.Overlaps(is1)) + require.True(t, is1.Overlap(is2)) + require.True(t, is2.Overlap(is1)) is3 := is1.Intersect(is2) require.True(t, is3.Equal(is2)) require.True(t, is2.ContainedIn(is3)) - require.True(t, interval.CreateSetFromInterval(1, 1).IsSingleNumber()) + require.True(t, interval.New(1, 1).ToSet().IsSingleNumber()) } func TestIntervalSetSubtract(t *testing.T) { - s := interval.CreateSetFromInterval(1, 100) + s := interval.New(1, 100).ToSet() s.AddInterval(interval.Interval{Start: 400, End: 700}) - d := *interval.CreateSetFromInterval(50, 100) + d := *interval.New(50, 100).ToSet() d.AddInterval(interval.Interval{Start: 400, End: 700}) actual := s.Subtract(&d) - expected := interval.CreateSetFromInterval(1, 49) + expected := interval.New(1, 49).ToSet() require.Equal(t, expected.String(), actual.String()) } diff --git a/pkg/ipblock/ipblock.go b/pkg/ipblock/ipblock.go index 0975aeb..b541a3a 100644 --- a/pkg/ipblock/ipblock.go +++ b/pkg/ipblock/ipblock.go @@ -41,7 +41,7 @@ type IPBlock struct { // New returns a new IPBlock object func New() *IPBlock { return &IPBlock{ - ipRange: interval.NewCanonicalIntervalSet(), + ipRange: interval.NewCanonicalSet(), } } @@ -113,7 +113,7 @@ func (b *IPBlock) Union(c *IPBlock) *IPBlock { } // Empty returns true if this IPBlock is empty -func (b *IPBlock) Empty() bool { +func (b *IPBlock) IsEmpty() bool { return b.ipRange.IsEmpty() } @@ -133,9 +133,9 @@ func (b *IPBlock) ipCount() int { // Split returns a set of IpBlock objects, each with a single range of ips func (b *IPBlock) Split() []*IPBlock { res := make([]*IPBlock, b.ipRange.NumIntervals()) - for index, set := range b.ipRange.Split() { + for index, span := range b.ipRange.Intervals() { res[index] = &IPBlock{ - ipRange: set, + ipRange: span.ToSet(), } } return res @@ -177,7 +177,7 @@ func DisjointIPBlocks(set1, set2 []*IPBlock) []*IPBlock { func addIntervalToList(ipbNew *IPBlock, ipbList []*IPBlock) []*IPBlock { toAdd := []*IPBlock{} for idx, ipb := range ipbList { - if !ipb.ipRange.Overlaps(ipbNew.ipRange) { + if !ipb.ipRange.Overlap(ipbNew.ipRange) { continue } intersection := ipb.Intersect(ipbNew) @@ -186,7 +186,7 @@ func addIntervalToList(ipbNew *IPBlock, ipbList []*IPBlock) []*IPBlock { toAdd = append(toAdd, intersection) ipbList[idx] = ipbList[idx].Subtract(intersection) } - if ipbNew.ipRange.IsEmpty() { + if ipbNew.IsEmpty() { break } } @@ -197,28 +197,15 @@ func addIntervalToList(ipbNew *IPBlock, ipbList []*IPBlock) []*IPBlock { // FromCidr returns a new IPBlock object from input CIDR string func FromCidr(cidr string) (*IPBlock, error) { - start, end, err := cidrToIPRange(cidr) + ipRange, err := cidrToIPRange(cidr) if err != nil { return nil, err } return &IPBlock{ - ipRange: interval.CreateSetFromInterval(start, end), + ipRange: ipRange.ToSet(), }, nil } -// ExceptCidrs returns a new IPBlock with all cidr ranges removed -func (b *IPBlock) ExceptCidrs(cidrExceptions ...string) (*IPBlock, error) { - res := b.Copy() - for i := range cidrExceptions { - hole, err := FromCidr(cidrExceptions[i]) - if err != nil { - return nil, err - } - res = res.Subtract(hole) - } - return res, nil -} - // PairCIDRsToIPBlocks returns two IPBlock objects from two input CIDR strings func PairCIDRsToIPBlocks(cidr1, cidr2 string) (ipb1, ipb2 *IPBlock, err error) { ipb1, err1 := FromCidr(cidr1) @@ -229,7 +216,7 @@ func PairCIDRsToIPBlocks(cidr1, cidr2 string) (ipb1, ipb2 *IPBlock, err error) { return ipb1, ipb2, nil } -// FromCidr returns a new IPBlock object from input string of CIDR or IP address +// FromCidrOrAddress returns a new IPBlock object from input string of CIDR or IP address func FromCidrOrAddress(s string) (*IPBlock, error) { if strings.Contains(s, cidrSeparator) { return FromCidr(s) @@ -239,7 +226,7 @@ func FromCidrOrAddress(s string) (*IPBlock, error) { // FromCidrList returns IPBlock object from multiple CIDRs given as list of strings func FromCidrList(cidrsList []string) (*IPBlock, error) { - ipRange := interval.NewCanonicalIntervalSet() + ipRange := interval.NewCanonicalSet() for _, cidr := range cidrsList { block, err := FromCidr(cidr) if err != nil { @@ -250,6 +237,19 @@ func FromCidrList(cidrsList []string) (*IPBlock, error) { return &IPBlock{ipRange: ipRange}, nil } +// ExceptCidrs returns a new IPBlock with all cidr ranges removed +func (b *IPBlock) ExceptCidrs(cidrExceptions ...string) (*IPBlock, error) { + res := b.Copy() + for i := range cidrExceptions { + hole, err := FromCidr(cidrExceptions[i]) + if err != nil { + return nil, err + } + res = res.Subtract(hole) + } + return res, nil +} + // FromIPAddress returns an IPBlock object from input IP address string func FromIPAddress(ipAddress string) (*IPBlock, error) { ipNum, err := parseIP(ipAddress) @@ -257,15 +257,15 @@ func FromIPAddress(ipAddress string) (*IPBlock, error) { return nil, err } return &IPBlock{ - ipRange: interval.CreateSetFromInterval(ipNum, ipNum), + ipRange: interval.New(ipNum, ipNum).ToSet(), }, nil } -func cidrToIPRange(cidr string) (start, end int64, err error) { +func cidrToIPRange(cidr string) (interval.Interval, error) { // convert string to IPNet struct _, ipv4Net, err := net.ParseCIDR(cidr) if err != nil { - return + return interval.Interval{}, err } // convert IPNet struct mask and address to uint32 @@ -274,9 +274,7 @@ func cidrToIPRange(cidr string) (start, end int64, err error) { startNum := binary.BigEndian.Uint32(ipv4Net.IP) // find the final address endNum := (startNum & mask) | (mask ^ ipMask) - start = int64(startNum) - end = int64(endNum) - return + return interval.New(int64(startNum), int64(endNum)), nil } // ToCidrList returns a list of CIDR strings for this IPBlock object @@ -362,7 +360,7 @@ func FromIPRangeStr(ipRangeStr string) (*IPBlock, error) { return nil, errors.Join(err0, err1) } res := &IPBlock{ - ipRange: interval.CreateSetFromInterval(startIPNum, endIPNum), + ipRange: interval.New(startIPNum, endIPNum).ToSet(), } return res, nil } diff --git a/pkg/ipblock/ipblock_test.go b/pkg/ipblock/ipblock_test.go index e941cc4..025d93b 100644 --- a/pkg/ipblock/ipblock_test.go +++ b/pkg/ipblock/ipblock_test.go @@ -36,7 +36,7 @@ func TestOps(t *testing.T) { require.Equal(t, union, ipb1) intersect2 := minus.Intersect(intersect) - require.True(t, intersect2.Empty()) + require.True(t, intersect2.IsEmpty()) } func TestConversions(t *testing.T) { diff --git a/pkg/model/data_model.go b/pkg/model/data_model.go index fcb56cd..5218a66 100644 --- a/pkg/model/data_model.go +++ b/pkg/model/data_model.go @@ -1,6 +1,6 @@ // Code generated by github.com/atombender/go-jsonschema, DO NOT EDIT. -package model +package spec import ( "encoding/json" From 29ce19d1189eb287056bfacad21e12064ed8f3b4 Mon Sep 17 00:00:00 2001 From: Elazar Gershuni Date: Tue, 19 Mar 2024 12:51:07 +0200 Subject: [PATCH 04/15] minor Signed-off-by: Elazar Gershuni --- pkg/ipblock/ipblock.go | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/pkg/ipblock/ipblock.go b/pkg/ipblock/ipblock.go index 6e41b63..65e6fbf 100644 --- a/pkg/ipblock/ipblock.go +++ b/pkg/ipblock/ipblock.go @@ -197,7 +197,7 @@ func addIntervalToList(ipbNew *IPBlock, ipbList []*IPBlock) []*IPBlock { // FromCidr returns a new IPBlock object from input CIDR string func FromCidr(cidr string) (*IPBlock, error) { - span, err := cidrToIPRange(cidr) + span, err := cidrToInterval(cidr) if err != nil { return nil, err } @@ -226,28 +226,28 @@ func FromCidrOrAddress(s string) (*IPBlock, error) { // FromCidrList returns IPBlock object from multiple CIDRs given as list of strings func FromCidrList(cidrsList []string) (*IPBlock, error) { - ipRange := New() + res := New() for _, cidr := range cidrsList { block, err := FromCidr(cidr) if err != nil { return nil, err } - ipRange = ipRange.Union(block) + res = res.Union(block) } - return ipRange, nil + return res, nil } // ExceptCidrs returns a new IPBlock with all cidr ranges removed -func (b *IPBlock) ExceptCidrs(cidrExceptions ...string) (*IPBlock, error) { - res := b.Copy() - for i := range cidrExceptions { - hole, err := FromCidr(cidrExceptions[i]) +func (b *IPBlock) ExceptCidrs(exceptions ...string) (*IPBlock, error) { + holes := interval.NewCanonicalSet() + for i := range exceptions { + intervalHole, err := cidrToInterval(exceptions[i]) if err != nil { return nil, err } - res = res.Subtract(hole) + holes.AddInterval(intervalHole) } - return res, nil + return &IPBlock{ipRange: b.ipRange.Subtract(holes)}, nil } // FromIPAddress returns an IPBlock object from input IP address string @@ -261,7 +261,7 @@ func FromIPAddress(ipAddress string) (*IPBlock, error) { }, nil } -func cidrToIPRange(cidr string) (interval.Interval, error) { +func cidrToInterval(cidr string) (interval.Interval, error) { // convert string to IPNet struct _, ipv4Net, err := net.ParseCIDR(cidr) if err != nil { From 9a3b582567ec4c04fc76290ae59a764858d9ddd9 Mon Sep 17 00:00:00 2001 From: Elazar Gershuni Date: Wed, 20 Mar 2024 12:15:41 +0200 Subject: [PATCH 05/15] minor --- pkg/connection/connectionset.go | 18 ++++++++++-------- pkg/connection/statefulness.go | 4 ++-- pkg/connection/statefulness_test.go | 6 +++--- 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/pkg/connection/connectionset.go b/pkg/connection/connectionset.go index 7e70f0b..c7c5ff4 100644 --- a/pkg/connection/connectionset.go +++ b/pkg/connection/connectionset.go @@ -16,10 +16,10 @@ const ( TCPCode = 0 UDPCode = 1 ICMPCode = 2 - MinICMPtype int64 = 0 - MaxICMPtype int64 = netp.InformationReply - MinICMPcode int64 = 0 - MaxICMPcode int64 = 5 + MinICMPType int64 = 0 + MaxICMPType int64 = netp.InformationReply + MinICMPCode int64 = 0 + MaxICMPCode int64 = 5 minProtocol int64 = 0 maxProtocol int64 = 2 MinPort = 1 @@ -63,9 +63,9 @@ func entireDimension(dim Dimension) *interval.CanonicalSet { case dstPort: return interval.New(MinPort, MaxPort).ToSet() case icmpType: - return interval.New(MinICMPtype, MaxICMPtype).ToSet() + return interval.New(MinICMPType, MaxICMPType).ToSet() case icmpCode: - return interval.New(MinICMPcode, MaxICMPcode).ToSet() + return interval.New(MinICMPCode, MaxICMPCode).ToSet() } return nil } @@ -92,6 +92,8 @@ func All() *Set { return newSet(true) } +var all = All() + func None() *Set { return newSet(false) } @@ -215,7 +217,7 @@ func TCPorUDPConnection(protocol netp.ProtocolString, srcMinP, srcMaxP, dstMinP, conn := None() conn.addConnection(protocol, srcMinP, srcMaxP, dstMinP, dstMaxP, - MinICMPtype, MaxICMPtype, MinICMPcode, MaxICMPcode) + MinICMPType, MaxICMPType, MinICMPCode, MaxICMPCode) return conn } @@ -415,7 +417,7 @@ func getCubeAsICMPItems(cube []*interval.CanonicalSet) []netp.Protocol { type Details []netp.Protocol -func ConnToJSONRep(c *Set) Details { +func ToJSON(c *Set) Details { if c == nil { return nil // one of the connections in connectionDiff can be empty } diff --git a/pkg/connection/statefulness.go b/pkg/connection/statefulness.go index 3369ca5..27719d7 100644 --- a/pkg/connection/statefulness.go +++ b/pkg/connection/statefulness.go @@ -35,13 +35,13 @@ func newTCPSet() *Set { return TCPorUDPConnection(netp.ProtocolStringTCP, MinPort, MaxPort, MinPort, MaxPort) } -// ConnectionWithStatefulness updates `conn` object with `IsStateful` property, based on input `secondDirectionConn`. +// WithStatefulness updates `conn` object with `IsStateful` property, based on input `secondDirectionConn`. // `conn` represents a src-to-dst connection, and `secondDirectionConn` represents dst-to-src connection. // The property `IsStateful` of `conn` is set as `StatefulFalse` if there is at least some subset within TCP from `conn` // which is not stateful (such that the response direction for this subset is not enabled). // This function also returns a connection object with the exact subset of the stateful part (within TCP) // from the entire connection `conn`, and with the original connections on other protocols. -func (conn *Set) ConnectionWithStatefulness(secondDirectionConn *Set) *Set { +func (conn *Set) WithStatefulness(secondDirectionConn *Set) *Set { connTCP := conn.Intersect(newTCPSet()) if connTCP.IsEmpty() { conn.IsStateful = StatefulTrue diff --git a/pkg/connection/statefulness_test.go b/pkg/connection/statefulness_test.go index 3c5b111..27f627f 100644 --- a/pkg/connection/statefulness_test.go +++ b/pkg/connection/statefulness_test.go @@ -24,8 +24,8 @@ func newUDPConn(t *testing.T, srcMinP, srcMaxP, dstMinP, dstMaxP int64) *connect func newICMPconn(t *testing.T) *connection.Set { t.Helper() return connection.ICMPConnection( - connection.MinICMPtype, connection.MaxICMPtype, - connection.MinICMPcode, connection.MaxICMPcode) + connection.MinICMPType, connection.MaxICMPType, + connection.MinICMPCode, connection.MaxICMPCode) } func newTCPUDPSet(t *testing.T, p netp.ProtocolString) *connection.Set { @@ -50,7 +50,7 @@ type statefulnessTest struct { func (tt statefulnessTest) runTest(t *testing.T) { t.Helper() - statefulConn := tt.srcToDst.ConnectionWithStatefulness(tt.dstToSrc) + statefulConn := tt.srcToDst.WithStatefulness(tt.dstToSrc) require.Equal(t, tt.expectedIsStateful, tt.srcToDst.IsStateful) require.True(t, tt.expectedStatefulConn.Equal(statefulConn)) } From f095699a7cdbf15325183c8d7c1c64f6a4810e20 Mon Sep 17 00:00:00 2001 From: Elazar Gershuni Date: Wed, 20 Mar 2024 13:05:30 +0200 Subject: [PATCH 06/15] allow all as method --- pkg/connection/connectionset.go | 112 +++++++++++++++++--------------- pkg/connection/statefulness.go | 50 +++++++------- 2 files changed, 83 insertions(+), 79 deletions(-) diff --git a/pkg/connection/connectionset.go b/pkg/connection/connectionset.go index c7c5ff4..1339f8f 100644 --- a/pkg/connection/connectionset.go +++ b/pkg/connection/connectionset.go @@ -79,13 +79,13 @@ func getDimensionDomainsList() []*interval.CanonicalSet { } type Set struct { - AllowAll bool + allowAll bool connectionProperties *hypercube.CanonicalSet IsStateful StatefulState } func newSet(all bool) *Set { - return &Set{AllowAll: all, connectionProperties: hypercube.NewCanonicalSet(numDimensions)} + return &Set{allowAll: all, connectionProperties: hypercube.NewCanonicalSet(numDimensions)} } func All() *Set { @@ -98,44 +98,48 @@ func None() *Set { return newSet(false) } -func (conn *Set) Copy() *Set { +func (c *Set) AllowAll() bool { + return c.allowAll +} + +func (c *Set) Copy() *Set { return &Set{ - AllowAll: conn.AllowAll, - connectionProperties: conn.connectionProperties.Copy(), - IsStateful: conn.IsStateful, + allowAll: c.allowAll, + connectionProperties: c.connectionProperties.Copy(), + IsStateful: c.IsStateful, } } -func (conn *Set) Intersect(other *Set) *Set { - if other.AllowAll { - return conn.Copy() +func (c *Set) Intersect(other *Set) *Set { + if other.allowAll { + return c.Copy() } - if conn.AllowAll { + if c.allowAll { return other.Copy() } - return &Set{AllowAll: false, connectionProperties: conn.connectionProperties.Intersect(other.connectionProperties)} + return &Set{allowAll: false, connectionProperties: c.connectionProperties.Intersect(other.connectionProperties)} } -func (conn *Set) IsEmpty() bool { - if conn.AllowAll { +func (c *Set) IsEmpty() bool { + if c.allowAll { return false } - return conn.connectionProperties.IsEmpty() + return c.connectionProperties.IsEmpty() } -func (conn *Set) Union(other *Set) *Set { - if conn.AllowAll || other.AllowAll { +func (c *Set) Union(other *Set) *Set { + if c.allowAll || other.allowAll { return All() } if other.IsEmpty() { - return conn.Copy() + return c.Copy() } - if conn.IsEmpty() { + if c.IsEmpty() { return other.Copy() } res := &Set{ - AllowAll: false, - connectionProperties: conn.connectionProperties.Union(other.connectionProperties), + allowAll: false, + connectionProperties: c.connectionProperties.Union(other.connectionProperties), } res.canonicalize() return res @@ -148,34 +152,34 @@ func getAllPropertiesObject() *hypercube.CanonicalSet { // Subtract // ToDo: Subtract seems to ignore IsStateful (see https://github.com/np-guard/vpc-network-config-analyzer/issues/199): // 1. is the delta connection stateful -// 2. connectionProperties is identical but conn stateful while other is not +// 2. connectionProperties is identical but c stateful while other is not // the 2nd item can be computed here, with enhancement to relevant structure // the 1st can not since we do not know where exactly the statefulness came from -func (conn *Set) Subtract(other *Set) *Set { - if conn.IsEmpty() || other.AllowAll { +func (c *Set) Subtract(other *Set) *Set { + if c.IsEmpty() || other.allowAll { return None() } if other.IsEmpty() { - return conn.Copy() + return c.Copy() } var connProperties *hypercube.CanonicalSet - if conn.AllowAll { + if c.allowAll { connProperties = getAllPropertiesObject() } else { - connProperties = conn.connectionProperties + connProperties = c.connectionProperties } - return &Set{AllowAll: false, connectionProperties: connProperties.Subtract(other.connectionProperties)} + return &Set{allowAll: false, connectionProperties: connProperties.Subtract(other.connectionProperties)} } -// ContainedIn returns true if conn is subset of other -func (conn *Set) ContainedIn(other *Set) bool { - if other.AllowAll { +// ContainedIn returns true if c is subset of other +func (c *Set) ContainedIn(other *Set) bool { + if other.allowAll { return true } - if conn.AllowAll { + if c.allowAll { return false } - res, err := conn.connectionProperties.ContainedIn(other.connectionProperties) + res, err := c.connectionProperties.ContainedIn(other.connectionProperties) if err != nil { log.Fatalf("invalid connection set. %e", err) } @@ -195,48 +199,48 @@ func ProtocolStringToCode(protocol netp.ProtocolString) int64 { return 0 } -func (conn *Set) addConnection(protocol netp.ProtocolString, +func (c *Set) addConnection(protocol netp.ProtocolString, srcMinP, srcMaxP, dstMinP, dstMaxP, icmpTypeMin, icmpTypeMax, icmpCodeMin, icmpCodeMax int64) { code := ProtocolStringToCode(protocol) cube := hypercube.Cube(code, code, srcMinP, srcMaxP, dstMinP, dstMaxP, icmpTypeMin, icmpTypeMax, icmpCodeMin, icmpCodeMax) - conn.connectionProperties = conn.connectionProperties.Union(cube) - conn.canonicalize() + c.connectionProperties = c.connectionProperties.Union(cube) + c.canonicalize() } -func (conn *Set) canonicalize() { - if !conn.AllowAll && conn.connectionProperties.Equal(getAllPropertiesObject()) { - conn.AllowAll = true - conn.connectionProperties = hypercube.NewCanonicalSet(numDimensions) +func (c *Set) canonicalize() { + if !c.allowAll && c.connectionProperties.Equal(getAllPropertiesObject()) { + c.allowAll = true + c.connectionProperties = hypercube.NewCanonicalSet(numDimensions) } } func TCPorUDPConnection(protocol netp.ProtocolString, srcMinP, srcMaxP, dstMinP, dstMaxP int64) *Set { - conn := None() - conn.addConnection(protocol, + c := None() + c.addConnection(protocol, srcMinP, srcMaxP, dstMinP, dstMaxP, MinICMPType, MaxICMPType, MinICMPCode, MaxICMPCode) - return conn + return c } func ICMPConnection(icmpTypeMin, icmpTypeMax, icmpCodeMin, icmpCodeMax int64) *Set { - conn := None() - conn.addConnection(netp.ProtocolStringICMP, + c := None() + c.addConnection(netp.ProtocolStringICMP, MinPort, MaxPort, MinPort, MaxPort, icmpTypeMin, icmpTypeMax, icmpCodeMin, icmpCodeMax) - return conn + return c } -func (conn *Set) Equal(other *Set) bool { - if conn.AllowAll != other.AllowAll { +func (c *Set) Equal(other *Set) bool { + if c.allowAll != other.allowAll { return false } - if conn.AllowAll { + if c.allowAll { return true } - return conn.connectionProperties.Equal(other.connectionProperties) + return c.connectionProperties.Equal(other.connectionProperties) } func protocolStringFromCode(protocolCode int64) netp.ProtocolString { @@ -323,15 +327,15 @@ func getConnsCubeStr(cube []*interval.CanonicalSet) string { } // String returns a string representation of a Set object -func (conn *Set) String() string { - if conn.AllowAll { +func (c *Set) String() string { + if c.allowAll { return AllConnections - } else if conn.IsEmpty() { + } else if c.IsEmpty() { return NoConnections } resStrings := []string{} // get cubes and cube str per each cube - cubes := conn.connectionProperties.GetCubesList() + cubes := c.connectionProperties.GetCubesList() for _, cube := range cubes { resStrings = append(resStrings, getConnsCubeStr(cube)) } @@ -421,7 +425,7 @@ func ToJSON(c *Set) Details { if c == nil { return nil // one of the connections in connectionDiff can be empty } - if c.AllowAll { + if c.allowAll { return []netp.Protocol{netp.AnyProtocol{}} } var res []netp.Protocol diff --git a/pkg/connection/statefulness.go b/pkg/connection/statefulness.go index 27719d7..beacd00 100644 --- a/pkg/connection/statefulness.go +++ b/pkg/connection/statefulness.go @@ -24,43 +24,43 @@ const ( ) // EnhancedString returns a connection string with possibly added asterisk for stateless connection -func (conn *Set) EnhancedString() string { - if conn.IsStateful == StatefulFalse { - return conn.String() + " *" +func (c *Set) EnhancedString() string { + if c.IsStateful == StatefulFalse { + return c.String() + " *" } - return conn.String() + return c.String() } func newTCPSet() *Set { return TCPorUDPConnection(netp.ProtocolStringTCP, MinPort, MaxPort, MinPort, MaxPort) } -// WithStatefulness updates `conn` object with `IsStateful` property, based on input `secondDirectionConn`. -// `conn` represents a src-to-dst connection, and `secondDirectionConn` represents dst-to-src connection. -// The property `IsStateful` of `conn` is set as `StatefulFalse` if there is at least some subset within TCP from `conn` +// WithStatefulness updates `c` object with `IsStateful` property, based on input `secondDirectionConn`. +// `c` represents a src-to-dst connection, and `secondDirectionConn` represents dst-to-src connection. +// The property `IsStateful` of `c` is set as `StatefulFalse` if there is at least some subset within TCP from `c` // which is not stateful (such that the response direction for this subset is not enabled). // This function also returns a connection object with the exact subset of the stateful part (within TCP) -// from the entire connection `conn`, and with the original connections on other protocols. -func (conn *Set) WithStatefulness(secondDirectionConn *Set) *Set { - connTCP := conn.Intersect(newTCPSet()) +// from the entire connection `c`, and with the original connections on other protocols. +func (c *Set) WithStatefulness(secondDirectionConn *Set) *Set { + connTCP := c.Intersect(newTCPSet()) if connTCP.IsEmpty() { - conn.IsStateful = StatefulTrue - return conn + c.IsStateful = StatefulTrue + return c } statefulCombinedConnTCP := connTCP.connTCPWithStatefulness(secondDirectionConn.Intersect(newTCPSet())) - conn.IsStateful = connTCP.IsStateful - return conn.Subtract(connTCP).Union(statefulCombinedConnTCP) + c.IsStateful = connTCP.IsStateful + return c.Subtract(connTCP).Union(statefulCombinedConnTCP) } -// connTCPWithStatefulness assumes that both `conn` and `secondDirectionConn` are within TCP. -// it assigns IsStateful a value within `conn`, and returns the subset from `conn` which is stateful. -func (conn *Set) connTCPWithStatefulness(secondDirectionConn *Set) *Set { +// connTCPWithStatefulness assumes that both `c` and `secondDirectionConn` are within TCP. +// it assigns IsStateful a value within `c`, and returns the subset from `c` which is stateful. +func (c *Set) connTCPWithStatefulness(secondDirectionConn *Set) *Set { // flip src/dst ports before intersection - statefulCombinedConn := conn.Intersect(secondDirectionConn.switchSrcDstPortsOnTCP()) - if conn.Equal(statefulCombinedConn) { - conn.IsStateful = StatefulTrue + statefulCombinedConn := c.Intersect(secondDirectionConn.switchSrcDstPortsOnTCP()) + if c.Equal(statefulCombinedConn) { + c.IsStateful = StatefulTrue } else { - conn.IsStateful = StatefulFalse + c.IsStateful = StatefulFalse } return statefulCombinedConn } @@ -68,12 +68,12 @@ func (conn *Set) connTCPWithStatefulness(secondDirectionConn *Set) *Set { // switchSrcDstPortsOnTCP returns a new Set object, built from the input Set object. // It assumes the input connection object is only within TCP protocol. // For TCP the src and dst ports on relevant cubes are being switched. -func (conn *Set) switchSrcDstPortsOnTCP() *Set { - if conn.AllowAll || conn.IsEmpty() { - return conn.Copy() +func (c *Set) switchSrcDstPortsOnTCP() *Set { + if c.allowAll || c.IsEmpty() { + return c.Copy() } res := None() - for _, cube := range conn.connectionProperties.GetCubesList() { + for _, cube := range c.connectionProperties.GetCubesList() { // assuming cube[protocol] contains TCP only // no need to switch if src equals dst if !cube[srcPort].Equal(cube[dstPort]) { From 51bab16c50c723f2c0f87075066844d64cfc720c Mon Sep 17 00:00:00 2001 From: Elazar Gershuni Date: Wed, 20 Mar 2024 13:24:51 +0200 Subject: [PATCH 07/15] remove allowAll field, rename AllowAll() to IsAll() --- pkg/connection/connectionset.go | 82 ++++++--------------------------- pkg/connection/statefulness.go | 2 +- 2 files changed, 16 insertions(+), 68 deletions(-) diff --git a/pkg/connection/connectionset.go b/pkg/connection/connectionset.go index 1339f8f..d10054c 100644 --- a/pkg/connection/connectionset.go +++ b/pkg/connection/connectionset.go @@ -79,74 +79,49 @@ func getDimensionDomainsList() []*interval.CanonicalSet { } type Set struct { - allowAll bool connectionProperties *hypercube.CanonicalSet IsStateful StatefulState } -func newSet(all bool) *Set { - return &Set{allowAll: all, connectionProperties: hypercube.NewCanonicalSet(numDimensions)} +func None() *Set { + return &Set{connectionProperties: hypercube.NewCanonicalSet(numDimensions)} } func All() *Set { - return newSet(true) + return &Set{connectionProperties: hypercube.FromCube(getDimensionDomainsList())} } var all = All() -func None() *Set { - return newSet(false) -} - -func (c *Set) AllowAll() bool { - return c.allowAll +func (c *Set) IsAll() bool { + return c.Equal(all) } func (c *Set) Copy() *Set { return &Set{ - allowAll: c.allowAll, connectionProperties: c.connectionProperties.Copy(), IsStateful: c.IsStateful, } } func (c *Set) Intersect(other *Set) *Set { - if other.allowAll { - return c.Copy() - } - if c.allowAll { - return other.Copy() - } - return &Set{allowAll: false, connectionProperties: c.connectionProperties.Intersect(other.connectionProperties)} + return &Set{connectionProperties: c.connectionProperties.Intersect(other.connectionProperties)} } func (c *Set) IsEmpty() bool { - if c.allowAll { - return false - } return c.connectionProperties.IsEmpty() } func (c *Set) Union(other *Set) *Set { - if c.allowAll || other.allowAll { - return All() - } if other.IsEmpty() { return c.Copy() } if c.IsEmpty() { return other.Copy() } - res := &Set{ - allowAll: false, + return &Set{ connectionProperties: c.connectionProperties.Union(other.connectionProperties), } - res.canonicalize() - return res -} - -func getAllPropertiesObject() *hypercube.CanonicalSet { - return hypercube.FromCube(getDimensionDomainsList()) } // Subtract @@ -156,29 +131,17 @@ func getAllPropertiesObject() *hypercube.CanonicalSet { // the 2nd item can be computed here, with enhancement to relevant structure // the 1st can not since we do not know where exactly the statefulness came from func (c *Set) Subtract(other *Set) *Set { - if c.IsEmpty() || other.allowAll { + if c.IsEmpty() { return None() } if other.IsEmpty() { return c.Copy() } - var connProperties *hypercube.CanonicalSet - if c.allowAll { - connProperties = getAllPropertiesObject() - } else { - connProperties = c.connectionProperties - } - return &Set{allowAll: false, connectionProperties: connProperties.Subtract(other.connectionProperties)} + return &Set{connectionProperties: c.connectionProperties.Subtract(other.connectionProperties)} } // ContainedIn returns true if c is subset of other func (c *Set) ContainedIn(other *Set) bool { - if other.allowAll { - return true - } - if c.allowAll { - return false - } res, err := c.connectionProperties.ContainedIn(other.connectionProperties) if err != nil { log.Fatalf("invalid connection set. %e", err) @@ -207,14 +170,6 @@ func (c *Set) addConnection(protocol netp.ProtocolString, srcMinP, srcMaxP, dstMinP, dstMaxP, icmpTypeMin, icmpTypeMax, icmpCodeMin, icmpCodeMax) c.connectionProperties = c.connectionProperties.Union(cube) - c.canonicalize() -} - -func (c *Set) canonicalize() { - if !c.allowAll && c.connectionProperties.Equal(getAllPropertiesObject()) { - c.allowAll = true - c.connectionProperties = hypercube.NewCanonicalSet(numDimensions) - } } func TCPorUDPConnection(protocol netp.ProtocolString, srcMinP, srcMaxP, dstMinP, dstMaxP int64) *Set { @@ -234,12 +189,6 @@ func ICMPConnection(icmpTypeMin, icmpTypeMax, icmpCodeMin, icmpCodeMax int64) *S } func (c *Set) Equal(other *Set) bool { - if c.allowAll != other.allowAll { - return false - } - if c.allowAll { - return true - } return c.connectionProperties.Equal(other.connectionProperties) } @@ -328,15 +277,14 @@ func getConnsCubeStr(cube []*interval.CanonicalSet) string { // String returns a string representation of a Set object func (c *Set) String() string { - if c.allowAll { - return AllConnections - } else if c.IsEmpty() { + if c.IsEmpty() { return NoConnections + } else if c.IsAll() { + return AllConnections } - resStrings := []string{} // get cubes and cube str per each cube - cubes := c.connectionProperties.GetCubesList() - for _, cube := range cubes { + resStrings := []string{} + for _, cube := range c.connectionProperties.GetCubesList() { resStrings = append(resStrings, getConnsCubeStr(cube)) } @@ -425,7 +373,7 @@ func ToJSON(c *Set) Details { if c == nil { return nil // one of the connections in connectionDiff can be empty } - if c.allowAll { + if c.IsAll() { return []netp.Protocol{netp.AnyProtocol{}} } var res []netp.Protocol diff --git a/pkg/connection/statefulness.go b/pkg/connection/statefulness.go index beacd00..ceb1095 100644 --- a/pkg/connection/statefulness.go +++ b/pkg/connection/statefulness.go @@ -69,7 +69,7 @@ func (c *Set) connTCPWithStatefulness(secondDirectionConn *Set) *Set { // It assumes the input connection object is only within TCP protocol. // For TCP the src and dst ports on relevant cubes are being switched. func (c *Set) switchSrcDstPortsOnTCP() *Set { - if c.allowAll || c.IsEmpty() { + if c.IsAll() || c.IsEmpty() { return c.Copy() } res := None() From ed1a6896f0449f84af2735def80d197121439e82 Mon Sep 17 00:00:00 2001 From: Elazar Gershuni Date: Wed, 20 Mar 2024 13:31:56 +0200 Subject: [PATCH 08/15] less mutation Signed-off-by: Elazar Gershuni --- pkg/connection/connectionset.go | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/pkg/connection/connectionset.go b/pkg/connection/connectionset.go index d10054c..0b03666 100644 --- a/pkg/connection/connectionset.go +++ b/pkg/connection/connectionset.go @@ -162,30 +162,26 @@ func ProtocolStringToCode(protocol netp.ProtocolString) int64 { return 0 } -func (c *Set) addConnection(protocol netp.ProtocolString, +func cube(protocolString netp.ProtocolString, srcMinP, srcMaxP, dstMinP, dstMaxP, - icmpTypeMin, icmpTypeMax, icmpCodeMin, icmpCodeMax int64) { - code := ProtocolStringToCode(protocol) - cube := hypercube.Cube(code, code, - srcMinP, srcMaxP, dstMinP, dstMaxP, - icmpTypeMin, icmpTypeMax, icmpCodeMin, icmpCodeMax) - c.connectionProperties = c.connectionProperties.Union(cube) + icmpTypeMin, icmpTypeMax, icmpCodeMin, icmpCodeMax int64) *Set { + protocol := ProtocolStringToCode(protocolString) + return &Set{ + connectionProperties: hypercube.Cube(protocol, protocol, + srcMinP, srcMaxP, dstMinP, dstMaxP, + icmpTypeMin, icmpTypeMax, icmpCodeMin, icmpCodeMax)} } func TCPorUDPConnection(protocol netp.ProtocolString, srcMinP, srcMaxP, dstMinP, dstMaxP int64) *Set { - c := None() - c.addConnection(protocol, + return cube(protocol, srcMinP, srcMaxP, dstMinP, dstMaxP, MinICMPType, MaxICMPType, MinICMPCode, MaxICMPCode) - return c } func ICMPConnection(icmpTypeMin, icmpTypeMax, icmpCodeMin, icmpCodeMax int64) *Set { - c := None() - c.addConnection(netp.ProtocolStringICMP, + return cube(netp.ProtocolStringICMP, MinPort, MaxPort, MinPort, MaxPort, icmpTypeMin, icmpTypeMax, icmpCodeMin, icmpCodeMax) - return c } func (c *Set) Equal(other *Set) bool { From d45105a2dffa01b32b37c1c6ae5fffe750513ed4 Mon Sep 17 00:00:00 2001 From: Elazar Gershuni Date: Wed, 20 Mar 2024 14:28:40 +0200 Subject: [PATCH 09/15] revert ToJSON to use spec --- pkg/connection/connectionset.go | 128 +++++++++++++++++--------------- 1 file changed, 68 insertions(+), 60 deletions(-) diff --git a/pkg/connection/connectionset.go b/pkg/connection/connectionset.go index 0b03666..3c5d3dc 100644 --- a/pkg/connection/connectionset.go +++ b/pkg/connection/connectionset.go @@ -10,6 +10,7 @@ import ( "github.com/np-guard/models/pkg/hypercube" "github.com/np-guard/models/pkg/interval" "github.com/np-guard/models/pkg/netp" + "github.com/np-guard/models/pkg/spec" ) const ( @@ -288,105 +289,112 @@ func (c *Set) String() string { return strings.Join(resStrings, "; ") } -func getCubeAsTCPorUDPItems(cube []*interval.CanonicalSet, isTCP bool) []netp.Protocol { - tcpItemsTemp := []netp.Protocol{} +func getCubeAsTCPItems(cube []*interval.CanonicalSet, protocol spec.TcpUdpProtocol) []spec.TcpUdp { + tcpItemsTemp := []spec.TcpUdp{} + tcpItemsFinal := []spec.TcpUdp{} // consider src ports srcPorts := cube[srcPort] - if srcPorts.Equal(entireDimension(srcPort)) { - tcpItemsTemp = append(tcpItemsTemp, netp.TCPUDP{IsTCP: isTCP}) - } else { - // iterate the intervals in the interval-set - for _, portRange := range srcPorts.Intervals() { - tcpRes := netp.TCPUDP{ - IsTCP: isTCP, - PortRangePair: netp.PortRangePair{ - SrcPort: portRange, - DstPort: interval.Interval{Start: netp.MinPort, End: netp.MaxPort}, - }, - } + if !srcPorts.Equal(entireDimension(srcPort)) { + // iterate the interval in the interval-set + for _, interval := range srcPorts.Intervals() { + tcpRes := spec.TcpUdp{Protocol: protocol, MinSourcePort: int(interval.Start), MaxSourcePort: int(interval.End)} tcpItemsTemp = append(tcpItemsTemp, tcpRes) } + } else { + tcpItemsTemp = append(tcpItemsTemp, spec.TcpUdp{Protocol: protocol}) } // consider dst ports dstPorts := cube[dstPort] - if dstPorts.Equal(entireDimension(dstPort)) { - return tcpItemsTemp - } - tcpItemsFinal := []netp.Protocol{} - for _, portRange := range dstPorts.Intervals() { - for _, tcpItemTemp := range tcpItemsTemp { - item, _ := tcpItemTemp.(netp.TCPUDP) - tcpItemsFinal = append(tcpItemsFinal, netp.TCPUDP{ - IsTCP: isTCP, - PortRangePair: netp.PortRangePair{ - SrcPort: item.PortRangePair.SrcPort, - DstPort: portRange, - }, - }) + if !dstPorts.Equal(entireDimension(dstPort)) { + // iterate the interval in the interval-set + for _, interval := range dstPorts.Intervals() { + for _, tcpItemTemp := range tcpItemsTemp { + tcpRes := spec.TcpUdp{ + Protocol: protocol, + MinSourcePort: tcpItemTemp.MinSourcePort, + MaxSourcePort: tcpItemTemp.MaxSourcePort, + MinDestinationPort: int(interval.Start), + MaxDestinationPort: int(interval.End), + } + tcpItemsFinal = append(tcpItemsFinal, tcpRes) + } } + } else { + tcpItemsFinal = tcpItemsTemp } return tcpItemsFinal } -func getCubeAsICMPItems(cube []*interval.CanonicalSet) []netp.Protocol { +func getCubeAsICMPItems(cube []*interval.CanonicalSet) []spec.Icmp { icmpTypes := cube[icmpType] icmpCodes := cube[icmpCode] - if icmpCodes.Equal(entireDimension(icmpCode)) { - if icmpTypes.Equal(entireDimension(icmpType)) { - return []netp.Protocol{netp.ICMP{}} + allTypes := icmpTypes.Equal(entireDimension(icmpType)) + allCodes := icmpCodes.Equal(entireDimension(icmpCode)) + switch { + case allTypes && allCodes: + return []spec.Icmp{{Protocol: spec.IcmpProtocolICMP}} + case allTypes: + // This does not really make sense: not all types can have all codes + res := []spec.Icmp{} + for _, code64 := range icmpCodes.Elements() { + code := int(code64) + res = append(res, spec.Icmp{Protocol: spec.IcmpProtocolICMP, Code: &code}) } - res := []netp.Protocol{} - for _, t := range icmpTypes.Elements() { - icmp, err := netp.NewICMP(&netp.ICMPTypeCode{Type: int(t)}) - if err != nil { - log.Panic(err) - } - res = append(res, icmp) + return res + case allCodes: + res := []spec.Icmp{} + for _, type64 := range icmpTypes.Elements() { + t := int(type64) + res = append(res, spec.Icmp{Protocol: spec.IcmpProtocolICMP, Type: &t}) } return res - } - - // iterate both codes and types - res := []netp.Protocol{} - for _, t := range icmpTypes.Elements() { - codes := icmpCodes.Elements() - for i := range codes { - // TODO: merge when all codes for certain type exist - c := int(codes[i]) - icmp, err := netp.NewICMP(&netp.ICMPTypeCode{Type: int(t), Code: &c}) - if err != nil { - log.Panic(err) + default: + res := []spec.Icmp{} + // iterate both codes and types + for _, type64 := range icmpTypes.Elements() { + t := int(type64) + for _, code64 := range icmpCodes.Elements() { + code := int(code64) + res = append(res, spec.Icmp{Protocol: spec.IcmpProtocolICMP, Type: &t, Code: &code}) } - res = append(res, icmp) } + return res } - return res } -type Details []netp.Protocol +type Details spec.ProtocolList func ToJSON(c *Set) Details { if c == nil { return nil // one of the connections in connectionDiff can be empty } if c.IsAll() { - return []netp.Protocol{netp.AnyProtocol{}} + return Details(spec.ProtocolList{spec.AnyProtocol{Protocol: spec.AnyProtocolProtocolANY}}) } - var res []netp.Protocol + res := spec.ProtocolList{} cubes := c.connectionProperties.GetCubesList() for _, cube := range cubes { protocols := cube[protocol] if protocols.Contains(TCPCode) { - res = append(res, getCubeAsTCPorUDPItems(cube, true)...) + tcpItems := getCubeAsTCPItems(cube, spec.TcpUdpProtocolTCP) + for _, item := range tcpItems { + res = append(res, item) + } } if protocols.Contains(UDPCode) { - res = append(res, getCubeAsTCPorUDPItems(cube, false)...) + udpItems := getCubeAsTCPItems(cube, spec.TcpUdpProtocolUDP) + for _, item := range udpItems { + res = append(res, item) + } } if protocols.Contains(ICMPCode) { - res = append(res, getCubeAsICMPItems(cube)...) + icmpItems := getCubeAsICMPItems(cube) + for _, item := range icmpItems { + res = append(res, item) + } } } - return res + return Details(res) } From c71a66c61442703af4993ca52798615253ce5d5b Mon Sep 17 00:00:00 2001 From: Elazar Gershuni Date: Thu, 21 Mar 2024 09:47:41 +0200 Subject: [PATCH 10/15] cleanup string generation, avoid leaking internal encoding Signed-off-by: Elazar Gershuni --- pkg/connection/connectionset.go | 68 ++++++++++++---------------- pkg/connection/connectionset_test.go | 2 +- 2 files changed, 31 insertions(+), 39 deletions(-) diff --git a/pkg/connection/connectionset.go b/pkg/connection/connectionset.go index 3c5d3dc..d79b5bf 100644 --- a/pkg/connection/connectionset.go +++ b/pkg/connection/connectionset.go @@ -98,6 +98,10 @@ func (c *Set) IsAll() bool { return c.Equal(all) } +func (c *Set) Equal(other *Set) bool { + return c.connectionProperties.Equal(other.connectionProperties) +} + func (c *Set) Copy() *Set { return &Set{ connectionProperties: c.connectionProperties.Copy(), @@ -150,7 +154,7 @@ func (c *Set) ContainedIn(other *Set) bool { return res } -func ProtocolStringToCode(protocol netp.ProtocolString) int64 { +func protocolStringToCode(protocol netp.ProtocolString) int64 { switch protocol { case netp.ProtocolStringTCP: return TCPCode @@ -166,7 +170,7 @@ func ProtocolStringToCode(protocol netp.ProtocolString) int64 { func cube(protocolString netp.ProtocolString, srcMinP, srcMaxP, dstMinP, dstMaxP, icmpTypeMin, icmpTypeMax, icmpCodeMin, icmpCodeMax int64) *Set { - protocol := ProtocolStringToCode(protocolString) + protocol := protocolStringToCode(protocolString) return &Set{ connectionProperties: hypercube.Cube(protocol, protocol, srcMinP, srcMaxP, dstMinP, dstMaxP, @@ -185,10 +189,6 @@ func ICMPConnection(icmpTypeMin, icmpTypeMax, icmpCodeMin, icmpCodeMax int64) *S icmpTypeMin, icmpTypeMax, icmpCodeMin, icmpCodeMax) } -func (c *Set) Equal(other *Set) bool { - return c.connectionProperties.Equal(other.connectionProperties) -} - func protocolStringFromCode(protocolCode int64) netp.ProtocolString { switch protocolCode { case TCPCode: @@ -202,7 +202,8 @@ func protocolStringFromCode(protocolCode int64) netp.ProtocolString { return "" } -func getDimensionString(dimValue *interval.CanonicalSet, dim Dimension) string { +func getDimensionString(cube []*interval.CanonicalSet, dim Dimension) string { + dimValue := cube[dim] if dimValue.Equal(entireDimension(dim)) { // avoid adding dimension str on full dimension values return "" @@ -215,6 +216,8 @@ func getDimensionString(dimValue *interval.CanonicalSet, dim Dimension) string { pList = append(pList, string(protocolStringFromCode(code))) } } + // sort by string values to avoid dependence on internal encoding + sort.Strings(pList) return "protocol: " + strings.Join(pList, ",") case srcPort: return "src-ports: " + dimValue.String() @@ -228,48 +231,37 @@ func getDimensionString(dimValue *interval.CanonicalSet, dim Dimension) string { return "" } -func filterEmptyPropertiesStr(inputList []string) []string { +func joinNonEmpty(inputList ...string) string { res := []string{} for _, propertyStr := range inputList { if propertyStr != "" { res = append(res, propertyStr) } } - return res -} - -func getICMPbasedCubeStr(protocolsValues, icmpTypeValues, icmpCodeValues *interval.CanonicalSet) string { - strList := []string{ - getDimensionString(protocolsValues, protocol), - getDimensionString(icmpTypeValues, icmpType), - getDimensionString(icmpCodeValues, icmpCode), - } - return strings.Join(filterEmptyPropertiesStr(strList), propertySeparator) -} - -func getPortBasedCubeStr(protocolsValues, srcPortsValues, dstPortsValues *interval.CanonicalSet) string { - strList := []string{ - getDimensionString(protocolsValues, protocol), - getDimensionString(srcPortsValues, srcPort), - getDimensionString(dstPortsValues, dstPort), - } - return strings.Join(filterEmptyPropertiesStr(strList), propertySeparator) -} - -func getMixedProtocolsCubeStr(protocols *interval.CanonicalSet) string { - // TODO: make sure other dimension values are full - return getDimensionString(protocols, protocol) + return strings.Join(res, propertySeparator) } func getConnsCubeStr(cube []*interval.CanonicalSet) string { protocols := cube[protocol] - if (protocols.Contains(TCPCode) || protocols.Contains(UDPCode)) && !protocols.Contains(ICMPCode) { - return getPortBasedCubeStr(protocols, cube[srcPort], cube[dstPort]) - } - if protocols.Contains(ICMPCode) && !(protocols.Contains(TCPCode) || protocols.Contains(UDPCode)) { - return getICMPbasedCubeStr(protocols, cube[icmpType], cube[icmpCode]) + tcpOrUDP := protocols.Contains(TCPCode) || protocols.Contains(UDPCode) + icmp := protocols.Contains(ICMPCode) + switch { + case tcpOrUDP && !icmp: + return joinNonEmpty( + getDimensionString(cube, protocol), + getDimensionString(cube, srcPort), + getDimensionString(cube, dstPort), + ) + case icmp && !tcpOrUDP: + return joinNonEmpty( + getDimensionString(cube, protocol), + getDimensionString(cube, icmpType), + getDimensionString(cube, icmpCode), + ) + default: + // TODO: make sure other dimension values are full + return getDimensionString(cube, protocol) } - return getMixedProtocolsCubeStr(protocols) } // String returns a string representation of a Set object diff --git a/pkg/connection/connectionset_test.go b/pkg/connection/connectionset_test.go index a949dc6..a5cdb22 100644 --- a/pkg/connection/connectionset_test.go +++ b/pkg/connection/connectionset_test.go @@ -33,7 +33,7 @@ func TestBasicSetTCP(t *testing.T) { require.Equal(t, "protocol: TCP", e.String()) c := connection.All().Subtract(e) - require.Equal(t, "protocol: UDP,ICMP", c.String()) + require.Equal(t, "protocol: ICMP,UDP", c.String()) c = c.Union(e) require.Equal(t, "All Connections", c.String()) From de9698106c48b32f7214f77cf08b71aeed427925 Mon Sep 17 00:00:00 2001 From: Elazar Gershuni Date: Thu, 21 Mar 2024 10:55:12 +0200 Subject: [PATCH 11/15] remove stale comment Signed-off-by: Elazar Gershuni --- pkg/connection/connectionset.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/connection/connectionset.go b/pkg/connection/connectionset.go index d79b5bf..8a8f8aa 100644 --- a/pkg/connection/connectionset.go +++ b/pkg/connection/connectionset.go @@ -358,7 +358,7 @@ type Details spec.ProtocolList func ToJSON(c *Set) Details { if c == nil { - return nil // one of the connections in connectionDiff can be empty + return nil } if c.IsAll() { return Details(spec.ProtocolList{spec.AnyProtocol{Protocol: spec.AnyProtocolProtocolANY}}) From 6c6ed7690cb0751ea679b07893b6d455fd5eb392 Mon Sep 17 00:00:00 2001 From: Elazar Gershuni Date: Thu, 21 Mar 2024 12:09:13 +0200 Subject: [PATCH 12/15] address review comments Signed-off-by: Elazar Gershuni --- pkg/hypercube/hypercubeset_test.go | 289 ++++++++++++++--------------- pkg/interval/intervalset.go | 14 +- pkg/ipblock/ipblock.go | 2 +- 3 files changed, 153 insertions(+), 152 deletions(-) diff --git a/pkg/hypercube/hypercubeset_test.go b/pkg/hypercube/hypercubeset_test.go index 7b9940e..f9dd49a 100644 --- a/pkg/hypercube/hypercubeset_test.go +++ b/pkg/hypercube/hypercubeset_test.go @@ -10,13 +10,6 @@ import ( "github.com/np-guard/models/pkg/hypercube" ) -// cube returns a new hypercube.CanonicalSet created from a single input cube -// the input cube is given as an ordered list of integer values, where each two values -// represent the range (start,end) for a dimension value -func cube(values ...int64) *hypercube.CanonicalSet { - return hypercube.Cube(values...) -} - func union(set *hypercube.CanonicalSet, sets ...*hypercube.CanonicalSet) *hypercube.CanonicalSet { for _, c := range sets { set = set.Union(c) @@ -25,12 +18,12 @@ func union(set *hypercube.CanonicalSet, sets ...*hypercube.CanonicalSet) *hyperc } func TestHCBasic(t *testing.T) { - a := cube(1, 100) - b := cube(1, 100) - c := cube(1, 200) - d := cube(1, 100, 1, 100) - e := cube(1, 100, 1, 100) - f := cube(1, 100, 1, 200) + a := hypercube.Cube(1, 100) + b := hypercube.Cube(1, 100) + c := hypercube.Cube(1, 200) + d := hypercube.Cube(1, 100, 1, 100) + e := hypercube.Cube(1, 100, 1, 100) + f := hypercube.Cube(1, 100, 1, 200) require.True(t, a.Equal(b)) require.True(t, b.Equal(a)) @@ -49,7 +42,7 @@ func TestHCBasic(t *testing.T) { } func TestCopy(t *testing.T) { - a := cube(1, 100) + a := hypercube.Cube(1, 100) b := a.Copy() require.True(t, a.Equal(b)) require.True(t, b.Equal(a)) @@ -57,42 +50,42 @@ func TestCopy(t *testing.T) { } func TestString(t *testing.T) { - require.Equal(t, "[(1-3)]", cube(1, 3).String()) - require.Equal(t, "[(1-3),(2-4)]", cube(1, 3, 2, 4).String()) + require.Equal(t, "[(1-3)]", hypercube.Cube(1, 3).String()) + require.Equal(t, "[(1-3),(2-4)]", hypercube.Cube(1, 3, 2, 4).String()) } func TestOr(t *testing.T) { - a := cube(1, 100, 1, 100) - b := cube(1, 90, 1, 200) + a := hypercube.Cube(1, 100, 1, 100) + b := hypercube.Cube(1, 90, 1, 200) c := a.Union(b) require.Equal(t, "[(1-90),(1-200)]; [(91-100),(1-100)]", c.String()) } func TestBasic1(t *testing.T) { a := union( - cube(1, 2), - cube(5, 6), - cube(3, 4), + hypercube.Cube(1, 2), + hypercube.Cube(5, 6), + hypercube.Cube(3, 4), ) - b := cube(1, 6) + b := hypercube.Cube(1, 6) require.True(t, a.Equal(b)) } func TestBasic2(t *testing.T) { a := union( - cube(1, 2, 1, 5), - cube(1, 2, 7, 9), - cube(1, 2, 6, 7), + hypercube.Cube(1, 2, 1, 5), + hypercube.Cube(1, 2, 7, 9), + hypercube.Cube(1, 2, 6, 7), ) - b := cube(1, 2, 1, 9) + b := hypercube.Cube(1, 2, 1, 9) require.True(t, a.Equal(b)) } func TestNew(t *testing.T) { a := union( - cube(10, 20, 10, 20, 1, 65535), - cube(1, 65535, 15, 40, 1, 65535), - cube(1, 65535, 100, 200, 30, 80), + hypercube.Cube(10, 20, 10, 20, 1, 65535), + hypercube.Cube(1, 65535, 15, 40, 1, 65535), + hypercube.Cube(1, 65535, 100, 200, 30, 80), ) expectedStr := "[(1-9,21-65535),(100-200),(30-80)]; " + "[(1-9,21-65535),(15-40),(1-65535)]; " + @@ -109,39 +102,39 @@ func checkContained(t *testing.T, a, b *hypercube.CanonicalSet, expected bool) { } func TestContainedIn(t *testing.T) { - a := cube(1, 100, 200, 300) - b := cube(10, 80, 210, 280) + a := hypercube.Cube(1, 100, 200, 300) + b := hypercube.Cube(10, 80, 210, 280) checkContained(t, b, a, true) - b = b.Union(cube(10, 200, 210, 280)) + b = b.Union(hypercube.Cube(10, 200, 210, 280)) checkContained(t, b, a, false) } func TestContainedIn1(t *testing.T) { - checkContained(t, cube(1, 3), cube(2, 4), false) - checkContained(t, cube(2, 4), cube(1, 3), false) - checkContained(t, cube(1, 3), cube(1, 4), true) - checkContained(t, cube(1, 4), cube(1, 3), false) + checkContained(t, hypercube.Cube(1, 3), hypercube.Cube(2, 4), false) + checkContained(t, hypercube.Cube(2, 4), hypercube.Cube(1, 3), false) + checkContained(t, hypercube.Cube(1, 3), hypercube.Cube(1, 4), true) + checkContained(t, hypercube.Cube(1, 4), hypercube.Cube(1, 3), false) } func TestContainedIn2(t *testing.T) { c := union( - cube(1, 100, 200, 300), - cube(150, 180, 20, 300), - cube(200, 240, 200, 300), - cube(241, 300, 200, 350), + hypercube.Cube(1, 100, 200, 300), + hypercube.Cube(150, 180, 20, 300), + hypercube.Cube(200, 240, 200, 300), + hypercube.Cube(241, 300, 200, 350), ) a := union( - cube(1, 100, 200, 300), - cube(150, 180, 20, 300), - cube(200, 240, 200, 300), - cube(242, 300, 200, 350), + hypercube.Cube(1, 100, 200, 300), + hypercube.Cube(150, 180, 20, 300), + hypercube.Cube(200, 240, 200, 300), + hypercube.Cube(242, 300, 200, 350), ) - d := cube(210, 220, 210, 280) - e := cube(210, 310, 210, 280) - f := cube(210, 250, 210, 280) - f1 := cube(210, 240, 210, 280) - f2 := cube(241, 250, 210, 280) + d := hypercube.Cube(210, 220, 210, 280) + e := hypercube.Cube(210, 310, 210, 280) + f := hypercube.Cube(210, 250, 210, 280) + f1 := hypercube.Cube(210, 240, 210, 280) + f2 := hypercube.Cube(241, 250, 210, 280) checkContained(t, d, c, true) checkContained(t, e, c, false) @@ -152,213 +145,213 @@ func TestContainedIn2(t *testing.T) { } func TestContainedIn3(t *testing.T) { - a := cube(105, 105, 54, 54) + a := hypercube.Cube(105, 105, 54, 54) b := union( - cube(0, 204, 0, 255), - cube(205, 205, 0, 53), - cube(205, 205, 55, 255), - cube(206, 254, 0, 255), + hypercube.Cube(0, 204, 0, 255), + hypercube.Cube(205, 205, 0, 53), + hypercube.Cube(205, 205, 55, 255), + hypercube.Cube(206, 254, 0, 255), ) checkContained(t, a, b, true) } func TestContainedIn4(t *testing.T) { - a := cube(105, 105, 54, 54) - b := cube(200, 204, 0, 255) + a := hypercube.Cube(105, 105, 54, 54) + b := hypercube.Cube(200, 204, 0, 255) checkContained(t, a, b, false) } func TestContainedIn5(t *testing.T) { - a := cube(100, 200, 54, 65, 60, 300) - b := cube(110, 120, 0, 10, 0, 255) + a := hypercube.Cube(100, 200, 54, 65, 60, 300) + b := hypercube.Cube(110, 120, 0, 10, 0, 255) checkContained(t, b, a, false) } func TestEqual1(t *testing.T) { - a := cube(1, 2) - b := cube(1, 2) + a := hypercube.Cube(1, 2) + b := hypercube.Cube(1, 2) require.True(t, a.Equal(b)) - c := cube(1, 2, 1, 5) - d := cube(1, 2, 1, 5) + c := hypercube.Cube(1, 2, 1, 5) + d := hypercube.Cube(1, 2, 1, 5) require.True(t, c.Equal(d)) } func TestEqual2(t *testing.T) { c := union( - cube(1, 2, 1, 5), - cube(1, 2, 7, 9), - cube(1, 2, 6, 7), - cube(4, 8, 1, 9), + hypercube.Cube(1, 2, 1, 5), + hypercube.Cube(1, 2, 7, 9), + hypercube.Cube(1, 2, 6, 7), + hypercube.Cube(4, 8, 1, 9), ) res := union( - cube(4, 8, 1, 9), - cube(1, 2, 1, 9), + hypercube.Cube(4, 8, 1, 9), + hypercube.Cube(1, 2, 1, 9), ) require.True(t, res.Equal(c)) d := union( - cube(1, 2, 1, 5), - cube(5, 6, 1, 5), - cube(3, 4, 1, 5), + hypercube.Cube(1, 2, 1, 5), + hypercube.Cube(5, 6, 1, 5), + hypercube.Cube(3, 4, 1, 5), ) - res2 := cube(1, 6, 1, 5) + res2 := hypercube.Cube(1, 6, 1, 5) require.True(t, res2.Equal(d)) } func TestBasicAddCube(t *testing.T) { a := union( - cube(1, 2), - cube(8, 10), + hypercube.Cube(1, 2), + hypercube.Cube(8, 10), ) b := union( a, - cube(1, 2), - cube(6, 10), - cube(1, 10), + hypercube.Cube(1, 2), + hypercube.Cube(6, 10), + hypercube.Cube(1, 10), ) - res := cube(1, 10) + res := hypercube.Cube(1, 10) require.False(t, res.Equal(a)) require.True(t, res.Equal(b)) } func TestFourHoles(t *testing.T) { - a := cube(1, 2, 1, 2) - require.Equal(t, "[(1),(2)]; [(2),(1-2)]", a.Subtract(cube(1, 1, 1, 1)).String()) - require.Equal(t, "[(1),(1)]; [(2),(1-2)]", a.Subtract(cube(1, 1, 2, 2)).String()) - require.Equal(t, "[(1),(1-2)]; [(2),(2)]", a.Subtract(cube(2, 2, 1, 1)).String()) - require.Equal(t, "[(1),(1-2)]; [(2),(1)]", a.Subtract(cube(2, 2, 2, 2)).String()) + a := hypercube.Cube(1, 2, 1, 2) + require.Equal(t, "[(1),(2)]; [(2),(1-2)]", a.Subtract(hypercube.Cube(1, 1, 1, 1)).String()) + require.Equal(t, "[(1),(1)]; [(2),(1-2)]", a.Subtract(hypercube.Cube(1, 1, 2, 2)).String()) + require.Equal(t, "[(1),(1-2)]; [(2),(2)]", a.Subtract(hypercube.Cube(2, 2, 1, 1)).String()) + require.Equal(t, "[(1),(1-2)]; [(2),(1)]", a.Subtract(hypercube.Cube(2, 2, 2, 2)).String()) } func TestBasicSubtract1(t *testing.T) { - a := cube(1, 10) - require.True(t, a.Subtract(cube(3, 7)).Equal(union(cube(1, 2), cube(8, 10)))) - require.True(t, a.Subtract(cube(3, 20)).Equal(cube(1, 2))) - require.True(t, a.Subtract(cube(0, 20)).IsEmpty()) - require.True(t, a.Subtract(cube(0, 5)).Equal(cube(6, 10))) - require.True(t, a.Subtract(cube(12, 14)).Equal(cube(1, 10))) + a := hypercube.Cube(1, 10) + require.True(t, a.Subtract(hypercube.Cube(3, 7)).Equal(union(hypercube.Cube(1, 2), hypercube.Cube(8, 10)))) + require.True(t, a.Subtract(hypercube.Cube(3, 20)).Equal(hypercube.Cube(1, 2))) + require.True(t, a.Subtract(hypercube.Cube(0, 20)).IsEmpty()) + require.True(t, a.Subtract(hypercube.Cube(0, 5)).Equal(hypercube.Cube(6, 10))) + require.True(t, a.Subtract(hypercube.Cube(12, 14)).Equal(hypercube.Cube(1, 10))) } func TestBasicSubtract2(t *testing.T) { - a := cube(1, 100, 200, 300).Subtract(cube(50, 60, 220, 300)) + a := hypercube.Cube(1, 100, 200, 300).Subtract(hypercube.Cube(50, 60, 220, 300)) resA := union( - cube(61, 100, 200, 300), - cube(50, 60, 200, 219), - cube(1, 49, 200, 300), + hypercube.Cube(61, 100, 200, 300), + hypercube.Cube(50, 60, 200, 219), + hypercube.Cube(1, 49, 200, 300), ) require.True(t, a.Equal(resA)) - b := cube(1, 100, 200, 300).Subtract(cube(50, 1000, 0, 250)) + b := hypercube.Cube(1, 100, 200, 300).Subtract(hypercube.Cube(50, 1000, 0, 250)) resB := union( - cube(50, 100, 251, 300), - cube(1, 49, 200, 300), + hypercube.Cube(50, 100, 251, 300), + hypercube.Cube(1, 49, 200, 300), ) require.True(t, b.Equal(resB)) c := union( - cube(1, 100, 200, 300), - cube(400, 700, 200, 300), - ).Subtract(cube(50, 1000, 0, 250)) + hypercube.Cube(1, 100, 200, 300), + hypercube.Cube(400, 700, 200, 300), + ).Subtract(hypercube.Cube(50, 1000, 0, 250)) resC := union( - cube(50, 100, 251, 300), - cube(1, 49, 200, 300), - cube(400, 700, 251, 300), + hypercube.Cube(50, 100, 251, 300), + hypercube.Cube(1, 49, 200, 300), + hypercube.Cube(400, 700, 251, 300), ) require.True(t, c.Equal(resC)) - d := cube(1, 100, 200, 300).Subtract(cube(50, 60, 220, 300)) + d := hypercube.Cube(1, 100, 200, 300).Subtract(hypercube.Cube(50, 60, 220, 300)) dRes := union( - cube(1, 49, 200, 300), - cube(50, 60, 200, 219), - cube(61, 100, 200, 300), + hypercube.Cube(1, 49, 200, 300), + hypercube.Cube(50, 60, 200, 219), + hypercube.Cube(61, 100, 200, 300), ) require.True(t, d.Equal(dRes)) } func TestAddHole2(t *testing.T) { c := union( - cube(80, 100, 20, 300), - cube(250, 400, 20, 300), - ).Subtract(cube(30, 300, 100, 102)) + hypercube.Cube(80, 100, 20, 300), + hypercube.Cube(250, 400, 20, 300), + ).Subtract(hypercube.Cube(30, 300, 100, 102)) d := union( - cube(80, 100, 20, 99), - cube(80, 100, 103, 300), - cube(250, 300, 20, 99), - cube(250, 300, 103, 300), - cube(301, 400, 20, 300), + hypercube.Cube(80, 100, 20, 99), + hypercube.Cube(80, 100, 103, 300), + hypercube.Cube(250, 300, 20, 99), + hypercube.Cube(250, 300, 103, 300), + hypercube.Cube(301, 400, 20, 300), ) require.True(t, c.Equal(d)) } func TestSubtractToEmpty(t *testing.T) { - c := cube(1, 100, 200, 300).Subtract(cube(1, 100, 200, 300)) + c := hypercube.Cube(1, 100, 200, 300).Subtract(hypercube.Cube(1, 100, 200, 300)) require.True(t, c.IsEmpty()) } func TestUnion1(t *testing.T) { c := union( - cube(1, 100, 200, 300), - cube(101, 200, 200, 300), + hypercube.Cube(1, 100, 200, 300), + hypercube.Cube(101, 200, 200, 300), ) - cExpected := cube(1, 200, 200, 300) + cExpected := hypercube.Cube(1, 200, 200, 300) require.True(t, cExpected.Equal(c)) } func TestUnion2(t *testing.T) { c := union( - cube(1, 100, 200, 300), - cube(101, 200, 200, 300), - cube(201, 300, 200, 300), - cube(301, 400, 200, 300), - cube(402, 500, 200, 300), - cube(500, 600, 200, 700), - cube(601, 700, 200, 700), + hypercube.Cube(1, 100, 200, 300), + hypercube.Cube(101, 200, 200, 300), + hypercube.Cube(201, 300, 200, 300), + hypercube.Cube(301, 400, 200, 300), + hypercube.Cube(402, 500, 200, 300), + hypercube.Cube(500, 600, 200, 700), + hypercube.Cube(601, 700, 200, 700), ) cExpected := union( - cube(1, 400, 200, 300), - cube(402, 500, 200, 300), - cube(500, 700, 200, 700), + hypercube.Cube(1, 400, 200, 300), + hypercube.Cube(402, 500, 200, 300), + hypercube.Cube(500, 700, 200, 700), ) require.True(t, c.Equal(cExpected)) - d := c.Union(cube(702, 800, 200, 700)) - dExpected := cExpected.Union(cube(702, 800, 200, 700)) + d := c.Union(hypercube.Cube(702, 800, 200, 700)) + dExpected := cExpected.Union(hypercube.Cube(702, 800, 200, 700)) require.True(t, d.Equal(dExpected)) } func TestIntersect(t *testing.T) { - c := cube(5, 15, 3, 10).Intersect(cube(8, 30, 7, 20)) - d := cube(8, 15, 7, 10) + c := hypercube.Cube(5, 15, 3, 10).Intersect(hypercube.Cube(8, 30, 7, 20)) + d := hypercube.Cube(8, 15, 7, 10) require.True(t, c.Equal(d)) } func TestUnionMerge(t *testing.T) { a := union( - cube(5, 15, 3, 6), - cube(5, 30, 7, 10), - cube(8, 30, 11, 20), + hypercube.Cube(5, 15, 3, 6), + hypercube.Cube(5, 30, 7, 10), + hypercube.Cube(8, 30, 11, 20), ) excepted := union( - cube(5, 15, 3, 10), - cube(8, 30, 7, 20), + hypercube.Cube(5, 15, 3, 10), + hypercube.Cube(8, 30, 7, 20), ) require.True(t, excepted.Equal(a)) } func TestSubtract(t *testing.T) { - g := cube(5, 15, 3, 10).Subtract(cube(8, 30, 7, 20)) + g := hypercube.Cube(5, 15, 3, 10).Subtract(hypercube.Cube(8, 30, 7, 20)) h := union( - cube(5, 7, 3, 10), - cube(8, 15, 3, 6), + hypercube.Cube(5, 7, 3, 10), + hypercube.Cube(8, 15, 3, 6), ) require.True(t, g.Equal(h)) } func TestIntersectEmpty(t *testing.T) { - a := cube(5, 15, 3, 10) + a := hypercube.Cube(5, 15, 3, 10) b := union( - cube(1, 3, 7, 20), - cube(20, 23, 7, 20), + hypercube.Cube(1, 3, 7, 20), + hypercube.Cube(20, 23, 7, 20), ) c := a.Intersect(b) require.True(t, c.IsEmpty()) @@ -366,13 +359,13 @@ func TestIntersectEmpty(t *testing.T) { func TestOr2(t *testing.T) { a := union( - cube(1, 79, 10054, 10054), - cube(80, 100, 10053, 10054), - cube(101, 65535, 10054, 10054), + hypercube.Cube(1, 79, 10054, 10054), + hypercube.Cube(80, 100, 10053, 10054), + hypercube.Cube(101, 65535, 10054, 10054), ) expected := union( - cube(80, 100, 10053, 10053), - cube(1, 65535, 10054, 10054), + hypercube.Cube(80, 100, 10053, 10053), + hypercube.Cube(1, 65535, 10054, 10054), ) require.True(t, expected.Equal(a)) } diff --git a/pkg/interval/intervalset.go b/pkg/interval/intervalset.go index b7d15dd..6eacae9 100644 --- a/pkg/interval/intervalset.go +++ b/pkg/interval/intervalset.go @@ -181,6 +181,9 @@ func (c *CanonicalSet) Overlap(other *CanonicalSet) bool { // Subtract returns the subtraction result of input CanonicalSet func (c *CanonicalSet) Subtract(other *CanonicalSet) *CanonicalSet { + if c == other { + return NewCanonicalSet() + } res := c.Copy() for _, i := range other.intervalSet { res.AddHole(i) @@ -195,11 +198,16 @@ func (c *CanonicalSet) IsSingleNumber() bool { return false } +// Elements returns a slice with all the numbers contained in the set. +// USE WITH CARE. It can easily run out of memory for large sets. func (c *CanonicalSet) Elements() []int64 { - res := []int64{} + // allocate memory up front, to fail early + res := make([]int64, c.CalculateSize()) + i := 0 for _, interval := range c.intervalSet { - for i := interval.Start; i <= interval.End; i++ { - res = append(res, i) + for v := interval.Start; v <= interval.End; v++ { + res[i] = v + i++ } } return res diff --git a/pkg/ipblock/ipblock.go b/pkg/ipblock/ipblock.go index 65e6fbf..93f624d 100644 --- a/pkg/ipblock/ipblock.go +++ b/pkg/ipblock/ipblock.go @@ -168,7 +168,7 @@ func DisjointIPBlocks(set1, set2 []*IPBlock) []*IPBlock { } if len(res) == 0 { - res = append(res, GetCidrAll()) + res = []*IPBlock{GetCidrAll()} } return res } From c5eb591cb633d6b49afef6424a134c799dd8aa45 Mon Sep 17 00:00:00 2001 From: Elazar Gershuni Date: Thu, 21 Mar 2024 12:10:02 +0200 Subject: [PATCH 13/15] change Dimension to not encode the order of the dimensions Signed-off-by: Elazar Gershuni --- pkg/connection/connectionset.go | 54 ++++++++++++++++----------------- pkg/connection/statefulness.go | 17 +++-------- pkg/hypercube/hypercubeset.go | 19 ++++++++++++ 3 files changed, 50 insertions(+), 40 deletions(-) diff --git a/pkg/connection/connectionset.go b/pkg/connection/connectionset.go index 8a8f8aa..1cbdd1d 100644 --- a/pkg/connection/connectionset.go +++ b/pkg/connection/connectionset.go @@ -4,6 +4,7 @@ package connection import ( "log" + "slices" "sort" "strings" @@ -32,15 +33,14 @@ const ( NoConnections = "No Connections" ) -type Dimension int +type Dimension string const ( - protocol Dimension = 0 - srcPort Dimension = 1 - dstPort Dimension = 2 - icmpType Dimension = 3 - icmpCode Dimension = 4 - numDimensions = 5 + protocol Dimension = "protocol" + srcPort Dimension = "srcPort" + dstPort Dimension = "dstPort" + icmpType Dimension = "icmpType" + icmpCode Dimension = "icmpCode" ) const propertySeparator string = " " @@ -71,25 +71,21 @@ func entireDimension(dim Dimension) *interval.CanonicalSet { return nil } -func getDimensionDomainsList() []*interval.CanonicalSet { - res := make([]*interval.CanonicalSet, len(dimensionsList)) - for i := range dimensionsList { - res[i] = entireDimension(dimensionsList[i]) - } - return res -} - type Set struct { connectionProperties *hypercube.CanonicalSet IsStateful StatefulState } func None() *Set { - return &Set{connectionProperties: hypercube.NewCanonicalSet(numDimensions)} + return &Set{connectionProperties: hypercube.NewCanonicalSet(len(dimensionsList))} } func All() *Set { - return &Set{connectionProperties: hypercube.FromCube(getDimensionDomainsList())} + all := make([]*interval.CanonicalSet, len(dimensionsList)) + for i := range dimensionsList { + all[i] = entireDimension(dimensionsList[i]) + } + return &Set{connectionProperties: hypercube.FromCube(all)} } var all = All() @@ -149,7 +145,7 @@ func (c *Set) Subtract(other *Set) *Set { func (c *Set) ContainedIn(other *Set) bool { res, err := c.connectionProperties.ContainedIn(other.connectionProperties) if err != nil { - log.Fatalf("invalid connection set. %e", err) + log.Panicf("invalid connection set. %e", err) } return res } @@ -163,7 +159,7 @@ func protocolStringToCode(protocol netp.ProtocolString) int64 { case netp.ProtocolStringICMP: return ICMPCode } - log.Fatalf("Impossible protocol code %v", protocol) + log.Panicf("Impossible protocol code %v", protocol) return 0 } @@ -198,12 +194,12 @@ func protocolStringFromCode(protocolCode int64) netp.ProtocolString { case ICMPCode: return netp.ProtocolStringICMP } - log.Fatalf("impossible protocol code %v", protocolCode) + log.Panicf("impossible protocol code %v", protocolCode) return "" } func getDimensionString(cube []*interval.CanonicalSet, dim Dimension) string { - dimValue := cube[dim] + dimValue := cubeAt(cube, dim) if dimValue.Equal(entireDimension(dim)) { // avoid adding dimension str on full dimension values return "" @@ -242,7 +238,7 @@ func joinNonEmpty(inputList ...string) string { } func getConnsCubeStr(cube []*interval.CanonicalSet) string { - protocols := cube[protocol] + protocols := cubeAt(cube, protocol) tcpOrUDP := protocols.Contains(TCPCode) || protocols.Contains(UDPCode) icmp := protocols.Contains(ICMPCode) switch { @@ -281,11 +277,15 @@ func (c *Set) String() string { return strings.Join(resStrings, "; ") } +func cubeAt(cube []*interval.CanonicalSet, dim Dimension) *interval.CanonicalSet { + return cube[slices.Index(dimensionsList, dim)] +} + func getCubeAsTCPItems(cube []*interval.CanonicalSet, protocol spec.TcpUdpProtocol) []spec.TcpUdp { tcpItemsTemp := []spec.TcpUdp{} tcpItemsFinal := []spec.TcpUdp{} // consider src ports - srcPorts := cube[srcPort] + srcPorts := cubeAt(cube, srcPort) if !srcPorts.Equal(entireDimension(srcPort)) { // iterate the interval in the interval-set for _, interval := range srcPorts.Intervals() { @@ -296,7 +296,7 @@ func getCubeAsTCPItems(cube []*interval.CanonicalSet, protocol spec.TcpUdpProtoc tcpItemsTemp = append(tcpItemsTemp, spec.TcpUdp{Protocol: protocol}) } // consider dst ports - dstPorts := cube[dstPort] + dstPorts := cubeAt(cube, dstPort) if !dstPorts.Equal(entireDimension(dstPort)) { // iterate the interval in the interval-set for _, interval := range dstPorts.Intervals() { @@ -318,8 +318,8 @@ func getCubeAsTCPItems(cube []*interval.CanonicalSet, protocol spec.TcpUdpProtoc } func getCubeAsICMPItems(cube []*interval.CanonicalSet) []spec.Icmp { - icmpTypes := cube[icmpType] - icmpCodes := cube[icmpCode] + icmpTypes := cubeAt(cube, icmpType) + icmpCodes := cubeAt(cube, icmpCode) allTypes := icmpTypes.Equal(entireDimension(icmpType)) allCodes := icmpCodes.Equal(entireDimension(icmpCode)) switch { @@ -367,7 +367,7 @@ func ToJSON(c *Set) Details { cubes := c.connectionProperties.GetCubesList() for _, cube := range cubes { - protocols := cube[protocol] + protocols := cubeAt(cube, protocol) if protocols.Contains(TCPCode) { tcpItems := getCubeAsTCPItems(cube, spec.TcpUdpProtocolTCP) for _, item := range tcpItems { diff --git a/pkg/connection/statefulness.go b/pkg/connection/statefulness.go index ceb1095..8226be5 100644 --- a/pkg/connection/statefulness.go +++ b/pkg/connection/statefulness.go @@ -5,7 +5,6 @@ package connection import ( "slices" - "github.com/np-guard/models/pkg/hypercube" "github.com/np-guard/models/pkg/netp" ) @@ -69,19 +68,11 @@ func (c *Set) connTCPWithStatefulness(secondDirectionConn *Set) *Set { // It assumes the input connection object is only within TCP protocol. // For TCP the src and dst ports on relevant cubes are being switched. func (c *Set) switchSrcDstPortsOnTCP() *Set { - if c.IsAll() || c.IsEmpty() { + if c.IsAll() { return c.Copy() } - res := None() - for _, cube := range c.connectionProperties.GetCubesList() { - // assuming cube[protocol] contains TCP only - // no need to switch if src equals dst - if !cube[srcPort].Equal(cube[dstPort]) { - // Shallow clone should be enough, since we do shallow swap: - cube = slices.Clone(cube) - cube[srcPort], cube[dstPort] = cube[dstPort], cube[srcPort] - } - res.connectionProperties = res.connectionProperties.Union(hypercube.FromCube(cube)) + newConn := c.connectionProperties.SwapDimensions(slices.Index(dimensionsList, srcPort), slices.Index(dimensionsList, dstPort)) + return &Set{ + connectionProperties: newConn, } - return res } diff --git a/pkg/hypercube/hypercubeset.go b/pkg/hypercube/hypercubeset.go index 635a537..e8e022f 100644 --- a/pkg/hypercube/hypercubeset.go +++ b/pkg/hypercube/hypercubeset.go @@ -4,6 +4,7 @@ package hypercube import ( "errors" + "slices" "sort" "strings" @@ -250,6 +251,24 @@ func (c *CanonicalSet) GetCubesList() [][]*interval.CanonicalSet { return res } +// SwapDimensions returns a new CanonicalSet object, built from the input CanonicalSet object, +// with dimensions dim1 and dim2 swapped +func (c *CanonicalSet) SwapDimensions(dim1, dim2 int) *CanonicalSet { + if c.IsEmpty() { + return c.Copy() + } + res := NewCanonicalSet(c.dimensions) + for _, cube := range c.GetCubesList() { + if !cube[dim1].Equal(cube[dim2]) { + // Shallow clone should be enough, since we do shallow swap: + cube = slices.Clone(cube) + cube[dim1], cube[dim2] = cube[dim2], cube[dim1] + } + res = res.Union(FromCube(cube)) + } + return res +} + func getElementsUnionPerLayer(layers map[*interval.CanonicalSet]*CanonicalSet) map[*interval.CanonicalSet]*CanonicalSet { type pair struct { hc *CanonicalSet // hypercube set object From 5dab35880ee71d309cfdb0af152a2d8d604a7f98 Mon Sep 17 00:00:00 2001 From: Elazar Gershuni Date: Thu, 21 Mar 2024 15:01:42 +0200 Subject: [PATCH 14/15] handle edge case Signed-off-by: Elazar Gershuni --- pkg/hypercube/hypercubeset.go | 6 ++- pkg/hypercube/hypercubeset_test.go | 85 ++++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+), 1 deletion(-) diff --git a/pkg/hypercube/hypercubeset.go b/pkg/hypercube/hypercubeset.go index e8e022f..ff2461f 100644 --- a/pkg/hypercube/hypercubeset.go +++ b/pkg/hypercube/hypercubeset.go @@ -4,6 +4,7 @@ package hypercube import ( "errors" + "log" "slices" "sort" "strings" @@ -254,9 +255,12 @@ func (c *CanonicalSet) GetCubesList() [][]*interval.CanonicalSet { // SwapDimensions returns a new CanonicalSet object, built from the input CanonicalSet object, // with dimensions dim1 and dim2 swapped func (c *CanonicalSet) SwapDimensions(dim1, dim2 int) *CanonicalSet { - if c.IsEmpty() { + if c.IsEmpty() || dim1 == dim2 { return c.Copy() } + if min(dim1, dim2) < 0 || max(dim1, dim2) >= c.dimensions { + log.Panicf("invalid dimensions: %d, %d", dim1, dim2) + } res := NewCanonicalSet(c.dimensions) for _, cube := range c.GetCubesList() { if !cube[dim1].Equal(cube[dim2]) { diff --git a/pkg/hypercube/hypercubeset_test.go b/pkg/hypercube/hypercubeset_test.go index f9dd49a..6f8f5d1 100644 --- a/pkg/hypercube/hypercubeset_test.go +++ b/pkg/hypercube/hypercubeset_test.go @@ -369,3 +369,88 @@ func TestOr2(t *testing.T) { ) require.True(t, expected.Equal(a)) } + +// Assisted by WCA for GP +// Latest GenAI contribution: granite-20B-code-instruct-v2 model +// TestSwapDimensions tests the SwapDimensions method of the CanonicalSet type. +func TestSwapDimensions(t *testing.T) { + tests := []struct { + name string + c *hypercube.CanonicalSet + dim1 int + dim2 int + expected *hypercube.CanonicalSet + }{ + { + name: "empty set", + c: hypercube.NewCanonicalSet(2), + dim1: 0, + dim2: 1, + expected: hypercube.NewCanonicalSet(2), + }, + { + name: "0,0 of 1", + c: hypercube.Cube(1, 2), + dim1: 0, + dim2: 0, + expected: hypercube.Cube(1, 2), + }, + { + name: "0,1 of 2", + c: hypercube.Cube(1, 2, 3, 4), + dim1: 0, + dim2: 1, + expected: hypercube.Cube(3, 4, 1, 2), + }, + { + name: "0,1 of 2, no-op", + c: hypercube.Cube(1, 2, 1, 2), + dim1: 0, + dim2: 1, + expected: hypercube.Cube(1, 2, 1, 2), + }, + { + name: "0,1 of 3", + c: hypercube.Cube(1, 2, 3, 4, 5, 6), + dim1: 0, + dim2: 1, + expected: hypercube.Cube(3, 4, 1, 2, 5, 6), + }, + { + name: "1,2 of 3", + c: hypercube.Cube(1, 2, 3, 4, 5, 6), + dim1: 1, + dim2: 2, + expected: hypercube.Cube(1, 2, 5, 6, 3, 4), + }, + { + name: "0,2 of 3", + c: hypercube.Cube(1, 2, 3, 4, 5, 6), + dim1: 0, + dim2: 2, + expected: hypercube.Cube(5, 6, 3, 4, 1, 2), + }, + { + name: "0,1 of 2, non-cube", + c: union( + hypercube.Cube(1, 3, 7, 20), + hypercube.Cube(20, 23, 7, 20), + ), + dim1: 0, + dim2: 1, + expected: union( + hypercube.Cube(7, 20, 1, 3), + hypercube.Cube(7, 20, 20, 23), + ), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + actual := tt.c.SwapDimensions(tt.dim1, tt.dim2) + require.True(t, tt.expected != actual) + require.True(t, tt.expected.Equal(actual)) + }) + } + require.Panics(t, func() { hypercube.Cube(1, 2).SwapDimensions(0, 1) }) + require.Panics(t, func() { hypercube.Cube(1, 2).SwapDimensions(-1, 0) }) +} From 84681d59afd44aebbd88585ee8454935decb3261 Mon Sep 17 00:00:00 2001 From: Elazar Gershuni Date: Thu, 21 Mar 2024 15:08:57 +0200 Subject: [PATCH 15/15] test SwapDimensions Signed-off-by: Elazar Gershuni --- pkg/hypercube/hypercubeset_test.go | 99 ++++++------------------------ 1 file changed, 19 insertions(+), 80 deletions(-) diff --git a/pkg/hypercube/hypercubeset_test.go b/pkg/hypercube/hypercubeset_test.go index 6f8f5d1..1412fb8 100644 --- a/pkg/hypercube/hypercubeset_test.go +++ b/pkg/hypercube/hypercubeset_test.go @@ -370,87 +370,26 @@ func TestOr2(t *testing.T) { require.True(t, expected.Equal(a)) } -// Assisted by WCA for GP -// Latest GenAI contribution: granite-20B-code-instruct-v2 model -// TestSwapDimensions tests the SwapDimensions method of the CanonicalSet type. func TestSwapDimensions(t *testing.T) { - tests := []struct { - name string - c *hypercube.CanonicalSet - dim1 int - dim2 int - expected *hypercube.CanonicalSet - }{ - { - name: "empty set", - c: hypercube.NewCanonicalSet(2), - dim1: 0, - dim2: 1, - expected: hypercube.NewCanonicalSet(2), - }, - { - name: "0,0 of 1", - c: hypercube.Cube(1, 2), - dim1: 0, - dim2: 0, - expected: hypercube.Cube(1, 2), - }, - { - name: "0,1 of 2", - c: hypercube.Cube(1, 2, 3, 4), - dim1: 0, - dim2: 1, - expected: hypercube.Cube(3, 4, 1, 2), - }, - { - name: "0,1 of 2, no-op", - c: hypercube.Cube(1, 2, 1, 2), - dim1: 0, - dim2: 1, - expected: hypercube.Cube(1, 2, 1, 2), - }, - { - name: "0,1 of 3", - c: hypercube.Cube(1, 2, 3, 4, 5, 6), - dim1: 0, - dim2: 1, - expected: hypercube.Cube(3, 4, 1, 2, 5, 6), - }, - { - name: "1,2 of 3", - c: hypercube.Cube(1, 2, 3, 4, 5, 6), - dim1: 1, - dim2: 2, - expected: hypercube.Cube(1, 2, 5, 6, 3, 4), - }, - { - name: "0,2 of 3", - c: hypercube.Cube(1, 2, 3, 4, 5, 6), - dim1: 0, - dim2: 2, - expected: hypercube.Cube(5, 6, 3, 4, 1, 2), - }, - { - name: "0,1 of 2, non-cube", - c: union( - hypercube.Cube(1, 3, 7, 20), - hypercube.Cube(20, 23, 7, 20), - ), - dim1: 0, - dim2: 1, - expected: union( - hypercube.Cube(7, 20, 1, 3), - hypercube.Cube(7, 20, 20, 23), - ), - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - actual := tt.c.SwapDimensions(tt.dim1, tt.dim2) - require.True(t, tt.expected != actual) - require.True(t, tt.expected.Equal(actual)) - }) - } + require.True(t, hypercube.NewCanonicalSet(2).SwapDimensions(0, 1).Equal(hypercube.NewCanonicalSet(2))) + + require.True(t, hypercube.Cube(1, 2).SwapDimensions(0, 0).Equal(hypercube.Cube(1, 2))) + + require.True(t, hypercube.Cube(1, 2, 3, 4).SwapDimensions(0, 1).Equal(hypercube.Cube(3, 4, 1, 2))) + require.True(t, hypercube.Cube(1, 2, 1, 2).SwapDimensions(0, 1).Equal(hypercube.Cube(1, 2, 1, 2))) + + require.True(t, hypercube.Cube(1, 2, 3, 4, 5, 6).SwapDimensions(0, 1).Equal(hypercube.Cube(3, 4, 1, 2, 5, 6))) + require.True(t, hypercube.Cube(1, 2, 3, 4, 5, 6).SwapDimensions(1, 2).Equal(hypercube.Cube(1, 2, 5, 6, 3, 4))) + require.True(t, hypercube.Cube(1, 2, 3, 4, 5, 6).SwapDimensions(0, 2).Equal(hypercube.Cube(5, 6, 3, 4, 1, 2))) + + require.True(t, union( + hypercube.Cube(1, 3, 7, 20), + hypercube.Cube(20, 23, 7, 20), + ).SwapDimensions(0, 1).Equal(union( + hypercube.Cube(7, 20, 1, 3), + hypercube.Cube(7, 20, 20, 23), + ))) + require.Panics(t, func() { hypercube.Cube(1, 2).SwapDimensions(0, 1) }) require.Panics(t, func() { hypercube.Cube(1, 2).SwapDimensions(-1, 0) }) }