diff --git a/pkg/connection/connectionset.go b/pkg/connection/connectionset.go index 8a8f8aa..1cbdd1d 100644 --- a/pkg/connection/connectionset.go +++ b/pkg/connection/connectionset.go @@ -4,6 +4,7 @@ package connection import ( "log" + "slices" "sort" "strings" @@ -32,15 +33,14 @@ const ( NoConnections = "No Connections" ) -type Dimension int +type Dimension string const ( - protocol Dimension = 0 - srcPort Dimension = 1 - dstPort Dimension = 2 - icmpType Dimension = 3 - icmpCode Dimension = 4 - numDimensions = 5 + protocol Dimension = "protocol" + srcPort Dimension = "srcPort" + dstPort Dimension = "dstPort" + icmpType Dimension = "icmpType" + icmpCode Dimension = "icmpCode" ) const propertySeparator string = " " @@ -71,25 +71,21 @@ func entireDimension(dim Dimension) *interval.CanonicalSet { return nil } -func getDimensionDomainsList() []*interval.CanonicalSet { - res := make([]*interval.CanonicalSet, len(dimensionsList)) - for i := range dimensionsList { - res[i] = entireDimension(dimensionsList[i]) - } - return res -} - type Set struct { connectionProperties *hypercube.CanonicalSet IsStateful StatefulState } func None() *Set { - return &Set{connectionProperties: hypercube.NewCanonicalSet(numDimensions)} + return &Set{connectionProperties: hypercube.NewCanonicalSet(len(dimensionsList))} } func All() *Set { - return &Set{connectionProperties: hypercube.FromCube(getDimensionDomainsList())} + all := make([]*interval.CanonicalSet, len(dimensionsList)) + for i := range dimensionsList { + all[i] = entireDimension(dimensionsList[i]) + } + return &Set{connectionProperties: hypercube.FromCube(all)} } var all = All() @@ -149,7 +145,7 @@ func (c *Set) Subtract(other *Set) *Set { func (c *Set) ContainedIn(other *Set) bool { res, err := c.connectionProperties.ContainedIn(other.connectionProperties) if err != nil { - log.Fatalf("invalid connection set. %e", err) + log.Panicf("invalid connection set. %e", err) } return res } @@ -163,7 +159,7 @@ func protocolStringToCode(protocol netp.ProtocolString) int64 { case netp.ProtocolStringICMP: return ICMPCode } - log.Fatalf("Impossible protocol code %v", protocol) + log.Panicf("Impossible protocol code %v", protocol) return 0 } @@ -198,12 +194,12 @@ func protocolStringFromCode(protocolCode int64) netp.ProtocolString { case ICMPCode: return netp.ProtocolStringICMP } - log.Fatalf("impossible protocol code %v", protocolCode) + log.Panicf("impossible protocol code %v", protocolCode) return "" } func getDimensionString(cube []*interval.CanonicalSet, dim Dimension) string { - dimValue := cube[dim] + dimValue := cubeAt(cube, dim) if dimValue.Equal(entireDimension(dim)) { // avoid adding dimension str on full dimension values return "" @@ -242,7 +238,7 @@ func joinNonEmpty(inputList ...string) string { } func getConnsCubeStr(cube []*interval.CanonicalSet) string { - protocols := cube[protocol] + protocols := cubeAt(cube, protocol) tcpOrUDP := protocols.Contains(TCPCode) || protocols.Contains(UDPCode) icmp := protocols.Contains(ICMPCode) switch { @@ -281,11 +277,15 @@ func (c *Set) String() string { return strings.Join(resStrings, "; ") } +func cubeAt(cube []*interval.CanonicalSet, dim Dimension) *interval.CanonicalSet { + return cube[slices.Index(dimensionsList, dim)] +} + func getCubeAsTCPItems(cube []*interval.CanonicalSet, protocol spec.TcpUdpProtocol) []spec.TcpUdp { tcpItemsTemp := []spec.TcpUdp{} tcpItemsFinal := []spec.TcpUdp{} // consider src ports - srcPorts := cube[srcPort] + srcPorts := cubeAt(cube, srcPort) if !srcPorts.Equal(entireDimension(srcPort)) { // iterate the interval in the interval-set for _, interval := range srcPorts.Intervals() { @@ -296,7 +296,7 @@ func getCubeAsTCPItems(cube []*interval.CanonicalSet, protocol spec.TcpUdpProtoc tcpItemsTemp = append(tcpItemsTemp, spec.TcpUdp{Protocol: protocol}) } // consider dst ports - dstPorts := cube[dstPort] + dstPorts := cubeAt(cube, dstPort) if !dstPorts.Equal(entireDimension(dstPort)) { // iterate the interval in the interval-set for _, interval := range dstPorts.Intervals() { @@ -318,8 +318,8 @@ func getCubeAsTCPItems(cube []*interval.CanonicalSet, protocol spec.TcpUdpProtoc } func getCubeAsICMPItems(cube []*interval.CanonicalSet) []spec.Icmp { - icmpTypes := cube[icmpType] - icmpCodes := cube[icmpCode] + icmpTypes := cubeAt(cube, icmpType) + icmpCodes := cubeAt(cube, icmpCode) allTypes := icmpTypes.Equal(entireDimension(icmpType)) allCodes := icmpCodes.Equal(entireDimension(icmpCode)) switch { @@ -367,7 +367,7 @@ func ToJSON(c *Set) Details { cubes := c.connectionProperties.GetCubesList() for _, cube := range cubes { - protocols := cube[protocol] + protocols := cubeAt(cube, protocol) if protocols.Contains(TCPCode) { tcpItems := getCubeAsTCPItems(cube, spec.TcpUdpProtocolTCP) for _, item := range tcpItems { diff --git a/pkg/connection/statefulness.go b/pkg/connection/statefulness.go index ceb1095..8226be5 100644 --- a/pkg/connection/statefulness.go +++ b/pkg/connection/statefulness.go @@ -5,7 +5,6 @@ package connection import ( "slices" - "github.com/np-guard/models/pkg/hypercube" "github.com/np-guard/models/pkg/netp" ) @@ -69,19 +68,11 @@ func (c *Set) connTCPWithStatefulness(secondDirectionConn *Set) *Set { // It assumes the input connection object is only within TCP protocol. // For TCP the src and dst ports on relevant cubes are being switched. func (c *Set) switchSrcDstPortsOnTCP() *Set { - if c.IsAll() || c.IsEmpty() { + if c.IsAll() { return c.Copy() } - res := None() - for _, cube := range c.connectionProperties.GetCubesList() { - // assuming cube[protocol] contains TCP only - // no need to switch if src equals dst - if !cube[srcPort].Equal(cube[dstPort]) { - // Shallow clone should be enough, since we do shallow swap: - cube = slices.Clone(cube) - cube[srcPort], cube[dstPort] = cube[dstPort], cube[srcPort] - } - res.connectionProperties = res.connectionProperties.Union(hypercube.FromCube(cube)) + newConn := c.connectionProperties.SwapDimensions(slices.Index(dimensionsList, srcPort), slices.Index(dimensionsList, dstPort)) + return &Set{ + connectionProperties: newConn, } - return res } diff --git a/pkg/hypercube/hypercubeset.go b/pkg/hypercube/hypercubeset.go index 635a537..e8e022f 100644 --- a/pkg/hypercube/hypercubeset.go +++ b/pkg/hypercube/hypercubeset.go @@ -4,6 +4,7 @@ package hypercube import ( "errors" + "slices" "sort" "strings" @@ -250,6 +251,24 @@ func (c *CanonicalSet) GetCubesList() [][]*interval.CanonicalSet { return res } +// SwapDimensions returns a new CanonicalSet object, built from the input CanonicalSet object, +// with dimensions dim1 and dim2 swapped +func (c *CanonicalSet) SwapDimensions(dim1, dim2 int) *CanonicalSet { + if c.IsEmpty() { + return c.Copy() + } + res := NewCanonicalSet(c.dimensions) + for _, cube := range c.GetCubesList() { + if !cube[dim1].Equal(cube[dim2]) { + // Shallow clone should be enough, since we do shallow swap: + cube = slices.Clone(cube) + cube[dim1], cube[dim2] = cube[dim2], cube[dim1] + } + res = res.Union(FromCube(cube)) + } + return res +} + func getElementsUnionPerLayer(layers map[*interval.CanonicalSet]*CanonicalSet) map[*interval.CanonicalSet]*CanonicalSet { type pair struct { hc *CanonicalSet // hypercube set object