From 5178baab749f6c1aa42d141e65648535c488e8d7 Mon Sep 17 00:00:00 2001 From: Elazar Gershuni Date: Wed, 6 Mar 2024 16:02:21 +0200 Subject: [PATCH] move connectionset from analyzer and protocols from synthesizer Signed-off-by: Elazar Gershuni --- Makefile | 2 +- pkg/connectionset/connectionset.go | 437 ++++++++++++++++++++++++ pkg/connectionset/connectionset_test.go | 52 +++ pkg/connectionset/statefulness.go | 101 ++++++ pkg/connectionset/statefulness_test.go | 149 ++++++++ pkg/intervals/intervalset.go | 10 + pkg/netp/common.go | 18 + pkg/netp/icmp.go | 99 ++++++ pkg/netp/tcpudp.go | 46 +++ 9 files changed, 913 insertions(+), 1 deletion(-) create mode 100644 pkg/connectionset/connectionset.go create mode 100644 pkg/connectionset/connectionset_test.go create mode 100644 pkg/connectionset/statefulness.go create mode 100644 pkg/connectionset/statefulness_test.go create mode 100644 pkg/netp/common.go create mode 100644 pkg/netp/icmp.go create mode 100644 pkg/netp/tcpudp.go diff --git a/Makefile b/Makefile index 9a9eb26..d584912 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -REPOSITORY := github.com/np-guard/common +REPOSITORY := github.com/np-guard/models mod: go.mod @echo -- $@ -- diff --git a/pkg/connectionset/connectionset.go b/pkg/connectionset/connectionset.go new file mode 100644 index 0000000..3c7dae3 --- /dev/null +++ b/pkg/connectionset/connectionset.go @@ -0,0 +1,437 @@ +package connectionset + +import ( + "log" + "sort" + "strings" + + "github.com/np-guard/models/pkg/hypercubes" + "github.com/np-guard/models/pkg/intervals" + "github.com/np-guard/models/pkg/netp" +) + +const ( + ICMPCode = -1 + TCPCode = 0 + UDPCode = 1 + MinICMPtype int64 = 0 + MaxICMPtype int64 = netp.InformationReply + MinICMPcode int64 = 0 + MaxICMPcode int64 = 5 + minProtocol int64 = ICMPCode + maxProtocol int64 = UDPCode + MinPort = 1 + MaxPort = netp.DefaultMaxPort +) + +const ( + AllConnections = "All Connections" + NoConnections = "No Connections" +) + +type Dimension int + +const ( + protocol Dimension = 0 + srcPort Dimension = 1 + dstPort Dimension = 2 + icmpType Dimension = 3 + icmpCode Dimension = 4 + numDimensions = 5 +) + +const propertySeparator string = " " + +// dimensionsList is the ordered list of dimensions in the ConnectionSet object +// this should be the only place where the order is hard-coded +var dimensionsList = []Dimension{ + protocol, + srcPort, + dstPort, + icmpType, + icmpCode, +} + +func entireDimension(dim Dimension) *intervals.CanonicalIntervalSet { + switch dim { + case protocol: + return intervals.CreateFromInterval(minProtocol, maxProtocol) + case srcPort: + return intervals.CreateFromInterval(MinPort, MaxPort) + case dstPort: + return intervals.CreateFromInterval(MinPort, MaxPort) + case icmpType: + return intervals.CreateFromInterval(MinICMPtype, MaxICMPtype) + case icmpCode: + return intervals.CreateFromInterval(MinICMPcode, MaxICMPcode) + } + return nil +} + +func getDimensionDomainsList() []*intervals.CanonicalIntervalSet { + res := make([]*intervals.CanonicalIntervalSet, len(dimensionsList)) + for i := range dimensionsList { + res[i] = entireDimension(dimensionsList[i]) + } + return res +} + +type ConnectionSet struct { + AllowAll bool + connectionProperties *hypercubes.CanonicalHypercubeSet + IsStateful int // default is StatefulUnknown +} + +func NewConnectionSet(all bool) *ConnectionSet { + return &ConnectionSet{AllowAll: all, connectionProperties: hypercubes.NewCanonicalHypercubeSet(numDimensions)} +} + +func NewConnectionSetWithCube(cube *hypercubes.CanonicalHypercubeSet) *ConnectionSet { + res := NewConnectionSet(false) + res.connectionProperties.Union(cube) + if res.isAllConnectionsWithoutAllowAll() { + return NewConnectionSet(true) + } + return res +} + +func (conn *ConnectionSet) Copy() *ConnectionSet { + return &ConnectionSet{ + AllowAll: conn.AllowAll, + connectionProperties: conn.connectionProperties.Copy(), + IsStateful: conn.IsStateful, + } +} + +func (conn *ConnectionSet) Intersection(other *ConnectionSet) *ConnectionSet { + if other.AllowAll { + return conn.Copy() + } + if conn.AllowAll { + return other.Copy() + } + return &ConnectionSet{AllowAll: false, connectionProperties: conn.connectionProperties.Intersection(other.connectionProperties)} +} + +func (conn *ConnectionSet) IsEmpty() bool { + if conn.AllowAll { + return false + } + return conn.connectionProperties.IsEmpty() +} + +func (conn *ConnectionSet) Union(other *ConnectionSet) *ConnectionSet { + if conn.AllowAll || other.AllowAll { + return NewConnectionSet(true) + } + if other.IsEmpty() { + return conn.Copy() + } + if conn.IsEmpty() { + return other.Copy() + } + res := &ConnectionSet{ + AllowAll: false, + connectionProperties: conn.connectionProperties.Union(other.connectionProperties), + } + if res.isAllConnectionsWithoutAllowAll() { + return NewConnectionSet(true) + } + return res +} + +func getAllPropertiesObject() *hypercubes.CanonicalHypercubeSet { + return hypercubes.CreateFromCube(getDimensionDomainsList()) +} + +func (conn *ConnectionSet) isAllConnectionsWithoutAllowAll() bool { + if conn.AllowAll { + return false + } + return conn.connectionProperties.Equals(getAllPropertiesObject()) +} + +// Subtract +// ToDo: Subtract seems to ignore IsStateful (see https://github.com/np-guard/vpc-network-config-analyzer/issues/199): +// 1. is the delta connection stateful +// 2. connectionProperties is identical but conn stateful while other is not +// the 2nd item can be computed here, with enhancement to relevant structure +// the 1st can not since we do not know where exactly the statefulness came from +func (conn *ConnectionSet) Subtract(other *ConnectionSet) *ConnectionSet { + if conn.IsEmpty() || other.IsEmpty() { + return conn + } + if other.AllowAll { + return NewConnectionSet(false) + } + var connProperties *hypercubes.CanonicalHypercubeSet + if conn.AllowAll { + connProperties = getAllPropertiesObject() + } else { + connProperties = conn.connectionProperties + } + return &ConnectionSet{AllowAll: false, connectionProperties: connProperties.Subtraction(other.connectionProperties)} +} + +func (conn *ConnectionSet) ContainedIn(other *ConnectionSet) (bool, error) { + if other.AllowAll { + return true, nil + } + if conn.AllowAll { + return false, nil + } + res, err := conn.connectionProperties.ContainedIn(other.connectionProperties) + return res, err +} + +func ProtocolStringToCode(protocol netp.ProtocolStr) int64 { + switch protocol { + case netp.ProtocolStringTCP: + return TCPCode + case netp.ProtocolStringUDP: + return UDPCode + case netp.ProtocolStringICMP: + return ICMPCode + } + log.Fatalf("Impossible protocol code %v", protocol) + return 0 +} + +func (conn *ConnectionSet) addConnection(protocol netp.ProtocolStr, + srcMinP, srcMaxP, dstMinP, dstMaxP, + icmpTypeMin, icmpTypeMax, icmpCodeMin, icmpCodeMax int64) { + code := ProtocolStringToCode(protocol) + cube := hypercubes.CreateFromCubeShort(code, code, + srcMinP, srcMaxP, dstMinP, dstMaxP, + icmpTypeMin, icmpTypeMax, icmpCodeMin, icmpCodeMax) + conn.connectionProperties = conn.connectionProperties.Union(cube) + // check if all connections allowed after this union + if conn.isAllConnectionsWithoutAllowAll() { + conn.AllowAll = true + conn.connectionProperties = hypercubes.NewCanonicalHypercubeSet(numDimensions) + } +} + +func (conn *ConnectionSet) AddTCPorUDPConn(protocol netp.ProtocolStr, srcMinP, srcMaxP, dstMinP, dstMaxP int64) { + conn.addConnection(protocol, + srcMinP, srcMaxP, dstMinP, dstMaxP, + MinICMPtype, MaxICMPtype, MinICMPcode, MaxICMPcode) +} + +func (conn *ConnectionSet) AddICMPConnection(icmpTypeMin, icmpTypeMax, icmpCodeMin, icmpCodeMax int64) { + conn.addConnection(netp.ProtocolStringICMP, + MinPort, MaxPort, MinPort, MaxPort, + icmpTypeMin, icmpTypeMax, icmpCodeMin, icmpCodeMax) +} + +func (conn *ConnectionSet) Equal(other *ConnectionSet) bool { + if conn.AllowAll != other.AllowAll { + return false + } + if conn.AllowAll { + return true + } + return conn.connectionProperties.Equals(other.connectionProperties) +} + +func getProtocolStr(p int64) netp.ProtocolStr { + switch p { + case TCPCode: + return netp.ProtocolStringTCP + case UDPCode: + return netp.ProtocolStringUDP + case ICMPCode: + return netp.ProtocolStringICMP + } + log.Fatalf("Impossible protocol value %v", p) + return "" +} + +func getDimensionStr(dimValue *intervals.CanonicalIntervalSet, dim Dimension) string { + domainValues := entireDimension(dim) + if dimValue.Equal(*domainValues) { + // avoid adding dimension str on full dimension values + return "" + } + switch dim { + case protocol: + pList := []string{} + for p := minProtocol; p <= maxProtocol; p++ { + pList = append(pList, string(getProtocolStr(p))) + } + return "protocol: " + strings.Join(pList, ",") + case srcPort: + return "src-ports: " + dimValue.String() + case dstPort: + return "dst-ports: " + dimValue.String() + case icmpType: + return "icmp-type: " + dimValue.String() + case icmpCode: + return "icmp-code: " + dimValue.String() + } + return "" +} + +func filterEmptyPropertiesStr(inputList []string) []string { + res := []string{} + for _, propertyStr := range inputList { + if propertyStr != "" { + res = append(res, propertyStr) + } + } + return res +} + +func getICMPbasedCubeStr(protocolsValues, icmpTypeValues, icmpCodeValues *intervals.CanonicalIntervalSet) string { + strList := []string{ + getDimensionStr(protocolsValues, protocol), + getDimensionStr(icmpTypeValues, icmpType), + getDimensionStr(icmpCodeValues, icmpCode), + } + return strings.Join(filterEmptyPropertiesStr(strList), propertySeparator) +} + +func getPortBasedCubeStr(protocolsValues, srcPortsValues, dstPortsValues *intervals.CanonicalIntervalSet) string { + strList := []string{ + getDimensionStr(protocolsValues, protocol), + getDimensionStr(srcPortsValues, srcPort), + getDimensionStr(dstPortsValues, dstPort), + } + return strings.Join(filterEmptyPropertiesStr(strList), propertySeparator) +} + +func getMixedProtocolsCubeStr(protocols *intervals.CanonicalIntervalSet) string { + // TODO: make sure other dimension values are full + return getDimensionStr(protocols, protocol) +} + +func getConnsCubeStr(cube []*intervals.CanonicalIntervalSet) string { + protocols := cube[protocol] + if (protocols.Contains(TCPCode) || protocols.Contains(UDPCode)) && !protocols.Contains(ICMPCode) { + return getPortBasedCubeStr(cube[protocol], cube[srcPort], cube[dstPort]) + } + if protocols.Contains(ICMPCode) && !(protocols.Contains(TCPCode) || protocols.Contains(UDPCode)) { + return getICMPbasedCubeStr(cube[protocol], cube[icmpType], cube[icmpCode]) + } + return getMixedProtocolsCubeStr(protocols) +} + +// String returns a string representation of a ConnectionSet object +func (conn *ConnectionSet) String() string { + if conn.AllowAll { + return AllConnections + } else if conn.IsEmpty() { + return NoConnections + } + resStrings := []string{} + // get cubes and cube str per each cube + cubes := conn.connectionProperties.GetCubesList() + for _, cube := range cubes { + resStrings = append(resStrings, getConnsCubeStr(cube)) + } + + sort.Strings(resStrings) + return strings.Join(resStrings, "; ") +} + +func getCubeAsTCPItems(cube []*intervals.CanonicalIntervalSet, protocol netp.TransportLayerProtocolName) []netp.Protocol { + tcpItemsTemp := []netp.Protocol{} + // consider src ports + srcPorts := cube[srcPort] + if srcPorts.Equal(*entireDimension(srcPort)) { + tcpItemsTemp = append(tcpItemsTemp, netp.TCPUDP{Protocol: protocol}) + } else { + // iterate the intervals in the interval-set + for _, interval := range srcPorts.IntervalSet { + tcpRes := netp.TCPUDP{ + Protocol: protocol, + PortRangePair: netp.PortRangePair{ + SrcPort: netp.PortRange{Min: int(interval.Start), Max: int(interval.End)}, + }, + } + tcpItemsTemp = append(tcpItemsTemp, tcpRes) + } + } + // consider dst ports + dstPorts := cube[dstPort] + if dstPorts.Equal(*entireDimension(dstPort)) { + return tcpItemsTemp + } + tcpItemsFinal := []netp.Protocol{} + for _, interval := range dstPorts.IntervalSet { + for _, tcpItemTemp := range tcpItemsTemp { + item, _ := tcpItemTemp.(netp.TCPUDP) + tcpItemsFinal = append(tcpItemsFinal, netp.TCPUDP{ + Protocol: protocol, + PortRangePair: netp.PortRangePair{ + SrcPort: item.PortRangePair.SrcPort, + DstPort: netp.PortRange{Min: int(interval.Start), Max: int(interval.End)}, + }, + }) + } + } + return tcpItemsFinal +} + +func getCubeAsICMPItems(cube []*intervals.CanonicalIntervalSet) []netp.Protocol { + icmpTypes := cube[icmpType] + icmpCodes := cube[icmpCode] + if icmpCodes.Equal(*entireDimension(icmpCode)) { + if icmpTypes.Equal(*entireDimension(icmpType)) { + return []netp.Protocol{netp.ICMP{}} + } + res := []netp.Protocol{} + for _, t := range icmpTypes.Elements() { + res = append(res, netp.ICMP{ICMPCodeType: &netp.ICMPCodeType{Type: t}}) + } + return res + } + + // iterate both codes and types + res := []netp.Protocol{} + for _, t := range icmpTypes.Elements() { + codes := icmpCodes.Elements() + for i := range codes { + c := codes[i] + if netp.ValidateICMP(t, c) == nil { + res = append(res, netp.ICMP{ICMPCodeType: &netp.ICMPCodeType{Type: t, Code: &c}}) + } + } + } + return res +} + +type ConnDetails []netp.Protocol + +func ConnToJSONRep(c *ConnectionSet) ConnDetails { + if c == nil { + return nil // one of the connections in connectionDiff can be empty + } + if c.AllowAll { + return []netp.Protocol{} + } + var res []netp.Protocol + + cubes := c.connectionProperties.GetCubesList() + for _, cube := range cubes { + protocols := cube[protocol] + if protocols.Contains(TCPCode) { + res = append(res, getCubeAsTCPItems(cube, netp.TCP)...) + } + if protocols.Contains(UDPCode) { + res = append(res, getCubeAsTCPItems(cube, netp.UDP)...) + } + if protocols.Contains(ICMPCode) { + res = append(res, getCubeAsICMPItems(cube)...) + } + } + + return res +} + +// NewTCPConnectionSet returns a ConnectionSet object with TCPCode protocol (all ports) +func NewTCPConnectionSet() *ConnectionSet { + res := NewConnectionSet(false) + res.AddTCPorUDPConn(netp.ProtocolStringTCP, MinPort, MaxPort, MinPort, MaxPort) + return res +} diff --git a/pkg/connectionset/connectionset_test.go b/pkg/connectionset/connectionset_test.go new file mode 100644 index 0000000..9154679 --- /dev/null +++ b/pkg/connectionset/connectionset_test.go @@ -0,0 +1,52 @@ +package connectionset + +import ( + "fmt" + "testing" + + "github.com/np-guard/models/pkg/netp" +) + +// TODO: Add test assertions +func TestBasicConnectionSet(t *testing.T) { + c := NewConnectionSet(false) + fmt.Println(c.String()) + c.AddICMPConnection(7, 7, 5, 5) + fmt.Println(c.String()) + + d := NewConnectionSet(true) + fmt.Println(d.String()) + e := NewConnectionSet(false) + e.AddTCPorUDPConn(netp.ProtocolStringTCP, 1, 65535, 1, 65535) + d = d.Subtract(e) + fmt.Println(d.String()) + d = d.Union(e) + fmt.Println(d.String()) + + fmt.Println("done") +} + +func TestBasicConnectionSet2(t *testing.T) { + c := NewConnectionSet(false) + c.AddICMPConnection(7, 7, 5, 5) + d := NewConnectionSet(true) + e := NewConnectionSet(false) + e.AddTCPorUDPConn(netp.ProtocolStringTCP, 1, 65535, 1, 65535) + d = d.Subtract(e) + d = d.Subtract(c) + fmt.Println(d.String()) + + fmt.Println("done") +} + +func TestBasicConnectionSet3(t *testing.T) { + c := NewConnectionSet(false) + c.AddICMPConnection(7, 7, 5, 5) + d := NewConnectionSet(true) + d = d.Subtract(c) + d.AddICMPConnection(7, 7, 5, 5) + + fmt.Println(d.String()) + + fmt.Println("done") +} diff --git a/pkg/connectionset/statefulness.go b/pkg/connectionset/statefulness.go new file mode 100644 index 0000000..95c4328 --- /dev/null +++ b/pkg/connectionset/statefulness.go @@ -0,0 +1,101 @@ +package connectionset + +import ( + "github.com/np-guard/models/pkg/hypercubes" + "github.com/np-guard/models/pkg/intervals" + "github.com/np-guard/models/pkg/netp" +) + +const ( + // StatefulUnknown is the default value for a ConnectionSet object, + StatefulUnknown int = iota + // StatefulTrue represents a connection object for which any allowed TCP (on all allowed src/dst ports) + // has an allowed response connection + StatefulTrue + // StatefulFalse represents a connection object for which there exists some allowed TCP + // (on any allowed subset from the allowed src/dst ports) that does not have an allowed response connection + StatefulFalse +) + +// EnhancedString returns a connection string with possibly added asterisk for stateless connection +func (conn *ConnectionSet) EnhancedString() string { + if conn.IsStateful == StatefulFalse { + return conn.String() + " *" + } + return conn.String() +} + +// ConnectionWithStatefulness updates `conn` object with `IsStateful` property, based on input `secondDirectionConn`. +// `conn` represents a src-to-dst connection, and `secondDirectionConn` represents dst-to-src connection. +// The property `IsStateful` of `conn` is set as `StatefulFalse` if there is at least some subset within TCP from `conn` +// which is not stateful (such that the response direction for this subset is not enabled). +// This function also returns a connection object with the exact subset of the stateful part (within TCP) +// from the entire connection `conn`, and with the original connections on other protocols. +func (conn *ConnectionSet) ConnectionWithStatefulness(secondDirectionConn *ConnectionSet) *ConnectionSet { + connTCP := conn.tcpConn() + if connTCP.IsEmpty() { + conn.IsStateful = StatefulTrue + return conn + } + secondDirectionConnTCP := secondDirectionConn.tcpConn() + statefulCombinedConnTCP := connTCP.connTCPWithStatefulness(secondDirectionConnTCP) + conn.IsStateful = connTCP.IsStateful + nonTCP := conn.Subtract(connTCP) + return nonTCP.Union(statefulCombinedConnTCP) +} + +// connTCPWithStatefulness assumes that both `conn` and `secondDirectionConn` are within TCP. +// it assigns IsStateful a value within `conn`, and returns the subset from `conn` which is stateful. +func (conn *ConnectionSet) connTCPWithStatefulness(secondDirectionConn *ConnectionSet) *ConnectionSet { + secondDirectionSwitchPortsDirection := secondDirectionConn.switchSrcDstPortsOnTCP() + // flip src/dst ports before intersection + statefulCombinedConn := conn.Intersection(secondDirectionSwitchPortsDirection) + if !conn.Equal(statefulCombinedConn) { + conn.IsStateful = StatefulFalse + } else { + conn.IsStateful = StatefulTrue + } + return statefulCombinedConn +} + +// tcpConn returns a new ConnectionSet object, which is the intersection of `conn` with TCP +func (conn *ConnectionSet) tcpConn() *ConnectionSet { + res := NewConnectionSet(false) + res.AddTCPorUDPConn(netp.ProtocolStringTCP, MinPort, MaxPort, MinPort, MaxPort) + return conn.Intersection(res) +} + +// switchSrcDstPortsOnTCP returns a new ConnectionSet object, built from the input ConnectionSet object. +// It assumes the input connection object is only within TCP protocol. +// For TCP the src and dst ports on relevant cubes are being switched. +func (conn *ConnectionSet) switchSrcDstPortsOnTCP() *ConnectionSet { + if conn.AllowAll || conn.IsEmpty() { + return conn.Copy() + } + res := NewConnectionSet(false) + cubes := conn.connectionProperties.GetCubesList() + for _, cube := range cubes { + // assuming cube[protocol] contains TCP only + srcPorts := cube[srcPort] + dstPorts := cube[dstPort] + // if the entire domain is enabled by both src and dst no need to switch + if !srcPorts.Equal(*entireDimension(srcPort)) || !dstPorts.Equal(*entireDimension(dstPort)) { + newCube := copyCube(cube) + newCube[srcPort], newCube[dstPort] = newCube[dstPort], newCube[srcPort] + res.connectionProperties = res.connectionProperties.Union(hypercubes.CreateFromCube(newCube)) + } else { + res.connectionProperties = res.connectionProperties.Union(hypercubes.CreateFromCube(cube)) + } + } + return res +} + +// copyCube returns a new slice of intervals copied from input cube +func copyCube(cube []*intervals.CanonicalIntervalSet) []*intervals.CanonicalIntervalSet { + newCube := make([]*intervals.CanonicalIntervalSet, len(cube)) + for i, interval := range cube { + newInterval := interval.Copy() + newCube[i] = &newInterval + } + return newCube +} diff --git a/pkg/connectionset/statefulness_test.go b/pkg/connectionset/statefulness_test.go new file mode 100644 index 0000000..ae935df --- /dev/null +++ b/pkg/connectionset/statefulness_test.go @@ -0,0 +1,149 @@ +package connectionset + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/np-guard/models/pkg/netp" +) + +func newTCPConn(t *testing.T, srcMinP, srcMaxP, dstMinP, dstMaxP int64) *ConnectionSet { + t.Helper() + res := NewConnectionSet(false) + res.AddTCPorUDPConn(netp.ProtocolStringTCP, srcMinP, srcMaxP, dstMinP, dstMaxP) + return res +} + +func newUDPConn(t *testing.T, srcMinP, srcMaxP, dstMinP, dstMaxP int64) *ConnectionSet { + t.Helper() + res := NewConnectionSet(false) + res.AddTCPorUDPConn(netp.ProtocolStringUDP, srcMinP, srcMaxP, dstMinP, dstMaxP) + return res +} + +func newICMPconn(t *testing.T) *ConnectionSet { + t.Helper() + res := NewConnectionSet(false) + res.AddICMPConnection(MinICMPtype, MaxICMPtype, MinICMPcode, MaxICMPcode) + return res +} + +func allButTCP(t *testing.T) *ConnectionSet { + t.Helper() + res := NewConnectionSet(true) + tcpOnly := res.tcpConn() + return res.Subtract(tcpOnly) +} + +type statefulnessTest struct { + name string + srcToDst *ConnectionSet + dstToSrc *ConnectionSet + // expectedIsStateful represents the expected IsStateful computed value for srcToDst, + // which should be either StatefulTrue or StatefulFalse, given the input dstToSrc connection. + // the computation applies only to the TCP protocol within those connections. + expectedIsStateful int + // expectedStatefulConn represents the subset from srcToDst which is not related to the "non-stateful" mark (*) on the srcToDst connection, + // the stateless part for TCP is srcToDst.Subtract(statefulConn) + expectedStatefulConn *ConnectionSet +} + +func (tt statefulnessTest) runTest(t *testing.T) { + t.Helper() + statefulConn := tt.srcToDst.ConnectionWithStatefulness(tt.dstToSrc) + require.Equal(t, tt.expectedIsStateful, tt.srcToDst.IsStateful) + require.True(t, tt.expectedStatefulConn.Equal(statefulConn)) +} + +func TestAll(t *testing.T) { + var testCasesStatefulness = []statefulnessTest{ + { + name: "tcp_all_ports_on_both_directions", + srcToDst: newTCPConn(t, MinPort, MaxPort, MinPort, MaxPort), // TCP all ports + dstToSrc: newTCPConn(t, MinPort, MaxPort, MinPort, MaxPort), // TCP all ports + expectedIsStateful: StatefulTrue, + expectedStatefulConn: newTCPConn(t, MinPort, MaxPort, MinPort, MaxPort), // TCP all ports + }, + { + name: "first_all_cons_second_tcp_with_ports", + srcToDst: NewConnectionSet(true), // all connections + dstToSrc: newTCPConn(t, 80, 80, MinPort, MaxPort), // TCP , src-ports: 80, dst-ports: all + + // there is a subset of the tcp connection which is not stateful + expectedIsStateful: StatefulFalse, + + // TCP src-ports: all, dst-port: 80 , union: all non-TCP conns + expectedStatefulConn: allButTCP(t).Union(newTCPConn(t, MinPort, MaxPort, 80, 80)), + }, + { + name: "first_all_conns_second_no_tcp", + srcToDst: NewConnectionSet(true), // all connections + dstToSrc: newICMPconn(t), // ICMP + expectedIsStateful: StatefulFalse, + expectedStatefulConn: allButTCP(t), // UDP, ICMP (all TCP is considered stateless here) + }, + { + name: "tcp_with_ports_both_directions_exact_match", + srcToDst: newTCPConn(t, 80, 80, 443, 443), + dstToSrc: newTCPConn(t, 443, 443, 80, 80), + expectedIsStateful: StatefulTrue, + expectedStatefulConn: newTCPConn(t, 80, 80, 443, 443), + }, + { + name: "tcp_with_ports_both_directions_partial_match", + srcToDst: newTCPConn(t, 80, 100, 443, 443), + dstToSrc: newTCPConn(t, 443, 443, 80, 80), + expectedIsStateful: StatefulFalse, + expectedStatefulConn: newTCPConn(t, 80, 80, 443, 443), + }, + { + name: "tcp_with_ports_both_directions_no_match", + srcToDst: newTCPConn(t, 80, 100, 443, 443), + dstToSrc: newTCPConn(t, 80, 80, 80, 80), + expectedIsStateful: StatefulFalse, + expectedStatefulConn: NewConnectionSet(false), + }, + { + name: "udp_and_tcp_with_ports_both_directions_no_match", + srcToDst: newTCPConn(t, 80, 100, 443, 443).Union(newUDPConn(t, 80, 100, 443, 443)), + dstToSrc: newTCPConn(t, 80, 80, 80, 80).Union(newUDPConn(t, 80, 80, 80, 80)), + expectedIsStateful: StatefulFalse, + expectedStatefulConn: newUDPConn(t, 80, 100, 443, 443), + }, + { + name: "no_tcp_in_first_direction", + srcToDst: newUDPConn(t, 80, 100, 443, 443), + dstToSrc: newTCPConn(t, 80, 80, 80, 80).Union(newUDPConn(t, 80, 80, 80, 80)), + expectedIsStateful: StatefulTrue, + expectedStatefulConn: newUDPConn(t, 80, 100, 443, 443), + }, + { + name: "empty_conn_in_first_direction", + srcToDst: NewConnectionSet(false), + dstToSrc: newTCPConn(t, 80, 80, 80, 80).Union(newUDPConn(t, MinPort, MaxPort, MinPort, MaxPort)), + expectedIsStateful: StatefulTrue, + expectedStatefulConn: NewConnectionSet(false), + }, + { + name: "only_udp_icmp_in_first_direction_and_empty_second_direction", + srcToDst: newUDPConn(t, MinPort, MaxPort, MinPort, MaxPort).Union(newICMPconn(t)), + dstToSrc: NewConnectionSet(false), + // stateful analysis does not apply to udp/icmp, thus considered in the result as "stateful" + // (to avoid marking it as stateless in the output) + expectedIsStateful: StatefulTrue, + expectedStatefulConn: newUDPConn(t, MinPort, MaxPort, MinPort, MaxPort).Union(newICMPconn(t)), + }, + } + t.Parallel() + // explainTests is the list of tests to run + for testIdx := range testCasesStatefulness { + tt := testCasesStatefulness[testIdx] + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + tt.runTest(t) + }) + } + fmt.Println("done") +} diff --git a/pkg/intervals/intervalset.go b/pkg/intervals/intervalset.go index 400a88d..71b49c5 100644 --- a/pkg/intervals/intervalset.go +++ b/pkg/intervals/intervalset.go @@ -254,6 +254,16 @@ func (c *CanonicalIntervalSet) IsSingleNumber() bool { return false } +func (c *CanonicalIntervalSet) Elements() []int { + res := []int{} + for _, interval := range c.IntervalSet { + for i := interval.Start; i <= interval.End; i++ { + res = append(res, int(i)) + } + } + return res +} + func CreateFromInterval(start, end int64) *CanonicalIntervalSet { return &CanonicalIntervalSet{IntervalSet: []Interval{{Start: start, End: end}}} } diff --git a/pkg/netp/common.go b/pkg/netp/common.go new file mode 100644 index 0000000..bed6427 --- /dev/null +++ b/pkg/netp/common.go @@ -0,0 +1,18 @@ +package netp + +type Protocol interface { + // InverseDirection returns the response expected for a request made using this protocol + InverseDirection() Protocol +} + +type AnyProtocol struct{} + +func (t AnyProtocol) InverseDirection() Protocol { return AnyProtocol{} } + +type ProtocolStr string + +const ( + ProtocolStringTCP ProtocolStr = "TCP" + ProtocolStringUDP ProtocolStr = "UDP" + ProtocolStringICMP ProtocolStr = "ICMP" +) diff --git a/pkg/netp/icmp.go b/pkg/netp/icmp.go new file mode 100644 index 0000000..528ce52 --- /dev/null +++ b/pkg/netp/icmp.go @@ -0,0 +1,99 @@ +package netp + +import ( + "fmt" + "log" +) + +type ICMPCodeType struct { + // ICMP type allowed. + Type int + + // ICMP code allowed. If omitted, any code is allowed + Code *int +} + +type ICMP struct { + *ICMPCodeType +} + +func (t ICMP) InverseDirection() Protocol { + if t.ICMPCodeType == nil { + return nil + } + + if invType := inverseICMPType(t.Type); invType != undefinedICMP { + return ICMP{ICMPCodeType: &ICMPCodeType{Type: invType, Code: t.Code}} + } + return nil +} + +// Based on https://datatracker.ietf.org/doc/html/rfc792 + +const ( + EchoReply = 0 + DestinationUnreachable = 3 + SourceQuench = 4 + Redirect = 5 + Echo = 8 + TimeExceeded = 11 + ParameterProblem = 12 + Timestamp = 13 + TimestampReply = 14 + InformationRequest = 15 + InformationReply = 16 + + undefinedICMP = -2 +) + +// inverseICMPType returns the reply type for request type and vice versa. +// When there is no inverse, returns undefinedICMP +func inverseICMPType(t int) int { + switch t { + case Echo: + return EchoReply + case EchoReply: + return Echo + + case Timestamp: + return TimestampReply + case TimestampReply: + return Timestamp + + case InformationRequest: + return InformationReply + case InformationReply: + return InformationRequest + + case DestinationUnreachable, SourceQuench, Redirect, TimeExceeded, ParameterProblem: + return undefinedICMP + default: + log.Panicf("Impossible ICMP type: %v", t) + } + return undefinedICMP +} + +//nolint:revive // magic numbers are fine here +func ValidateICMP(t, c int) error { + maxCodes := map[int]int{ + EchoReply: 0, + DestinationUnreachable: 5, + SourceQuench: 0, + Redirect: 3, + Echo: 0, + TimeExceeded: 1, + ParameterProblem: 0, + Timestamp: 0, + TimestampReply: 0, + InformationRequest: 0, + InformationReply: 0, + } + maxCode, ok := maxCodes[t] + if !ok { + return fmt.Errorf("invalid ICMP type %v", t) + } + if c > maxCode { + return fmt.Errorf("ICMP code %v is invalid for ICMP type %v", c, t) + } + return nil +} diff --git a/pkg/netp/tcpudp.go b/pkg/netp/tcpudp.go new file mode 100644 index 0000000..8777e41 --- /dev/null +++ b/pkg/netp/tcpudp.go @@ -0,0 +1,46 @@ +package netp + +import "log" + +type TransportLayerProtocolName string + +const ( + TCP TransportLayerProtocolName = "TCP" + UDP TransportLayerProtocolName = "UDP" +) + +const DefaultMinPort = 1 +const DefaultMaxPort = 65535 + +type PortRange struct { + // Minimal port; default is DefaultMinPort + Min int + + // Maximal port; default is DefaultMaxPort + Max int +} + +type PortRangePair struct { + SrcPort PortRange + DstPort PortRange +} + +type TCPUDP struct { + Protocol TransportLayerProtocolName + PortRangePair PortRangePair +} + +func (t TCPUDP) InverseDirection() Protocol { + switch t.Protocol { + case TCP: + return TCPUDP{ + Protocol: TCP, + PortRangePair: PortRangePair{SrcPort: t.PortRangePair.DstPort, DstPort: t.PortRangePair.SrcPort}, + } + case UDP: + return nil + default: + log.Panicf("Impossible protocol: %v", t.Protocol) + } + return nil +}