diff --git a/pkg/connection/connectionset.go b/pkg/connection/connectionset.go new file mode 100644 index 0000000..1cbdd1d --- /dev/null +++ b/pkg/connection/connectionset.go @@ -0,0 +1,392 @@ +// Copyright 2020- IBM Inc. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package connection + +import ( + "log" + "slices" + "sort" + "strings" + + "github.com/np-guard/models/pkg/hypercube" + "github.com/np-guard/models/pkg/interval" + "github.com/np-guard/models/pkg/netp" + "github.com/np-guard/models/pkg/spec" +) + +const ( + TCPCode = 0 + UDPCode = 1 + ICMPCode = 2 + MinICMPType int64 = 0 + MaxICMPType int64 = netp.InformationReply + MinICMPCode int64 = 0 + MaxICMPCode int64 = 5 + minProtocol int64 = 0 + maxProtocol int64 = 2 + MinPort = 1 + MaxPort = netp.MaxPort +) + +const ( + AllConnections = "All Connections" + NoConnections = "No Connections" +) + +type Dimension string + +const ( + protocol Dimension = "protocol" + srcPort Dimension = "srcPort" + dstPort Dimension = "dstPort" + icmpType Dimension = "icmpType" + icmpCode Dimension = "icmpCode" +) + +const propertySeparator string = " " + +// dimensionsList is the ordered list of dimensions in the Set object +// this should be the only place where the order is hard-coded +var dimensionsList = []Dimension{ + protocol, + srcPort, + dstPort, + icmpType, + icmpCode, +} + +func entireDimension(dim Dimension) *interval.CanonicalSet { + switch dim { + case protocol: + return interval.New(minProtocol, maxProtocol).ToSet() + case srcPort: + return interval.New(MinPort, MaxPort).ToSet() + case dstPort: + return interval.New(MinPort, MaxPort).ToSet() + case icmpType: + return interval.New(MinICMPType, MaxICMPType).ToSet() + case icmpCode: + return interval.New(MinICMPCode, MaxICMPCode).ToSet() + } + return nil +} + +type Set struct { + connectionProperties *hypercube.CanonicalSet + IsStateful StatefulState +} + +func None() *Set { + return &Set{connectionProperties: hypercube.NewCanonicalSet(len(dimensionsList))} +} + +func All() *Set { + all := make([]*interval.CanonicalSet, len(dimensionsList)) + for i := range dimensionsList { + all[i] = entireDimension(dimensionsList[i]) + } + return &Set{connectionProperties: hypercube.FromCube(all)} +} + +var all = All() + +func (c *Set) IsAll() bool { + return c.Equal(all) +} + +func (c *Set) Equal(other *Set) bool { + return c.connectionProperties.Equal(other.connectionProperties) +} + +func (c *Set) Copy() *Set { + return &Set{ + connectionProperties: c.connectionProperties.Copy(), + IsStateful: c.IsStateful, + } +} + +func (c *Set) Intersect(other *Set) *Set { + return &Set{connectionProperties: c.connectionProperties.Intersect(other.connectionProperties)} +} + +func (c *Set) IsEmpty() bool { + return c.connectionProperties.IsEmpty() +} + +func (c *Set) Union(other *Set) *Set { + if other.IsEmpty() { + return c.Copy() + } + if c.IsEmpty() { + return other.Copy() + } + return &Set{ + connectionProperties: c.connectionProperties.Union(other.connectionProperties), + } +} + +// Subtract +// ToDo: Subtract seems to ignore IsStateful (see https://github.com/np-guard/vpc-network-config-analyzer/issues/199): +// 1. is the delta connection stateful +// 2. connectionProperties is identical but c stateful while other is not +// the 2nd item can be computed here, with enhancement to relevant structure +// the 1st can not since we do not know where exactly the statefulness came from +func (c *Set) Subtract(other *Set) *Set { + if c.IsEmpty() { + return None() + } + if other.IsEmpty() { + return c.Copy() + } + return &Set{connectionProperties: c.connectionProperties.Subtract(other.connectionProperties)} +} + +// ContainedIn returns true if c is subset of other +func (c *Set) ContainedIn(other *Set) bool { + res, err := c.connectionProperties.ContainedIn(other.connectionProperties) + if err != nil { + log.Panicf("invalid connection set. %e", err) + } + return res +} + +func protocolStringToCode(protocol netp.ProtocolString) int64 { + switch protocol { + case netp.ProtocolStringTCP: + return TCPCode + case netp.ProtocolStringUDP: + return UDPCode + case netp.ProtocolStringICMP: + return ICMPCode + } + log.Panicf("Impossible protocol code %v", protocol) + return 0 +} + +func cube(protocolString netp.ProtocolString, + srcMinP, srcMaxP, dstMinP, dstMaxP, + icmpTypeMin, icmpTypeMax, icmpCodeMin, icmpCodeMax int64) *Set { + protocol := protocolStringToCode(protocolString) + return &Set{ + connectionProperties: hypercube.Cube(protocol, protocol, + srcMinP, srcMaxP, dstMinP, dstMaxP, + icmpTypeMin, icmpTypeMax, icmpCodeMin, icmpCodeMax)} +} + +func TCPorUDPConnection(protocol netp.ProtocolString, srcMinP, srcMaxP, dstMinP, dstMaxP int64) *Set { + return cube(protocol, + srcMinP, srcMaxP, dstMinP, dstMaxP, + MinICMPType, MaxICMPType, MinICMPCode, MaxICMPCode) +} + +func ICMPConnection(icmpTypeMin, icmpTypeMax, icmpCodeMin, icmpCodeMax int64) *Set { + return cube(netp.ProtocolStringICMP, + MinPort, MaxPort, MinPort, MaxPort, + icmpTypeMin, icmpTypeMax, icmpCodeMin, icmpCodeMax) +} + +func protocolStringFromCode(protocolCode int64) netp.ProtocolString { + switch protocolCode { + case TCPCode: + return netp.ProtocolStringTCP + case UDPCode: + return netp.ProtocolStringUDP + case ICMPCode: + return netp.ProtocolStringICMP + } + log.Panicf("impossible protocol code %v", protocolCode) + return "" +} + +func getDimensionString(cube []*interval.CanonicalSet, dim Dimension) string { + dimValue := cubeAt(cube, dim) + if dimValue.Equal(entireDimension(dim)) { + // avoid adding dimension str on full dimension values + return "" + } + switch dim { + case protocol: + pList := []string{} + for code := minProtocol; code <= maxProtocol; code++ { + if dimValue.Contains(code) { + pList = append(pList, string(protocolStringFromCode(code))) + } + } + // sort by string values to avoid dependence on internal encoding + sort.Strings(pList) + return "protocol: " + strings.Join(pList, ",") + case srcPort: + return "src-ports: " + dimValue.String() + case dstPort: + return "dst-ports: " + dimValue.String() + case icmpType: + return "icmp-type: " + dimValue.String() + case icmpCode: + return "icmp-code: " + dimValue.String() + } + return "" +} + +func joinNonEmpty(inputList ...string) string { + res := []string{} + for _, propertyStr := range inputList { + if propertyStr != "" { + res = append(res, propertyStr) + } + } + return strings.Join(res, propertySeparator) +} + +func getConnsCubeStr(cube []*interval.CanonicalSet) string { + protocols := cubeAt(cube, protocol) + tcpOrUDP := protocols.Contains(TCPCode) || protocols.Contains(UDPCode) + icmp := protocols.Contains(ICMPCode) + switch { + case tcpOrUDP && !icmp: + return joinNonEmpty( + getDimensionString(cube, protocol), + getDimensionString(cube, srcPort), + getDimensionString(cube, dstPort), + ) + case icmp && !tcpOrUDP: + return joinNonEmpty( + getDimensionString(cube, protocol), + getDimensionString(cube, icmpType), + getDimensionString(cube, icmpCode), + ) + default: + // TODO: make sure other dimension values are full + return getDimensionString(cube, protocol) + } +} + +// String returns a string representation of a Set object +func (c *Set) String() string { + if c.IsEmpty() { + return NoConnections + } else if c.IsAll() { + return AllConnections + } + // get cubes and cube str per each cube + resStrings := []string{} + for _, cube := range c.connectionProperties.GetCubesList() { + resStrings = append(resStrings, getConnsCubeStr(cube)) + } + + sort.Strings(resStrings) + return strings.Join(resStrings, "; ") +} + +func cubeAt(cube []*interval.CanonicalSet, dim Dimension) *interval.CanonicalSet { + return cube[slices.Index(dimensionsList, dim)] +} + +func getCubeAsTCPItems(cube []*interval.CanonicalSet, protocol spec.TcpUdpProtocol) []spec.TcpUdp { + tcpItemsTemp := []spec.TcpUdp{} + tcpItemsFinal := []spec.TcpUdp{} + // consider src ports + srcPorts := cubeAt(cube, srcPort) + if !srcPorts.Equal(entireDimension(srcPort)) { + // iterate the interval in the interval-set + for _, interval := range srcPorts.Intervals() { + tcpRes := spec.TcpUdp{Protocol: protocol, MinSourcePort: int(interval.Start), MaxSourcePort: int(interval.End)} + tcpItemsTemp = append(tcpItemsTemp, tcpRes) + } + } else { + tcpItemsTemp = append(tcpItemsTemp, spec.TcpUdp{Protocol: protocol}) + } + // consider dst ports + dstPorts := cubeAt(cube, dstPort) + if !dstPorts.Equal(entireDimension(dstPort)) { + // iterate the interval in the interval-set + for _, interval := range dstPorts.Intervals() { + for _, tcpItemTemp := range tcpItemsTemp { + tcpRes := spec.TcpUdp{ + Protocol: protocol, + MinSourcePort: tcpItemTemp.MinSourcePort, + MaxSourcePort: tcpItemTemp.MaxSourcePort, + MinDestinationPort: int(interval.Start), + MaxDestinationPort: int(interval.End), + } + tcpItemsFinal = append(tcpItemsFinal, tcpRes) + } + } + } else { + tcpItemsFinal = tcpItemsTemp + } + return tcpItemsFinal +} + +func getCubeAsICMPItems(cube []*interval.CanonicalSet) []spec.Icmp { + icmpTypes := cubeAt(cube, icmpType) + icmpCodes := cubeAt(cube, icmpCode) + allTypes := icmpTypes.Equal(entireDimension(icmpType)) + allCodes := icmpCodes.Equal(entireDimension(icmpCode)) + switch { + case allTypes && allCodes: + return []spec.Icmp{{Protocol: spec.IcmpProtocolICMP}} + case allTypes: + // This does not really make sense: not all types can have all codes + res := []spec.Icmp{} + for _, code64 := range icmpCodes.Elements() { + code := int(code64) + res = append(res, spec.Icmp{Protocol: spec.IcmpProtocolICMP, Code: &code}) + } + return res + case allCodes: + res := []spec.Icmp{} + for _, type64 := range icmpTypes.Elements() { + t := int(type64) + res = append(res, spec.Icmp{Protocol: spec.IcmpProtocolICMP, Type: &t}) + } + return res + default: + res := []spec.Icmp{} + // iterate both codes and types + for _, type64 := range icmpTypes.Elements() { + t := int(type64) + for _, code64 := range icmpCodes.Elements() { + code := int(code64) + res = append(res, spec.Icmp{Protocol: spec.IcmpProtocolICMP, Type: &t, Code: &code}) + } + } + return res + } +} + +type Details spec.ProtocolList + +func ToJSON(c *Set) Details { + if c == nil { + return nil + } + if c.IsAll() { + return Details(spec.ProtocolList{spec.AnyProtocol{Protocol: spec.AnyProtocolProtocolANY}}) + } + res := spec.ProtocolList{} + + cubes := c.connectionProperties.GetCubesList() + for _, cube := range cubes { + protocols := cubeAt(cube, protocol) + if protocols.Contains(TCPCode) { + tcpItems := getCubeAsTCPItems(cube, spec.TcpUdpProtocolTCP) + for _, item := range tcpItems { + res = append(res, item) + } + } + if protocols.Contains(UDPCode) { + udpItems := getCubeAsTCPItems(cube, spec.TcpUdpProtocolUDP) + for _, item := range udpItems { + res = append(res, item) + } + } + if protocols.Contains(ICMPCode) { + icmpItems := getCubeAsICMPItems(cube) + for _, item := range icmpItems { + res = append(res, item) + } + } + } + + return Details(res) +} diff --git a/pkg/connection/connectionset_test.go b/pkg/connection/connectionset_test.go new file mode 100644 index 0000000..a5cdb22 --- /dev/null +++ b/pkg/connection/connectionset_test.go @@ -0,0 +1,58 @@ +// Copyright 2020- IBM Inc. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package connection_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/np-guard/models/pkg/connection" + "github.com/np-guard/models/pkg/netp" +) + +const ICMPValue = netp.DestinationUnreachable + +func TestAllConnections(t *testing.T) { + c := connection.All() + require.Equal(t, "All Connections", c.String()) +} + +func TestNoConnections(t *testing.T) { + c := connection.None() + require.Equal(t, "No Connections", c.String()) +} + +func TestBasicSetICMP(t *testing.T) { + c := connection.ICMPConnection(ICMPValue, ICMPValue, 5, 5) + require.Equal(t, "protocol: ICMP icmp-type: 3 icmp-code: 5", c.String()) +} + +func TestBasicSetTCP(t *testing.T) { + e := connection.TCPorUDPConnection(netp.ProtocolStringTCP, 1, 65535, 1, 65535) + require.Equal(t, "protocol: TCP", e.String()) + + c := connection.All().Subtract(e) + require.Equal(t, "protocol: ICMP,UDP", c.String()) + + c = c.Union(e) + require.Equal(t, "All Connections", c.String()) +} + +func TestBasicSet2(t *testing.T) { + except1 := connection.ICMPConnection(ICMPValue, ICMPValue, 5, 5) + + except2 := connection.TCPorUDPConnection(netp.ProtocolStringTCP, 1, 65535, 1, 65535) + + d := connection.All().Subtract(except1).Subtract(except2) + require.Equal(t, ""+ + "protocol: ICMP icmp-type: 0-2,4-16; "+ + "protocol: ICMP icmp-type: 3 icmp-code: 0-4; "+ + "protocol: UDP", d.String()) +} + +func TestBasicSet3(t *testing.T) { + c := connection.ICMPConnection(ICMPValue, ICMPValue, 5, 5) + d := connection.All().Subtract(c).Union(connection.ICMPConnection(ICMPValue, ICMPValue, 5, 5)) + require.Equal(t, "All Connections", d.String()) +} diff --git a/pkg/connection/statefulness.go b/pkg/connection/statefulness.go new file mode 100644 index 0000000..8226be5 --- /dev/null +++ b/pkg/connection/statefulness.go @@ -0,0 +1,78 @@ +// Copyright 2020- IBM Inc. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package connection + +import ( + "slices" + + "github.com/np-guard/models/pkg/netp" +) + +// default is StatefulUnknown +type StatefulState int + +const ( + // StatefulUnknown is the default value for a Set object, + StatefulUnknown StatefulState = 0 + // StatefulTrue represents a connection object for which any allowed TCP (on all allowed src/dst ports) + // has an allowed response connection + StatefulTrue StatefulState = 1 + // StatefulFalse represents a connection object for which there exists some allowed TCP + // (on any allowed subset from the allowed src/dst ports) that does not have an allowed response connection + StatefulFalse StatefulState = 2 +) + +// EnhancedString returns a connection string with possibly added asterisk for stateless connection +func (c *Set) EnhancedString() string { + if c.IsStateful == StatefulFalse { + return c.String() + " *" + } + return c.String() +} + +func newTCPSet() *Set { + return TCPorUDPConnection(netp.ProtocolStringTCP, MinPort, MaxPort, MinPort, MaxPort) +} + +// WithStatefulness updates `c` object with `IsStateful` property, based on input `secondDirectionConn`. +// `c` represents a src-to-dst connection, and `secondDirectionConn` represents dst-to-src connection. +// The property `IsStateful` of `c` is set as `StatefulFalse` if there is at least some subset within TCP from `c` +// which is not stateful (such that the response direction for this subset is not enabled). +// This function also returns a connection object with the exact subset of the stateful part (within TCP) +// from the entire connection `c`, and with the original connections on other protocols. +func (c *Set) WithStatefulness(secondDirectionConn *Set) *Set { + connTCP := c.Intersect(newTCPSet()) + if connTCP.IsEmpty() { + c.IsStateful = StatefulTrue + return c + } + statefulCombinedConnTCP := connTCP.connTCPWithStatefulness(secondDirectionConn.Intersect(newTCPSet())) + c.IsStateful = connTCP.IsStateful + return c.Subtract(connTCP).Union(statefulCombinedConnTCP) +} + +// connTCPWithStatefulness assumes that both `c` and `secondDirectionConn` are within TCP. +// it assigns IsStateful a value within `c`, and returns the subset from `c` which is stateful. +func (c *Set) connTCPWithStatefulness(secondDirectionConn *Set) *Set { + // flip src/dst ports before intersection + statefulCombinedConn := c.Intersect(secondDirectionConn.switchSrcDstPortsOnTCP()) + if c.Equal(statefulCombinedConn) { + c.IsStateful = StatefulTrue + } else { + c.IsStateful = StatefulFalse + } + return statefulCombinedConn +} + +// switchSrcDstPortsOnTCP returns a new Set object, built from the input Set object. +// It assumes the input connection object is only within TCP protocol. +// For TCP the src and dst ports on relevant cubes are being switched. +func (c *Set) switchSrcDstPortsOnTCP() *Set { + if c.IsAll() { + return c.Copy() + } + newConn := c.connectionProperties.SwapDimensions(slices.Index(dimensionsList, srcPort), slices.Index(dimensionsList, dstPort)) + return &Set{ + connectionProperties: newConn, + } +} diff --git a/pkg/connection/statefulness_test.go b/pkg/connection/statefulness_test.go new file mode 100644 index 0000000..27f627f --- /dev/null +++ b/pkg/connection/statefulness_test.go @@ -0,0 +1,148 @@ +// Copyright 2020- IBM Inc. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package connection_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/np-guard/models/pkg/connection" + "github.com/np-guard/models/pkg/netp" +) + +func newTCPConn(t *testing.T, srcMinP, srcMaxP, dstMinP, dstMaxP int64) *connection.Set { + t.Helper() + return connection.TCPorUDPConnection(netp.ProtocolStringTCP, srcMinP, srcMaxP, dstMinP, dstMaxP) +} + +func newUDPConn(t *testing.T, srcMinP, srcMaxP, dstMinP, dstMaxP int64) *connection.Set { + t.Helper() + return connection.TCPorUDPConnection(netp.ProtocolStringUDP, srcMinP, srcMaxP, dstMinP, dstMaxP) +} + +func newICMPconn(t *testing.T) *connection.Set { + t.Helper() + return connection.ICMPConnection( + connection.MinICMPType, connection.MaxICMPType, + connection.MinICMPCode, connection.MaxICMPCode) +} + +func newTCPUDPSet(t *testing.T, p netp.ProtocolString) *connection.Set { + t.Helper() + return connection.TCPorUDPConnection(p, + connection.MinPort, connection.MaxPort, + connection.MinPort, connection.MaxPort) +} + +type statefulnessTest struct { + name string + srcToDst *connection.Set + dstToSrc *connection.Set + // expectedIsStateful represents the expected IsStateful computed value for srcToDst, + // which should be either StatefulTrue or StatefulFalse, given the input dstToSrc connection. + // the computation applies only to the TCP protocol within those connections. + expectedIsStateful connection.StatefulState + // expectedStatefulConn represents the subset from srcToDst which is not related to the "non-stateful" mark (*) on the srcToDst connection, + // the stateless part for TCP is srcToDst.Subtract(statefulConn) + expectedStatefulConn *connection.Set +} + +func (tt statefulnessTest) runTest(t *testing.T) { + t.Helper() + statefulConn := tt.srcToDst.WithStatefulness(tt.dstToSrc) + require.Equal(t, tt.expectedIsStateful, tt.srcToDst.IsStateful) + require.True(t, tt.expectedStatefulConn.Equal(statefulConn)) +} + +func TestAll(t *testing.T) { + var testCasesStatefulness = []statefulnessTest{ + { + name: "tcp_all_ports_on_both_directions", + srcToDst: newTCPUDPSet(t, netp.ProtocolStringTCP), // TCP all ports + dstToSrc: newTCPUDPSet(t, netp.ProtocolStringTCP), // TCP all ports + expectedIsStateful: connection.StatefulTrue, + expectedStatefulConn: newTCPUDPSet(t, netp.ProtocolStringTCP), // TCP all ports + }, + { + name: "first_all_cons_second_tcp_with_ports", + srcToDst: connection.All(), // all connections + dstToSrc: newTCPConn(t, 80, 80, connection.MinPort, connection.MaxPort), // TCP , src-ports: 80, dst-ports: all + + // there is a subset of the tcp connection which is not stateful + expectedIsStateful: connection.StatefulFalse, + + // TCP src-ports: all, dst-port: 80 , union: all non-TCP conns + expectedStatefulConn: connection.All().Subtract(newTCPUDPSet(t, netp.ProtocolStringTCP)).Union( + newTCPConn(t, connection.MinPort, connection.MaxPort, 80, 80)), + }, + { + name: "first_all_conns_second_no_tcp", + srcToDst: connection.All(), // all connections + dstToSrc: newICMPconn(t), // ICMP + expectedIsStateful: connection.StatefulFalse, + // UDP, ICMP (all TCP is considered stateless here) + expectedStatefulConn: connection.All().Subtract(newTCPUDPSet(t, netp.ProtocolStringTCP)), + }, + { + name: "tcp_with_ports_both_directions_exact_match", + srcToDst: newTCPConn(t, 80, 80, 443, 443), + dstToSrc: newTCPConn(t, 443, 443, 80, 80), + expectedIsStateful: connection.StatefulTrue, + expectedStatefulConn: newTCPConn(t, 80, 80, 443, 443), + }, + { + name: "tcp_with_ports_both_directions_partial_match", + srcToDst: newTCPConn(t, 80, 100, 443, 443), + dstToSrc: newTCPConn(t, 443, 443, 80, 80), + expectedIsStateful: connection.StatefulFalse, + expectedStatefulConn: newTCPConn(t, 80, 80, 443, 443), + }, + { + name: "tcp_with_ports_both_directions_no_match", + srcToDst: newTCPConn(t, 80, 100, 443, 443), + dstToSrc: newTCPConn(t, 80, 80, 80, 80), + expectedIsStateful: connection.StatefulFalse, + expectedStatefulConn: connection.None(), + }, + { + name: "udp_and_tcp_with_ports_both_directions_no_match", + srcToDst: newTCPConn(t, 80, 100, 443, 443).Union(newUDPConn(t, 80, 100, 443, 443)), + dstToSrc: newTCPConn(t, 80, 80, 80, 80).Union(newUDPConn(t, 80, 80, 80, 80)), + expectedIsStateful: connection.StatefulFalse, + expectedStatefulConn: newUDPConn(t, 80, 100, 443, 443), + }, + { + name: "no_tcp_in_first_direction", + srcToDst: newUDPConn(t, 70, 100, 443, 443), + dstToSrc: newTCPConn(t, 70, 80, 80, 80).Union(newUDPConn(t, 70, 80, 80, 80)), + expectedIsStateful: connection.StatefulTrue, + expectedStatefulConn: newUDPConn(t, 70, 100, 443, 443), + }, + { + name: "empty_conn_in_first_direction", + srcToDst: connection.None(), + dstToSrc: newTCPConn(t, 80, 80, 80, 80).Union(newTCPUDPSet(t, netp.ProtocolStringUDP)), + expectedIsStateful: connection.StatefulTrue, + expectedStatefulConn: connection.None(), + }, + { + name: "only_udp_icmp_in_first_direction_and_empty_second_direction", + srcToDst: newTCPUDPSet(t, netp.ProtocolStringUDP).Union(newICMPconn(t)), + dstToSrc: connection.None(), + // stateful analysis does not apply to udp/icmp, thus considered in the result as "stateful" + // (to avoid marking it as stateless in the output) + expectedIsStateful: connection.StatefulTrue, + expectedStatefulConn: newTCPUDPSet(t, netp.ProtocolStringUDP).Union(newICMPconn(t)), + }, + } + t.Parallel() + // explainTests is the list of tests to run + for testIdx := range testCasesStatefulness { + tt := testCasesStatefulness[testIdx] + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + tt.runTest(t) + }) + } +} diff --git a/pkg/hypercube/hypercubeset.go b/pkg/hypercube/hypercubeset.go index 0a55181..ff2461f 100644 --- a/pkg/hypercube/hypercubeset.go +++ b/pkg/hypercube/hypercubeset.go @@ -1,7 +1,11 @@ +// Copyright 2020- IBM Inc. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 package hypercube import ( "errors" + "log" + "slices" "sort" "strings" @@ -163,7 +167,7 @@ func (c *CanonicalSet) Subtract(other *CanonicalSet) *CanonicalSet { } } -// ContainedIn returns true ic other contained in c +// ContainedIn returns true if c is subset of other func (c *CanonicalSet) ContainedIn(other *CanonicalSet) (bool, error) { if c == other { return true, nil @@ -179,8 +183,7 @@ func (c *CanonicalSet) ContainedIn(other *CanonicalSet) (bool, error) { } isSubsetCount := 0 - for k, v := range c.layers { - currentLayer := k.Copy() + for currentLayer, v := range c.layers { for otherKey, otherVal := range other.layers { commonKey := currentLayer.Intersect(otherKey) remaining := currentLayer.Subtract(commonKey) @@ -249,6 +252,27 @@ func (c *CanonicalSet) GetCubesList() [][]*interval.CanonicalSet { return res } +// SwapDimensions returns a new CanonicalSet object, built from the input CanonicalSet object, +// with dimensions dim1 and dim2 swapped +func (c *CanonicalSet) SwapDimensions(dim1, dim2 int) *CanonicalSet { + if c.IsEmpty() || dim1 == dim2 { + return c.Copy() + } + if min(dim1, dim2) < 0 || max(dim1, dim2) >= c.dimensions { + log.Panicf("invalid dimensions: %d, %d", dim1, dim2) + } + res := NewCanonicalSet(c.dimensions) + for _, cube := range c.GetCubesList() { + if !cube[dim1].Equal(cube[dim2]) { + // Shallow clone should be enough, since we do shallow swap: + cube = slices.Clone(cube) + cube[dim1], cube[dim2] = cube[dim2], cube[dim1] + } + res = res.Union(FromCube(cube)) + } + return res +} + func getElementsUnionPerLayer(layers map[*interval.CanonicalSet]*CanonicalSet) map[*interval.CanonicalSet]*CanonicalSet { type pair struct { hc *CanonicalSet // hypercube set object @@ -289,3 +313,14 @@ func FromCube(cube []*interval.CanonicalSet) *CanonicalSet { res.layers[cube[0].Copy()] = FromCube(cube[1:]) return res } + +// Cube returns a new CanonicalSet created from a single input cube +// the input cube is given as an ordered list of integer values, where each two values +// represent the range (start,end) for a dimension value +func Cube(values ...int64) *CanonicalSet { + cube := []*interval.CanonicalSet{} + for i := 0; i < len(values); i += 2 { + cube = append(cube, interval.New(values[i], values[i+1]).ToSet()) + } + return FromCube(cube) +} diff --git a/pkg/hypercube/hypercubeset_test.go b/pkg/hypercube/hypercubeset_test.go index dc178db..1412fb8 100644 --- a/pkg/hypercube/hypercubeset_test.go +++ b/pkg/hypercube/hypercubeset_test.go @@ -1,3 +1,5 @@ +// Copyright 2020- IBM Inc. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 package hypercube_test import ( @@ -6,20 +8,8 @@ import ( "github.com/stretchr/testify/require" "github.com/np-guard/models/pkg/hypercube" - "github.com/np-guard/models/pkg/interval" ) -// cube returns a new hypercube.CanonicalSet created from a single input cube -// the input cube is given as an ordered list of integer values, where each two values -// represent the range (start,end) for a dimension value -func cube(values ...int64) *hypercube.CanonicalSet { - cube := []*interval.CanonicalSet{} - for i := 0; i < len(values); i += 2 { - cube = append(cube, interval.New(values[i], values[i+1]).ToSet()) - } - return hypercube.FromCube(cube) -} - func union(set *hypercube.CanonicalSet, sets ...*hypercube.CanonicalSet) *hypercube.CanonicalSet { for _, c := range sets { set = set.Union(c) @@ -28,12 +18,12 @@ func union(set *hypercube.CanonicalSet, sets ...*hypercube.CanonicalSet) *hyperc } func TestHCBasic(t *testing.T) { - a := cube(1, 100) - b := cube(1, 100) - c := cube(1, 200) - d := cube(1, 100, 1, 100) - e := cube(1, 100, 1, 100) - f := cube(1, 100, 1, 200) + a := hypercube.Cube(1, 100) + b := hypercube.Cube(1, 100) + c := hypercube.Cube(1, 200) + d := hypercube.Cube(1, 100, 1, 100) + e := hypercube.Cube(1, 100, 1, 100) + f := hypercube.Cube(1, 100, 1, 200) require.True(t, a.Equal(b)) require.True(t, b.Equal(a)) @@ -52,7 +42,7 @@ func TestHCBasic(t *testing.T) { } func TestCopy(t *testing.T) { - a := cube(1, 100) + a := hypercube.Cube(1, 100) b := a.Copy() require.True(t, a.Equal(b)) require.True(t, b.Equal(a)) @@ -60,42 +50,42 @@ func TestCopy(t *testing.T) { } func TestString(t *testing.T) { - require.Equal(t, "[(1-3)]", cube(1, 3).String()) - require.Equal(t, "[(1-3),(2-4)]", cube(1, 3, 2, 4).String()) + require.Equal(t, "[(1-3)]", hypercube.Cube(1, 3).String()) + require.Equal(t, "[(1-3),(2-4)]", hypercube.Cube(1, 3, 2, 4).String()) } func TestOr(t *testing.T) { - a := cube(1, 100, 1, 100) - b := cube(1, 90, 1, 200) + a := hypercube.Cube(1, 100, 1, 100) + b := hypercube.Cube(1, 90, 1, 200) c := a.Union(b) require.Equal(t, "[(1-90),(1-200)]; [(91-100),(1-100)]", c.String()) } func TestBasic1(t *testing.T) { a := union( - cube(1, 2), - cube(5, 6), - cube(3, 4), + hypercube.Cube(1, 2), + hypercube.Cube(5, 6), + hypercube.Cube(3, 4), ) - b := cube(1, 6) + b := hypercube.Cube(1, 6) require.True(t, a.Equal(b)) } func TestBasic2(t *testing.T) { a := union( - cube(1, 2, 1, 5), - cube(1, 2, 7, 9), - cube(1, 2, 6, 7), + hypercube.Cube(1, 2, 1, 5), + hypercube.Cube(1, 2, 7, 9), + hypercube.Cube(1, 2, 6, 7), ) - b := cube(1, 2, 1, 9) + b := hypercube.Cube(1, 2, 1, 9) require.True(t, a.Equal(b)) } func TestNew(t *testing.T) { a := union( - cube(10, 20, 10, 20, 1, 65535), - cube(1, 65535, 15, 40, 1, 65535), - cube(1, 65535, 100, 200, 30, 80), + hypercube.Cube(10, 20, 10, 20, 1, 65535), + hypercube.Cube(1, 65535, 15, 40, 1, 65535), + hypercube.Cube(1, 65535, 100, 200, 30, 80), ) expectedStr := "[(1-9,21-65535),(100-200),(30-80)]; " + "[(1-9,21-65535),(15-40),(1-65535)]; " + @@ -112,39 +102,39 @@ func checkContained(t *testing.T, a, b *hypercube.CanonicalSet, expected bool) { } func TestContainedIn(t *testing.T) { - a := cube(1, 100, 200, 300) - b := cube(10, 80, 210, 280) + a := hypercube.Cube(1, 100, 200, 300) + b := hypercube.Cube(10, 80, 210, 280) checkContained(t, b, a, true) - b = b.Union(cube(10, 200, 210, 280)) + b = b.Union(hypercube.Cube(10, 200, 210, 280)) checkContained(t, b, a, false) } func TestContainedIn1(t *testing.T) { - checkContained(t, cube(1, 3), cube(2, 4), false) - checkContained(t, cube(2, 4), cube(1, 3), false) - checkContained(t, cube(1, 3), cube(1, 4), true) - checkContained(t, cube(1, 4), cube(1, 3), false) + checkContained(t, hypercube.Cube(1, 3), hypercube.Cube(2, 4), false) + checkContained(t, hypercube.Cube(2, 4), hypercube.Cube(1, 3), false) + checkContained(t, hypercube.Cube(1, 3), hypercube.Cube(1, 4), true) + checkContained(t, hypercube.Cube(1, 4), hypercube.Cube(1, 3), false) } func TestContainedIn2(t *testing.T) { c := union( - cube(1, 100, 200, 300), - cube(150, 180, 20, 300), - cube(200, 240, 200, 300), - cube(241, 300, 200, 350), + hypercube.Cube(1, 100, 200, 300), + hypercube.Cube(150, 180, 20, 300), + hypercube.Cube(200, 240, 200, 300), + hypercube.Cube(241, 300, 200, 350), ) a := union( - cube(1, 100, 200, 300), - cube(150, 180, 20, 300), - cube(200, 240, 200, 300), - cube(242, 300, 200, 350), + hypercube.Cube(1, 100, 200, 300), + hypercube.Cube(150, 180, 20, 300), + hypercube.Cube(200, 240, 200, 300), + hypercube.Cube(242, 300, 200, 350), ) - d := cube(210, 220, 210, 280) - e := cube(210, 310, 210, 280) - f := cube(210, 250, 210, 280) - f1 := cube(210, 240, 210, 280) - f2 := cube(241, 250, 210, 280) + d := hypercube.Cube(210, 220, 210, 280) + e := hypercube.Cube(210, 310, 210, 280) + f := hypercube.Cube(210, 250, 210, 280) + f1 := hypercube.Cube(210, 240, 210, 280) + f2 := hypercube.Cube(241, 250, 210, 280) checkContained(t, d, c, true) checkContained(t, e, c, false) @@ -155,213 +145,213 @@ func TestContainedIn2(t *testing.T) { } func TestContainedIn3(t *testing.T) { - a := cube(105, 105, 54, 54) + a := hypercube.Cube(105, 105, 54, 54) b := union( - cube(0, 204, 0, 255), - cube(205, 205, 0, 53), - cube(205, 205, 55, 255), - cube(206, 254, 0, 255), + hypercube.Cube(0, 204, 0, 255), + hypercube.Cube(205, 205, 0, 53), + hypercube.Cube(205, 205, 55, 255), + hypercube.Cube(206, 254, 0, 255), ) checkContained(t, a, b, true) } func TestContainedIn4(t *testing.T) { - a := cube(105, 105, 54, 54) - b := cube(200, 204, 0, 255) + a := hypercube.Cube(105, 105, 54, 54) + b := hypercube.Cube(200, 204, 0, 255) checkContained(t, a, b, false) } func TestContainedIn5(t *testing.T) { - a := cube(100, 200, 54, 65, 60, 300) - b := cube(110, 120, 0, 10, 0, 255) + a := hypercube.Cube(100, 200, 54, 65, 60, 300) + b := hypercube.Cube(110, 120, 0, 10, 0, 255) checkContained(t, b, a, false) } func TestEqual1(t *testing.T) { - a := cube(1, 2) - b := cube(1, 2) + a := hypercube.Cube(1, 2) + b := hypercube.Cube(1, 2) require.True(t, a.Equal(b)) - c := cube(1, 2, 1, 5) - d := cube(1, 2, 1, 5) + c := hypercube.Cube(1, 2, 1, 5) + d := hypercube.Cube(1, 2, 1, 5) require.True(t, c.Equal(d)) } func TestEqual2(t *testing.T) { c := union( - cube(1, 2, 1, 5), - cube(1, 2, 7, 9), - cube(1, 2, 6, 7), - cube(4, 8, 1, 9), + hypercube.Cube(1, 2, 1, 5), + hypercube.Cube(1, 2, 7, 9), + hypercube.Cube(1, 2, 6, 7), + hypercube.Cube(4, 8, 1, 9), ) res := union( - cube(4, 8, 1, 9), - cube(1, 2, 1, 9), + hypercube.Cube(4, 8, 1, 9), + hypercube.Cube(1, 2, 1, 9), ) require.True(t, res.Equal(c)) d := union( - cube(1, 2, 1, 5), - cube(5, 6, 1, 5), - cube(3, 4, 1, 5), + hypercube.Cube(1, 2, 1, 5), + hypercube.Cube(5, 6, 1, 5), + hypercube.Cube(3, 4, 1, 5), ) - res2 := cube(1, 6, 1, 5) + res2 := hypercube.Cube(1, 6, 1, 5) require.True(t, res2.Equal(d)) } func TestBasicAddCube(t *testing.T) { a := union( - cube(1, 2), - cube(8, 10), + hypercube.Cube(1, 2), + hypercube.Cube(8, 10), ) b := union( a, - cube(1, 2), - cube(6, 10), - cube(1, 10), + hypercube.Cube(1, 2), + hypercube.Cube(6, 10), + hypercube.Cube(1, 10), ) - res := cube(1, 10) + res := hypercube.Cube(1, 10) require.False(t, res.Equal(a)) require.True(t, res.Equal(b)) } func TestFourHoles(t *testing.T) { - a := cube(1, 2, 1, 2) - require.Equal(t, "[(1),(2)]; [(2),(1-2)]", a.Subtract(cube(1, 1, 1, 1)).String()) - require.Equal(t, "[(1),(1)]; [(2),(1-2)]", a.Subtract(cube(1, 1, 2, 2)).String()) - require.Equal(t, "[(1),(1-2)]; [(2),(2)]", a.Subtract(cube(2, 2, 1, 1)).String()) - require.Equal(t, "[(1),(1-2)]; [(2),(1)]", a.Subtract(cube(2, 2, 2, 2)).String()) + a := hypercube.Cube(1, 2, 1, 2) + require.Equal(t, "[(1),(2)]; [(2),(1-2)]", a.Subtract(hypercube.Cube(1, 1, 1, 1)).String()) + require.Equal(t, "[(1),(1)]; [(2),(1-2)]", a.Subtract(hypercube.Cube(1, 1, 2, 2)).String()) + require.Equal(t, "[(1),(1-2)]; [(2),(2)]", a.Subtract(hypercube.Cube(2, 2, 1, 1)).String()) + require.Equal(t, "[(1),(1-2)]; [(2),(1)]", a.Subtract(hypercube.Cube(2, 2, 2, 2)).String()) } func TestBasicSubtract1(t *testing.T) { - a := cube(1, 10) - require.True(t, a.Subtract(cube(3, 7)).Equal(union(cube(1, 2), cube(8, 10)))) - require.True(t, a.Subtract(cube(3, 20)).Equal(cube(1, 2))) - require.True(t, a.Subtract(cube(0, 20)).IsEmpty()) - require.True(t, a.Subtract(cube(0, 5)).Equal(cube(6, 10))) - require.True(t, a.Subtract(cube(12, 14)).Equal(cube(1, 10))) + a := hypercube.Cube(1, 10) + require.True(t, a.Subtract(hypercube.Cube(3, 7)).Equal(union(hypercube.Cube(1, 2), hypercube.Cube(8, 10)))) + require.True(t, a.Subtract(hypercube.Cube(3, 20)).Equal(hypercube.Cube(1, 2))) + require.True(t, a.Subtract(hypercube.Cube(0, 20)).IsEmpty()) + require.True(t, a.Subtract(hypercube.Cube(0, 5)).Equal(hypercube.Cube(6, 10))) + require.True(t, a.Subtract(hypercube.Cube(12, 14)).Equal(hypercube.Cube(1, 10))) } func TestBasicSubtract2(t *testing.T) { - a := cube(1, 100, 200, 300).Subtract(cube(50, 60, 220, 300)) + a := hypercube.Cube(1, 100, 200, 300).Subtract(hypercube.Cube(50, 60, 220, 300)) resA := union( - cube(61, 100, 200, 300), - cube(50, 60, 200, 219), - cube(1, 49, 200, 300), + hypercube.Cube(61, 100, 200, 300), + hypercube.Cube(50, 60, 200, 219), + hypercube.Cube(1, 49, 200, 300), ) require.True(t, a.Equal(resA)) - b := cube(1, 100, 200, 300).Subtract(cube(50, 1000, 0, 250)) + b := hypercube.Cube(1, 100, 200, 300).Subtract(hypercube.Cube(50, 1000, 0, 250)) resB := union( - cube(50, 100, 251, 300), - cube(1, 49, 200, 300), + hypercube.Cube(50, 100, 251, 300), + hypercube.Cube(1, 49, 200, 300), ) require.True(t, b.Equal(resB)) c := union( - cube(1, 100, 200, 300), - cube(400, 700, 200, 300), - ).Subtract(cube(50, 1000, 0, 250)) + hypercube.Cube(1, 100, 200, 300), + hypercube.Cube(400, 700, 200, 300), + ).Subtract(hypercube.Cube(50, 1000, 0, 250)) resC := union( - cube(50, 100, 251, 300), - cube(1, 49, 200, 300), - cube(400, 700, 251, 300), + hypercube.Cube(50, 100, 251, 300), + hypercube.Cube(1, 49, 200, 300), + hypercube.Cube(400, 700, 251, 300), ) require.True(t, c.Equal(resC)) - d := cube(1, 100, 200, 300).Subtract(cube(50, 60, 220, 300)) + d := hypercube.Cube(1, 100, 200, 300).Subtract(hypercube.Cube(50, 60, 220, 300)) dRes := union( - cube(1, 49, 200, 300), - cube(50, 60, 200, 219), - cube(61, 100, 200, 300), + hypercube.Cube(1, 49, 200, 300), + hypercube.Cube(50, 60, 200, 219), + hypercube.Cube(61, 100, 200, 300), ) require.True(t, d.Equal(dRes)) } func TestAddHole2(t *testing.T) { c := union( - cube(80, 100, 20, 300), - cube(250, 400, 20, 300), - ).Subtract(cube(30, 300, 100, 102)) + hypercube.Cube(80, 100, 20, 300), + hypercube.Cube(250, 400, 20, 300), + ).Subtract(hypercube.Cube(30, 300, 100, 102)) d := union( - cube(80, 100, 20, 99), - cube(80, 100, 103, 300), - cube(250, 300, 20, 99), - cube(250, 300, 103, 300), - cube(301, 400, 20, 300), + hypercube.Cube(80, 100, 20, 99), + hypercube.Cube(80, 100, 103, 300), + hypercube.Cube(250, 300, 20, 99), + hypercube.Cube(250, 300, 103, 300), + hypercube.Cube(301, 400, 20, 300), ) require.True(t, c.Equal(d)) } func TestSubtractToEmpty(t *testing.T) { - c := cube(1, 100, 200, 300).Subtract(cube(1, 100, 200, 300)) + c := hypercube.Cube(1, 100, 200, 300).Subtract(hypercube.Cube(1, 100, 200, 300)) require.True(t, c.IsEmpty()) } func TestUnion1(t *testing.T) { c := union( - cube(1, 100, 200, 300), - cube(101, 200, 200, 300), + hypercube.Cube(1, 100, 200, 300), + hypercube.Cube(101, 200, 200, 300), ) - cExpected := cube(1, 200, 200, 300) + cExpected := hypercube.Cube(1, 200, 200, 300) require.True(t, cExpected.Equal(c)) } func TestUnion2(t *testing.T) { c := union( - cube(1, 100, 200, 300), - cube(101, 200, 200, 300), - cube(201, 300, 200, 300), - cube(301, 400, 200, 300), - cube(402, 500, 200, 300), - cube(500, 600, 200, 700), - cube(601, 700, 200, 700), + hypercube.Cube(1, 100, 200, 300), + hypercube.Cube(101, 200, 200, 300), + hypercube.Cube(201, 300, 200, 300), + hypercube.Cube(301, 400, 200, 300), + hypercube.Cube(402, 500, 200, 300), + hypercube.Cube(500, 600, 200, 700), + hypercube.Cube(601, 700, 200, 700), ) cExpected := union( - cube(1, 400, 200, 300), - cube(402, 500, 200, 300), - cube(500, 700, 200, 700), + hypercube.Cube(1, 400, 200, 300), + hypercube.Cube(402, 500, 200, 300), + hypercube.Cube(500, 700, 200, 700), ) require.True(t, c.Equal(cExpected)) - d := c.Union(cube(702, 800, 200, 700)) - dExpected := cExpected.Union(cube(702, 800, 200, 700)) + d := c.Union(hypercube.Cube(702, 800, 200, 700)) + dExpected := cExpected.Union(hypercube.Cube(702, 800, 200, 700)) require.True(t, d.Equal(dExpected)) } func TestIntersect(t *testing.T) { - c := cube(5, 15, 3, 10).Intersect(cube(8, 30, 7, 20)) - d := cube(8, 15, 7, 10) + c := hypercube.Cube(5, 15, 3, 10).Intersect(hypercube.Cube(8, 30, 7, 20)) + d := hypercube.Cube(8, 15, 7, 10) require.True(t, c.Equal(d)) } func TestUnionMerge(t *testing.T) { a := union( - cube(5, 15, 3, 6), - cube(5, 30, 7, 10), - cube(8, 30, 11, 20), + hypercube.Cube(5, 15, 3, 6), + hypercube.Cube(5, 30, 7, 10), + hypercube.Cube(8, 30, 11, 20), ) excepted := union( - cube(5, 15, 3, 10), - cube(8, 30, 7, 20), + hypercube.Cube(5, 15, 3, 10), + hypercube.Cube(8, 30, 7, 20), ) require.True(t, excepted.Equal(a)) } func TestSubtract(t *testing.T) { - g := cube(5, 15, 3, 10).Subtract(cube(8, 30, 7, 20)) + g := hypercube.Cube(5, 15, 3, 10).Subtract(hypercube.Cube(8, 30, 7, 20)) h := union( - cube(5, 7, 3, 10), - cube(8, 15, 3, 6), + hypercube.Cube(5, 7, 3, 10), + hypercube.Cube(8, 15, 3, 6), ) require.True(t, g.Equal(h)) } func TestIntersectEmpty(t *testing.T) { - a := cube(5, 15, 3, 10) + a := hypercube.Cube(5, 15, 3, 10) b := union( - cube(1, 3, 7, 20), - cube(20, 23, 7, 20), + hypercube.Cube(1, 3, 7, 20), + hypercube.Cube(20, 23, 7, 20), ) c := a.Intersect(b) require.True(t, c.IsEmpty()) @@ -369,13 +359,37 @@ func TestIntersectEmpty(t *testing.T) { func TestOr2(t *testing.T) { a := union( - cube(1, 79, 10054, 10054), - cube(80, 100, 10053, 10054), - cube(101, 65535, 10054, 10054), + hypercube.Cube(1, 79, 10054, 10054), + hypercube.Cube(80, 100, 10053, 10054), + hypercube.Cube(101, 65535, 10054, 10054), ) expected := union( - cube(80, 100, 10053, 10053), - cube(1, 65535, 10054, 10054), + hypercube.Cube(80, 100, 10053, 10053), + hypercube.Cube(1, 65535, 10054, 10054), ) require.True(t, expected.Equal(a)) } + +func TestSwapDimensions(t *testing.T) { + require.True(t, hypercube.NewCanonicalSet(2).SwapDimensions(0, 1).Equal(hypercube.NewCanonicalSet(2))) + + require.True(t, hypercube.Cube(1, 2).SwapDimensions(0, 0).Equal(hypercube.Cube(1, 2))) + + require.True(t, hypercube.Cube(1, 2, 3, 4).SwapDimensions(0, 1).Equal(hypercube.Cube(3, 4, 1, 2))) + require.True(t, hypercube.Cube(1, 2, 1, 2).SwapDimensions(0, 1).Equal(hypercube.Cube(1, 2, 1, 2))) + + require.True(t, hypercube.Cube(1, 2, 3, 4, 5, 6).SwapDimensions(0, 1).Equal(hypercube.Cube(3, 4, 1, 2, 5, 6))) + require.True(t, hypercube.Cube(1, 2, 3, 4, 5, 6).SwapDimensions(1, 2).Equal(hypercube.Cube(1, 2, 5, 6, 3, 4))) + require.True(t, hypercube.Cube(1, 2, 3, 4, 5, 6).SwapDimensions(0, 2).Equal(hypercube.Cube(5, 6, 3, 4, 1, 2))) + + require.True(t, union( + hypercube.Cube(1, 3, 7, 20), + hypercube.Cube(20, 23, 7, 20), + ).SwapDimensions(0, 1).Equal(union( + hypercube.Cube(7, 20, 1, 3), + hypercube.Cube(7, 20, 20, 23), + ))) + + require.Panics(t, func() { hypercube.Cube(1, 2).SwapDimensions(0, 1) }) + require.Panics(t, func() { hypercube.Cube(1, 2).SwapDimensions(-1, 0) }) +} diff --git a/pkg/interval/interval.go b/pkg/interval/interval.go index b849e58..96328c0 100644 --- a/pkg/interval/interval.go +++ b/pkg/interval/interval.go @@ -1,3 +1,5 @@ +// Copyright 2020- IBM Inc. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 package interval import "fmt" @@ -22,6 +24,10 @@ func (i Interval) Equal(x Interval) bool { return i.Start == x.Start && i.End == x.End } +func (i Interval) Size() int64 { + return i.End - i.Start + 1 +} + func (i Interval) overlaps(other Interval) bool { return other.End >= i.Start && other.Start <= i.End } diff --git a/pkg/interval/intervalset.go b/pkg/interval/intervalset.go index a0634b4..6eacae9 100644 --- a/pkg/interval/intervalset.go +++ b/pkg/interval/intervalset.go @@ -1,3 +1,5 @@ +// Copyright 2020- IBM Inc. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 package interval import ( @@ -7,9 +9,9 @@ import ( "sort" ) -// CanonicalSet is a canonical representation of a set of Interval objects +// CanonicalSet is a set of int64 integers, implemented using an ordered slice of non-overlapping, non-touching interval type CanonicalSet struct { - intervalSet []Interval // sorted list of non-overlapping intervals + intervalSet []Interval } func NewCanonicalSet() *CanonicalSet { @@ -33,11 +35,19 @@ func (c *CanonicalSet) Min() int64 { return c.intervalSet[0].Start } -// IsEmpty returns true if the CanonicalSet is empty +// IsEmpty returns true if the CanonicalSet is empty func (c *CanonicalSet) IsEmpty() bool { return len(c.intervalSet) == 0 } +func (c *CanonicalSet) CalculateSize() int64 { + var res int64 = 0 + for _, r := range c.intervalSet { + res += r.Size() + } + return res +} + // Equal returns true if the CanonicalSet equals the input CanonicalSet func (c *CanonicalSet) Equal(other *CanonicalSet) bool { if c == other { @@ -81,10 +91,6 @@ func (c *CanonicalSet) AddHole(hole Interval) { c.intervalSet = newIntervalSet } -func getNumAsStr(num int64) string { - return fmt.Sprintf("%v", num) -} - // String returns a string representation of the current CanonicalSet object func (c *CanonicalSet) String() string { if c.IsEmpty() { @@ -92,9 +98,10 @@ func (c *CanonicalSet) String() string { } res := "" for _, interval := range c.intervalSet { - res += getNumAsStr(interval.Start) if interval.Start != interval.End { - res += "-" + getNumAsStr(interval.End) + res += fmt.Sprintf("%v-%v", interval.Start, interval.End) + } else { + res += fmt.Sprintf("%v", interval.Start) } res += "," } @@ -119,8 +126,7 @@ func (c *CanonicalSet) Copy() *CanonicalSet { } func (c *CanonicalSet) Contains(n int64) bool { - i := NewSetFromInterval(New(n, n)) - return i.ContainedIn(c) + return New(n, n).ToSet().ContainedIn(c) } // ContainedIn returns true of the current CanonicalSet is contained in the input CanonicalSet @@ -178,26 +184,35 @@ func (c *CanonicalSet) Subtract(other *CanonicalSet) *CanonicalSet { if c == other { return NewCanonicalSet() } - res := slices.Clone(c.intervalSet) - for _, hole := range other.intervalSet { - newIntervalSet := []Interval{} - for _, interval := range res { - newIntervalSet = append(newIntervalSet, interval.subtract(hole)...) - } - res = newIntervalSet - } - return &CanonicalSet{ - intervalSet: res, + res := c.Copy() + for _, i := range other.intervalSet { + res.AddHole(i) } + return res } func (c *CanonicalSet) IsSingleNumber() bool { - if len(c.intervalSet) == 1 && c.intervalSet[0].Start == c.intervalSet[0].End { + if len(c.intervalSet) == 1 && c.intervalSet[0].Size() == 1 { return true } return false } +// Elements returns a slice with all the numbers contained in the set. +// USE WITH CARE. It can easily run out of memory for large sets. +func (c *CanonicalSet) Elements() []int64 { + // allocate memory up front, to fail early + res := make([]int64, c.CalculateSize()) + i := 0 + for _, interval := range c.intervalSet { + for v := interval.Start; v <= interval.End; v++ { + res[i] = v + i++ + } + } + return res +} + func NewSetFromInterval(span Interval) *CanonicalSet { return &CanonicalSet{intervalSet: []Interval{span}} } diff --git a/pkg/interval/intervalset_test.go b/pkg/interval/intervalset_test.go index 2b9010a..b80167c 100644 --- a/pkg/interval/intervalset_test.go +++ b/pkg/interval/intervalset_test.go @@ -1,3 +1,5 @@ +// Copyright 2020- IBM Inc. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 package interval_test import ( @@ -10,6 +12,7 @@ import ( func TestInterval(t *testing.T) { it1 := interval.Interval{3, 7} + require.Equal(t, "[3-7]", it1.String()) } @@ -41,8 +44,7 @@ func TestIntervalSet(t *testing.T) { require.False(t, is1.Overlap(is2)) require.False(t, is2.Overlap(is1)) - is1 = is1.Union(is2) - is1.Union(interval.New(7, 9).ToSet()) + is1 = is1.Union(is2).Union(interval.New(7, 9).ToSet()) require.True(t, is2.ContainedIn(is1)) require.False(t, is1.ContainedIn(is2)) require.True(t, is1.Overlap(is2)) @@ -54,3 +56,13 @@ func TestIntervalSet(t *testing.T) { require.True(t, interval.New(1, 1).ToSet().IsSingleNumber()) } + +func TestIntervalSetSubtract(t *testing.T) { + s := interval.New(1, 100).ToSet() + s.AddInterval(interval.Interval{Start: 400, End: 700}) + d := *interval.New(50, 100).ToSet() + d.AddInterval(interval.Interval{Start: 400, End: 700}) + actual := s.Subtract(&d) + expected := interval.New(1, 49).ToSet() + require.Equal(t, expected.String(), actual.String()) +} diff --git a/pkg/ipblock/ipblock.go b/pkg/ipblock/ipblock.go index d460f06..93f624d 100644 --- a/pkg/ipblock/ipblock.go +++ b/pkg/ipblock/ipblock.go @@ -1,3 +1,5 @@ +// Copyright 2020- IBM Inc. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 package ipblock import ( @@ -125,11 +127,7 @@ func (b *IPBlock) Copy() *IPBlock { } func (b *IPBlock) ipCount() int { - res := 0 - for _, r := range b.ipRange.Intervals() { - res += int(r.End) - int(r.Start) + 1 - } - return res + return int(b.ipRange.CalculateSize()) } // Split returns a set of IpBlock objects, each with a single range of ips @@ -164,15 +162,13 @@ func DisjointIPBlocks(set1, set2 []*IPBlock) []*IPBlock { return ipbList[i].ipCount() < ipbList[j].ipCount() }) // making sure the resulting list does not contain overlapping ipBlocks - blocksWithNoOverlaps := []*IPBlock{} + res := []*IPBlock{} for _, ipb := range ipbList { - blocksWithNoOverlaps = addIntervalToList(ipb, blocksWithNoOverlaps) + res = addIntervalToList(ipb, res) } - res := blocksWithNoOverlaps if len(res) == 0 { - newAll := GetCidrAll() - res = append(res, newAll) + res = []*IPBlock{GetCidrAll()} } return res } @@ -241,8 +237,8 @@ func FromCidrList(cidrsList []string) (*IPBlock, error) { return res, nil } -// Except creates a new IP block that excludes the specified CIDRs from the current IP block -func (b *IPBlock) Except(exceptions ...string) (*IPBlock, error) { +// ExceptCidrs returns a new IPBlock with all cidr ranges removed +func (b *IPBlock) ExceptCidrs(exceptions ...string) (*IPBlock, error) { holes := interval.NewCanonicalSet() for i := range exceptions { intervalHole, err := cidrToInterval(exceptions[i]) @@ -254,20 +250,22 @@ func (b *IPBlock) Except(exceptions ...string) (*IPBlock, error) { return &IPBlock{ipRange: b.ipRange.Subtract(holes)}, nil } -func ipv4AddressToCidr(ipAddress string) string { - return ipAddress + "/32" -} - // FromIPAddress returns an IPBlock object from input IP address string func FromIPAddress(ipAddress string) (*IPBlock, error) { - return FromCidr(ipv4AddressToCidr(ipAddress)) + ipNum, err := parseIP(ipAddress) + if err != nil { + return nil, err + } + return &IPBlock{ + ipRange: interval.New(ipNum, ipNum).ToSet(), + }, nil } -func cidrToIPRange(cidr string) (start, end int64, err error) { +func cidrToInterval(cidr string) (interval.Interval, error) { // convert string to IPNet struct _, ipv4Net, err := net.ParseCIDR(cidr) if err != nil { - return + return interval.Interval{}, err } // convert IPNet struct mask and address to uint32 @@ -276,24 +274,14 @@ func cidrToIPRange(cidr string) (start, end int64, err error) { startNum := binary.BigEndian.Uint32(ipv4Net.IP) // find the final address endNum := (startNum & mask) | (mask ^ ipMask) - start = int64(startNum) - end = int64(endNum) - return -} - -func cidrToInterval(cidr string) (interval.Interval, error) { - start, end, err := cidrToIPRange(cidr) - if err != nil { - return interval.Interval{}, err - } - return interval.Interval{Start: start, End: end}, nil + return interval.New(int64(startNum), int64(endNum)), nil } // ToCidrList returns a list of CIDR strings for this IPBlock object func (b *IPBlock) ToCidrList() []string { cidrList := []string{} - for _, interval := range b.ipRange.Intervals() { - cidrList = append(cidrList, intervalToCidrList(interval.Start, interval.End)...) + for _, ipRange := range b.ipRange.Intervals() { + cidrList = append(cidrList, intervalToCidrList(ipRange)...) } return cidrList } @@ -306,12 +294,12 @@ func (b *IPBlock) ToCidrListString() string { // ListToPrint: returns a uniform to print list s.t. each element contains either a single cidr or an ip range func (b *IPBlock) ListToPrint() []string { cidrsIPRangesList := []string{} - for _, interval := range b.ipRange.Intervals() { - cidr := intervalToCidrList(interval.Start, interval.End) + for _, ipRange := range b.ipRange.Intervals() { + cidr := intervalToCidrList(ipRange) if len(cidr) == 1 { cidrsIPRangesList = append(cidrsIPRangesList, cidr[0]) } else { - cidrsIPRangesList = append(cidrsIPRangesList, toIPRange(interval)) + cidrsIPRangesList = append(cidrsIPRangesList, toIPRange(ipRange)) } } return cidrsIPRangesList @@ -325,9 +313,9 @@ func (b *IPBlock) ToIPAddressString() string { return "" } -func intervalToCidrList(ipStart, ipEnd int64) []string { - start := ipStart - end := ipEnd +func intervalToCidrList(ipRange interval.Interval) []string { + start := ipRange.Start + end := ipRange.End res := []string{} for end >= start { maxSize := maxIPv4Bits @@ -352,22 +340,25 @@ func intervalToCidrList(ipStart, ipEnd int64) []string { return res } +func parseIP(ip string) (int64, error) { + startIP := net.ParseIP(ip) + if startIP == nil { + return 0, fmt.Errorf("%v is not a valid ipv4", ip) + } + return int64(binary.BigEndian.Uint32(startIP.To4())), nil +} + // FromIPRangeStr returns IPBlock object from input IP range string (example: "169.255.0.0-172.15.255.255") func FromIPRangeStr(ipRangeStr string) (*IPBlock, error) { ipAddresses := strings.Split(ipRangeStr, dash) if len(ipAddresses) != 2 { return nil, errors.New("unexpected ipRange str") } - var startIP, endIP *IPBlock - var err error - if startIP, err = FromIPAddress(ipAddresses[0]); err != nil { - return nil, err - } - if endIP, err = FromIPAddress(ipAddresses[1]); err != nil { - return nil, err + startIPNum, err0 := parseIP(ipAddresses[0]) + endIPNum, err1 := parseIP(ipAddresses[1]) + if err0 != nil || err1 != nil { + return nil, errors.Join(err0, err1) } - startIPNum := startIP.ipRange.Min() - endIPNum := endIP.ipRange.Min() res := &IPBlock{ ipRange: interval.New(startIPNum, endIPNum).ToSet(), } diff --git a/pkg/ipblock/ipblock_test.go b/pkg/ipblock/ipblock_test.go index 74ed5f6..bda8b19 100644 --- a/pkg/ipblock/ipblock_test.go +++ b/pkg/ipblock/ipblock_test.go @@ -1,3 +1,5 @@ +// Copyright 2020- IBM Inc. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 package ipblock_test import ( @@ -23,15 +25,15 @@ func TestOps(t *testing.T) { minus2, err := ipblock.FromCidr(ipb1.ToCidrListString()) require.Nil(t, err) - minus2, err = minus2.Except(ipb2.ToCidrListString()) + minus2, err = minus2.ExceptCidrs(ipb2.ToCidrListString()) require.Nil(t, err) require.Equal(t, minus.ToCidrListString(), minus2.ToCidrListString()) intersect := ipb1.Intersect(ipb2) - require.True(t, intersect.Equal(ipb2)) + require.Equal(t, intersect, ipb2) union := intersect.Union(minus) - require.True(t, union.Equal(ipb1)) + require.Equal(t, union, ipb1) intersect2 := minus.Intersect(intersect) require.True(t, intersect2.IsEmpty()) @@ -115,7 +117,7 @@ func TestBadPath(t *testing.T) { _, err = ipblock.FromCidr("2.5.7.9/24") require.Nil(t, err) - _, err = ipblock.New().Except("5.6.7.8/20", "not-a-cidr") + _, err = ipblock.New().ExceptCidrs("5.6.7.8/20", "not-a-cidr") require.NotNil(t, err) _, err = ipblock.FromCidrList([]string{"1.2.3.4/20", "not-a-cidr"})