diff --git a/pkg/connectionset/statefulness.go b/pkg/connectionset/statefulness.go new file mode 100644 index 0000000..48046a2 --- /dev/null +++ b/pkg/connectionset/statefulness.go @@ -0,0 +1,87 @@ +package connectionset + +import "github.com/np-guard/models/pkg/hypercubes" + +const ( + // StatefulUnknown is the default value for a ConnectionSet object, + StatefulUnknown int = iota + // StatefulTrue represents a connection object for which any allowed TCP (on all allowed src/dst ports) + // has an allowed response connection + StatefulTrue + // StatefulFalse represents a connection object for which there exists some allowed TCP + // (on any allowed subset from the allowed src/dst ports) that does not have an allowed response connection + StatefulFalse +) + +// EnhancedString returns a connection string with possibly added asterisk for stateless connection +func (conn *ConnectionSet) EnhancedString() string { + if conn.IsStateful == StatefulFalse { + return conn.String() + " *" + } + return conn.String() +} + +// ConnectionWithStatefulness updates `conn` object with `IsStateful` property, based on input `secondDirectionConn`. +// `conn` represents a src-to-dst connection, and `secondDirectionConn` represents dst-to-src connection. +// The property `IsStateful` of `conn` is set as `StatefulFalse` if there is at least some subset within TCP from `conn` +// which is not stateful (such that the response direction for this subset is not enabled). +// This function also returns a connection object with the exact subset of the stateful part (within TCP) +// from the entire connection `conn`, and with the original connections on other protocols. +func (conn *ConnectionSet) ConnectionWithStatefulness(secondDirectionConn *ConnectionSet) *ConnectionSet { + connTCP := conn.tcpConn() + if connTCP.IsEmpty() { + conn.IsStateful = StatefulTrue + return conn + } + secondDirectionConnTCP := secondDirectionConn.tcpConn() + statefulCombinedConnTCP := connTCP.connTCPWithStatefulness(secondDirectionConnTCP) + conn.IsStateful = connTCP.IsStateful + nonTCP := conn.Subtract(connTCP) + return nonTCP.Union(statefulCombinedConnTCP) +} + +// connTCPWithStatefulness assumes that both `conn` and `secondDirectionConn` are within TCP. +// it assigns IsStateful a value within `conn`, and returns the subset from `conn` which is stateful. +func (conn *ConnectionSet) connTCPWithStatefulness(secondDirectionConn *ConnectionSet) *ConnectionSet { + secondDirectionSwitchPortsDirection := secondDirectionConn.switchSrcDstPortsOnTCP() + // flip src/dst ports before intersection + statefulCombinedConn := conn.Intersection(secondDirectionSwitchPortsDirection) + if !conn.Equal(statefulCombinedConn) { + conn.IsStateful = StatefulFalse + } else { + conn.IsStateful = StatefulTrue + } + return statefulCombinedConn +} + +// tcpConn returns a new ConnectionSet object, which is the intersection of `conn` with TCP +func (conn *ConnectionSet) tcpConn() *ConnectionSet { + res := NewConnectionSet(false) + res.AddTCPorUDPConn(ProtocolTCP, MinPort, MaxPort, MinPort, MaxPort) + return conn.Intersection(res) +} + +// switchSrcDstPortsOnTCP returns a new ConnectionSet object, built from the input ConnectionSet object. +// It assumes the input connection object is only within TCP protocol. +// For TCP the src and dst ports on relevant cubes are being switched. +func (conn *ConnectionSet) switchSrcDstPortsOnTCP() *ConnectionSet { + if conn.AllowAll || conn.IsEmpty() { + return conn.Copy() + } + res := NewConnectionSet(false) + cubes := conn.connectionProperties.GetCubesList() + for _, cube := range cubes { + // assuming cube[protocol] contains TCP only + srcPorts := cube[srcPort] + dstPorts := cube[dstPort] + // if the entire domain is enabled by both src and dst no need to switch + if !srcPorts.Equal(*getDimensionDomain(srcPort)) || !dstPorts.Equal(*getDimensionDomain(dstPort)) { + newCube := copyCube(cube) + newCube[srcPort], newCube[dstPort] = newCube[dstPort], newCube[srcPort] + res.connectionProperties = res.connectionProperties.Union(hypercubes.CreateFromCube(newCube)) + } else { + res.connectionProperties = res.connectionProperties.Union(hypercubes.CreateFromCube(cube)) + } + } + return res +} diff --git a/pkg/connectionset/statefulness_test.go b/pkg/connectionset/statefulness_test.go new file mode 100644 index 0000000..c10b3ba --- /dev/null +++ b/pkg/connectionset/statefulness_test.go @@ -0,0 +1,142 @@ +package connectionset + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func newTCPConn(srcMinP, srcMaxP, dstMinP, dstMaxP int64) *ConnectionSet { + res := NewConnectionSet(false) + res.AddTCPorUDPConn(ProtocolTCP, srcMinP, srcMaxP, dstMinP, dstMaxP) + return res +} + +func newUDPConn(srcMinP, srcMaxP, dstMinP, dstMaxP int64) *ConnectionSet { + res := NewConnectionSet(false) + res.AddTCPorUDPConn(ProtocolUDP, srcMinP, srcMaxP, dstMinP, dstMaxP) + return res +} + +func newICMPconn() *ConnectionSet { + res := NewConnectionSet(false) + res.AddICMPConnection(MinICMPtype, MaxICMPtype, MinICMPcode, MaxICMPcode) + return res +} + +func allButTCP() *ConnectionSet { + res := NewConnectionSet(true) + tcpOnly := res.tcpConn() + return res.Subtract(tcpOnly) +} + +type statefulnessTest struct { + name string + srcToDst *connectionset.ConnectionSet + dstToSrc *connectionset.ConnectionSet + // expectedIsStateful represents the expected IsStateful computed value for srcToDst, + // which should be either StatefulTrue or StatefulFalse, given the input dstToSrc connection. + // the computation applies only to the TCP protocol within those connections. + expectedIsStateful int + // expectedStatefulConn represents the subset from srcToDst which is not related to the "non-stateful" mark (*) on the srcToDst connection, + // the stateless part for TCP is srcToDst.Subtract(statefuleConn) + expectedStatefulConn *connectionset.ConnectionSet +} + +var testCasesStatefulness = []statefulnessTest{ + { + name: "tcp_all_ports_on_both_directions", + srcToDst: newTCPConn(MinPort, MaxPort, MinPort, MaxPort), // TCP all ports + dstToSrc: newTCPConn(MinPort, MaxPort, MinPort, MaxPort), // TCP all ports + expectedIsStateful: StatefulTrue, + expectedStatefulConn: newTCPConn(MinPort, MaxPort, MinPort, MaxPort), // TCP all ports + }, + { + name: "first_all_cons_second_tcp_with_ports", + srcToDst: connectionset.NewConnectionSet(true), // all connections + dstToSrc: newTCPConn(80, 80, MinPort, MaxPort), // TCP , src-ports: 80, dst-ports: all + + // there is a subset of the tcp connection which is not stateful + expectedIsStateful: StatefulFalse, + + // TCP src-ports: all, dst-port: 80 , union: all non-TCP conns + expectedStatefulConn: allButTCP().Union(newTCPConn(MinPort, MaxPort, 80, 80)), + }, + { + name: "first_all_conns_second_no_tcp", + srcToDst: NewConnectionSet(true), // all connections + dstToSrc: newICMPconn(), // ICMP + expectedIsStateful: StatefulFalse, + expectedStatefulConn: allButTCP(), // UDP, ICMP (all TCP is considered stateless here) + }, + { + name: "tcp_with_ports_both_directions_exact_match", + srcToDst: newTCPConn(80, 80, 443, 443), + dstToSrc: newTCPConn(443, 443, 80, 80), + expectedIsStateful: StatefulTrue, + expectedStatefulConn: newTCPConn(80, 80, 443, 443), + }, + { + name: "tcp_with_ports_both_directions_partial_match", + srcToDst: newTCPConn(80, 100, 443, 443), + dstToSrc: newTCPConn(443, 443, 80, 80), + expectedIsStateful: StatefulFalse, + expectedStatefulConn: newTCPConn(80, 80, 443, 443), + }, + { + name: "tcp_with_ports_both_directions_no_match", + srcToDst: newTCPConn(80, 100, 443, 443), + dstToSrc: newTCPConn(80, 80, 80, 80), + expectedIsStateful: StatefulFalse, + expectedStatefulConn: NewConnectionSet(false), + }, + { + name: "udp_and_tcp_with_ports_both_directions_no_match", + srcToDst: newTCPConn(80, 100, 443, 443).Union(newUDPConn(80, 100, 443, 443)), + dstToSrc: newTCPConn(80, 80, 80, 80).Union(newUDPConn(80, 80, 80, 80)), + expectedIsStateful: StatefulFalse, + expectedStatefulConn: newUDPConn(80, 100, 443, 443), + }, + { + name: "no_tcp_in_first_direction", + srcToDst: newUDPConn(80, 100, 443, 443), + dstToSrc: newTCPConn(80, 80, 80, 80).Union(newUDPConn(80, 80, 80, 80)), + expectedIsStateful: StatefulTrue, + expectedStatefulConn: newUDPConn(80, 100, 443, 443), + }, + { + name: "empty_conn_in_first_direction", + srcToDst: NewConnectionSet(false), + dstToSrc: newTCPConn(80, 80, 80, 80).Union(newUDPConn(MinPort, MaxPort, MinPort, MaxPort)), + expectedIsStateful: StatefulTrue, + expectedStatefulConn: NewConnectionSet(false), + }, + { + name: "only_udp_icmp_in_first_direction_and_empty_second_direction", + srcToDst: newUDPConn(MinPort, MaxPort, MinPort, MaxPort).Union(newICMPconn()), + dstToSrc: NewConnectionSet(false), + // stateful analysis does not apply to udp/icmp, thus considered in the result as "stateful" + // (to avoid marking it as stateless in the output) + expectedIsStateful: StatefulTrue, + expectedStatefulConn: newUDPConn(MinPort, MaxPort, MinPort, MaxPort).Union(newICMPconn()), + }, +} + +func (tt statefulnessTest) runTest(t *testing.T) { + statefuleConn := tt.srcToDst.ConnectionWithStatefulness(tt.dstToSrc) + require.Equal(t, tt.expectedIsStateful, tt.srcToDst.IsStateful) + require.True(t, tt.expectedStatefulConn.Equal(statefuleConn)) +} + +func TestAll(t *testing.T) { + // explainTests is the list of tests to run + for testIdx := range testCasesStatefulness { + tt := testCasesStatefulness[testIdx] + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + tt.runTest(t) + }) + } + fmt.Println("done") +}