Skip to content

Commit

Permalink
change Dimension to not encode the order of the dimensions
Browse files Browse the repository at this point in the history
Signed-off-by: Elazar Gershuni <[email protected]>
  • Loading branch information
Elazar Gershuni committed Mar 21, 2024
1 parent 6c6ed76 commit c5eb591
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 40 deletions.
54 changes: 27 additions & 27 deletions pkg/connection/connectionset.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package connection

import (
"log"
"slices"
"sort"
"strings"

Expand Down Expand Up @@ -32,15 +33,14 @@ const (
NoConnections = "No Connections"
)

type Dimension int
type Dimension string

const (
protocol Dimension = 0
srcPort Dimension = 1
dstPort Dimension = 2
icmpType Dimension = 3
icmpCode Dimension = 4
numDimensions = 5
protocol Dimension = "protocol"
srcPort Dimension = "srcPort"
dstPort Dimension = "dstPort"
icmpType Dimension = "icmpType"
icmpCode Dimension = "icmpCode"
)

const propertySeparator string = " "
Expand Down Expand Up @@ -71,25 +71,21 @@ func entireDimension(dim Dimension) *interval.CanonicalSet {
return nil
}

func getDimensionDomainsList() []*interval.CanonicalSet {
res := make([]*interval.CanonicalSet, len(dimensionsList))
for i := range dimensionsList {
res[i] = entireDimension(dimensionsList[i])
}
return res
}

type Set struct {
connectionProperties *hypercube.CanonicalSet
IsStateful StatefulState
}

func None() *Set {
return &Set{connectionProperties: hypercube.NewCanonicalSet(numDimensions)}
return &Set{connectionProperties: hypercube.NewCanonicalSet(len(dimensionsList))}
}

func All() *Set {
return &Set{connectionProperties: hypercube.FromCube(getDimensionDomainsList())}
all := make([]*interval.CanonicalSet, len(dimensionsList))
for i := range dimensionsList {
all[i] = entireDimension(dimensionsList[i])
}
return &Set{connectionProperties: hypercube.FromCube(all)}
}

var all = All()
Expand Down Expand Up @@ -149,7 +145,7 @@ func (c *Set) Subtract(other *Set) *Set {
func (c *Set) ContainedIn(other *Set) bool {
res, err := c.connectionProperties.ContainedIn(other.connectionProperties)
if err != nil {
log.Fatalf("invalid connection set. %e", err)
log.Panicf("invalid connection set. %e", err)
}
return res
}
Expand All @@ -163,7 +159,7 @@ func protocolStringToCode(protocol netp.ProtocolString) int64 {
case netp.ProtocolStringICMP:
return ICMPCode
}
log.Fatalf("Impossible protocol code %v", protocol)
log.Panicf("Impossible protocol code %v", protocol)
return 0
}

Expand Down Expand Up @@ -198,12 +194,12 @@ func protocolStringFromCode(protocolCode int64) netp.ProtocolString {
case ICMPCode:
return netp.ProtocolStringICMP
}
log.Fatalf("impossible protocol code %v", protocolCode)
log.Panicf("impossible protocol code %v", protocolCode)
return ""
}

func getDimensionString(cube []*interval.CanonicalSet, dim Dimension) string {
dimValue := cube[dim]
dimValue := cubeAt(cube, dim)
if dimValue.Equal(entireDimension(dim)) {
// avoid adding dimension str on full dimension values
return ""
Expand Down Expand Up @@ -242,7 +238,7 @@ func joinNonEmpty(inputList ...string) string {
}

func getConnsCubeStr(cube []*interval.CanonicalSet) string {
protocols := cube[protocol]
protocols := cubeAt(cube, protocol)
tcpOrUDP := protocols.Contains(TCPCode) || protocols.Contains(UDPCode)
icmp := protocols.Contains(ICMPCode)
switch {
Expand Down Expand Up @@ -281,11 +277,15 @@ func (c *Set) String() string {
return strings.Join(resStrings, "; ")
}

func cubeAt(cube []*interval.CanonicalSet, dim Dimension) *interval.CanonicalSet {
return cube[slices.Index(dimensionsList, dim)]
}

func getCubeAsTCPItems(cube []*interval.CanonicalSet, protocol spec.TcpUdpProtocol) []spec.TcpUdp {
tcpItemsTemp := []spec.TcpUdp{}
tcpItemsFinal := []spec.TcpUdp{}
// consider src ports
srcPorts := cube[srcPort]
srcPorts := cubeAt(cube, srcPort)
if !srcPorts.Equal(entireDimension(srcPort)) {
// iterate the interval in the interval-set
for _, interval := range srcPorts.Intervals() {
Expand All @@ -296,7 +296,7 @@ func getCubeAsTCPItems(cube []*interval.CanonicalSet, protocol spec.TcpUdpProtoc
tcpItemsTemp = append(tcpItemsTemp, spec.TcpUdp{Protocol: protocol})
}
// consider dst ports
dstPorts := cube[dstPort]
dstPorts := cubeAt(cube, dstPort)
if !dstPorts.Equal(entireDimension(dstPort)) {
// iterate the interval in the interval-set
for _, interval := range dstPorts.Intervals() {
Expand All @@ -318,8 +318,8 @@ func getCubeAsTCPItems(cube []*interval.CanonicalSet, protocol spec.TcpUdpProtoc
}

func getCubeAsICMPItems(cube []*interval.CanonicalSet) []spec.Icmp {
icmpTypes := cube[icmpType]
icmpCodes := cube[icmpCode]
icmpTypes := cubeAt(cube, icmpType)
icmpCodes := cubeAt(cube, icmpCode)
allTypes := icmpTypes.Equal(entireDimension(icmpType))
allCodes := icmpCodes.Equal(entireDimension(icmpCode))
switch {
Expand Down Expand Up @@ -367,7 +367,7 @@ func ToJSON(c *Set) Details {

cubes := c.connectionProperties.GetCubesList()
for _, cube := range cubes {
protocols := cube[protocol]
protocols := cubeAt(cube, protocol)
if protocols.Contains(TCPCode) {
tcpItems := getCubeAsTCPItems(cube, spec.TcpUdpProtocolTCP)
for _, item := range tcpItems {
Expand Down
17 changes: 4 additions & 13 deletions pkg/connection/statefulness.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ package connection
import (
"slices"

"github.com/np-guard/models/pkg/hypercube"
"github.com/np-guard/models/pkg/netp"
)

Expand Down Expand Up @@ -69,19 +68,11 @@ func (c *Set) connTCPWithStatefulness(secondDirectionConn *Set) *Set {
// It assumes the input connection object is only within TCP protocol.
// For TCP the src and dst ports on relevant cubes are being switched.
func (c *Set) switchSrcDstPortsOnTCP() *Set {
if c.IsAll() || c.IsEmpty() {
if c.IsAll() {
return c.Copy()
}
res := None()
for _, cube := range c.connectionProperties.GetCubesList() {
// assuming cube[protocol] contains TCP only
// no need to switch if src equals dst
if !cube[srcPort].Equal(cube[dstPort]) {
// Shallow clone should be enough, since we do shallow swap:
cube = slices.Clone(cube)
cube[srcPort], cube[dstPort] = cube[dstPort], cube[srcPort]
}
res.connectionProperties = res.connectionProperties.Union(hypercube.FromCube(cube))
newConn := c.connectionProperties.SwapDimensions(slices.Index(dimensionsList, srcPort), slices.Index(dimensionsList, dstPort))
return &Set{
connectionProperties: newConn,
}
return res
}
19 changes: 19 additions & 0 deletions pkg/hypercube/hypercubeset.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package hypercube

import (
"errors"
"slices"
"sort"
"strings"

Expand Down Expand Up @@ -250,6 +251,24 @@ func (c *CanonicalSet) GetCubesList() [][]*interval.CanonicalSet {
return res
}

// SwapDimensions returns a new CanonicalSet object, built from the input CanonicalSet object,
// with dimensions dim1 and dim2 swapped
func (c *CanonicalSet) SwapDimensions(dim1, dim2 int) *CanonicalSet {
if c.IsEmpty() {
return c.Copy()
}
res := NewCanonicalSet(c.dimensions)
for _, cube := range c.GetCubesList() {
if !cube[dim1].Equal(cube[dim2]) {
// Shallow clone should be enough, since we do shallow swap:
cube = slices.Clone(cube)
cube[dim1], cube[dim2] = cube[dim2], cube[dim1]
}
res = res.Union(FromCube(cube))
}
return res
}

func getElementsUnionPerLayer(layers map[*interval.CanonicalSet]*CanonicalSet) map[*interval.CanonicalSet]*CanonicalSet {
type pair struct {
hc *CanonicalSet // hypercube set object
Expand Down

0 comments on commit c5eb591

Please sign in to comment.