diff --git a/pkg/connection/connection.go b/pkg/connection/connection.go index d0dab19..1e4c4d0 100644 --- a/pkg/connection/connection.go +++ b/pkg/connection/connection.go @@ -50,12 +50,14 @@ func NewUDPSet() *Set { } // ICMPConnection returns a set of connections containing the ICMP protocol with specified type,code values -func ICMPConnection(icmpType, icmpCode *int64) (*Set, error) { - icmp, err := netp.ICMPFromTypeAndCode64(icmpType, icmpCode) - if err != nil { - return nil, err - } - return netset.NewICMPTransport(icmp), nil +func ICMPConnection(icmpType, icmpCode int64) *Set { + return netset.NewICMPTransport(icmpType, icmpType, icmpCode, icmpCode) +} + +// ICMPConnectionTypeCodeRanges returns a set of connections containing the ICMP +// protocol with specified type,code ranges values +func ICMPConnectionTypeCodeRanges(minIcmpType, maxICMPType, minCode, maxCode int64) *Set { + return netset.NewICMPTransport(minIcmpType, maxICMPType, minCode, maxCode) } // All returns a set of all protocols (TCP,UPD,ICMP) in the set (with all possible properties values) diff --git a/pkg/connection/json.go b/pkg/connection/json.go index 1cdb2bd..ff253ff 100644 --- a/pkg/connection/json.go +++ b/pkg/connection/json.go @@ -7,6 +7,7 @@ SPDX-License-Identifier: Apache-2.0 package connection import ( + "github.com/np-guard/models/pkg/interval" "github.com/np-guard/models/pkg/netp" "github.com/np-guard/models/pkg/netset" "github.com/np-guard/models/pkg/spec" @@ -52,6 +53,40 @@ func getCubeAsTCPItems(srcPorts, dstPorts *netset.PortSet, p int64) []spec.TcpUd type Details spec.ProtocolList +func getCubeAsICMPItems(typesSet, codesSet *interval.CanonicalSet) []spec.Icmp { + allTypes := typesSet.Equal(netset.AllICMPTypes()) + allCodes := codesSet.Equal(netset.AllICMPCodes()) + switch { + case allTypes && allCodes: + return []spec.Icmp{{Protocol: spec.IcmpProtocolICMP}} + case allTypes: + res := []spec.Icmp{} + for _, code64 := range codesSet.Elements() { + code := int(code64) + res = append(res, spec.Icmp{Protocol: spec.IcmpProtocolICMP, Code: &code}) + } + return res + case allCodes: + res := []spec.Icmp{} + for _, type64 := range typesSet.Elements() { + t := int(type64) + res = append(res, spec.Icmp{Protocol: spec.IcmpProtocolICMP, Type: &t}) + } + return res + default: + res := []spec.Icmp{} + // iterate both codes and types + for _, type64 := range typesSet.Elements() { + t := int(type64) + for _, code64 := range codesSet.Elements() { + code := int(code64) + res = append(res, spec.Icmp{Protocol: spec.IcmpProtocolICMP, Type: &t, Code: &code}) + } + } + return res + } +} + // ToJSON returns a `Details` object for JSON representation of the input connection Set. func ToJSON(c *Set) Details { if c == nil { @@ -72,19 +107,10 @@ func ToJSON(c *Set) Details { } } for _, item := range c.ICMPSet().Partitions() { - if item.TypeCode != nil { - t := item.TypeCode.Type - res = append(res, spec.Icmp{ - Protocol: spec.IcmpProtocolICMP, - Type: &t, - Code: item.TypeCode.Code, - }) - } else { - res = append(res, spec.Icmp{ - Protocol: spec.IcmpProtocolICMP, - }) + icmpItems := getCubeAsICMPItems(item.Left, item.Right) + for _, item := range icmpItems { + res = append(res, item) } } - return Details(res) } diff --git a/pkg/netp/icmp.go b/pkg/netp/icmp.go index abdaa85..174a2ae 100644 --- a/pkg/netp/icmp.go +++ b/pkg/netp/icmp.go @@ -12,6 +12,14 @@ import ( "slices" ) +// general non-strict ICMP type, code ranges +const ( + MinICMPType int64 = 0 + MaxICMPType int64 = 254 + MinICMPCode int64 = 0 + MaxICMPCode int64 = 255 +) + type ICMPTypeCode struct { // ICMP type allowed. Type int diff --git a/pkg/netset/icmpset.go b/pkg/netset/icmpset.go index 31c5a1c..96b6995 100644 --- a/pkg/netset/icmpset.go +++ b/pkg/netset/icmpset.go @@ -8,249 +8,116 @@ package netset import ( "fmt" - "log" "sort" "strings" + "github.com/np-guard/models/pkg/ds" + "github.com/np-guard/models/pkg/interval" "github.com/np-guard/models/pkg/netp" ) -// ICMPSet is a set of ICMP values, encoded as a bitset -type ICMPSet uint32 +type TypeSet = interval.CanonicalSet +type CodeSet = interval.CanonicalSet -// Encoding for ICMP types and codes, enumerating the possible pairs of values. -// For example: -// * 0 is the pair (type=DestinationUnreachable, code=0). -// * 2 is the pair (type=DestinationUnreachable, code=2). -// * 7 is the pair (type=Redirect, code=1). -// The idea is to use a simple bitset for the set of _valid_ ICMP values. -const ( - encodedDestinationUnreachable = 0 - encodedRedirect = 6 - encodedTimeExceeded = 10 - encodedParameterProblem = 12 - encodedTimestamp = 13 - encodedTimestampReply = 14 - encodedInformationRequest = 15 - encodedInformationReply = 16 - encodedEcho = 17 - encodedEchoReply = 18 - encodedSourceQuench = 19 - last = 19 -) - -func encode(t, code int) int { - switch t { - case netp.DestinationUnreachable: - return encodedDestinationUnreachable + code - case netp.Redirect: - return encodedRedirect + code - case netp.TimeExceeded: - return encodedTimeExceeded + code - case netp.ParameterProblem: - return encodedParameterProblem - case netp.Timestamp: - return encodedTimestamp - case netp.TimestampReply: - return encodedTimestampReply - case netp.InformationRequest: - return encodedInformationRequest - case netp.InformationReply: - return encodedInformationReply - case netp.Echo: - return encodedEcho - case netp.EchoReply: - return encodedEchoReply - case netp.SourceQuench: - return encodedSourceQuench - default: - log.Panicf("Invalid ICMP type %v", t) - return t - } +type ICMPSet struct { + props ds.Product[*TypeSet, *CodeSet] } -//lint:ignore U1000 should be used in the future -func decode(encodedCode int) (netp.ICMP, error) { - t := encodedCode - switch { - case encodedCode < encodedRedirect: - t = encodedDestinationUnreachable - case encodedCode < encodedTimeExceeded: - t = encodedRedirect - case encodedCode < netp.ParameterProblem: - t = encodedTimeExceeded - case encodedCode == encodedEcho: - t = netp.Echo - case encodedCode == encodedEchoReply: - t = netp.EchoReply - case encodedCode == encodedSourceQuench: - t = netp.SourceQuench - } - code := encodedCode - t - return netp.NewICMP(&netp.ICMPTypeCode{Type: t, Code: &code}) +func (c *ICMPSet) Equal(other *ICMPSet) bool { + return c.props.Equal(other.props) } -func (s *ICMPSet) IsSubset(other *ICMPSet) bool { - return ((*s) | (*other)) == (*other) +func (c *ICMPSet) Hash() int { + return c.props.Hash() } -func (s *ICMPSet) Union(other *ICMPSet) *ICMPSet { - var res = (*s) | (*other) - return &res +func (c *ICMPSet) Copy() *ICMPSet { + return &ICMPSet{props: c.props.Copy()} } -func (s *ICMPSet) Intersect(other *ICMPSet) *ICMPSet { - var res = (*s) & (*other) - return &res +func (c *ICMPSet) Intersect(other *ICMPSet) *ICMPSet { + return &ICMPSet{props: c.props.Intersect(other.props)} } -func (s *ICMPSet) Subtract(other *ICMPSet) *ICMPSet { - var res = (*s) & ^(*other) - return &res +func (c *ICMPSet) Partitions() []ds.Pair[*TypeSet, *CodeSet] { + return c.props.Partitions() } -func (s *ICMPSet) Equal(other *ICMPSet) bool { - return *s == *other +func (c *ICMPSet) IsEmpty() bool { + return c.props.IsEmpty() } -func (s *ICMPSet) Copy() *ICMPSet { - var res = *s - return &res +func (c *ICMPSet) Union(other *ICMPSet) *ICMPSet { + return &ICMPSet{props: c.props.Union(other.props)} } -func (s *ICMPSet) Hash() int { - return int(*s) +func (c *ICMPSet) Size() int { + return c.props.Size() } -func (s *ICMPSet) Size() int { - res := 0 - for i := 0; i <= last; i++ { - if s.Contains(i) { - res++ - } - } - return res +// Subtract returns the subtraction of the other from c +func (c *ICMPSet) Subtract(other *ICMPSet) *ICMPSet { + return &ICMPSet{props: c.props.Subtract(other.props)} } -func (s *ICMPSet) IsEmpty() bool { - return s.Equal(EmptyICMPSet()) +// IsSubset returns true if c is subset of other +func (c *ICMPSet) IsSubset(other *ICMPSet) bool { + return c.props.IsSubset(other.props) } -func (s *ICMPSet) Contains(i int) bool { - return ((1 << i) & (*s)) != 0 +// icmpPropsPathLeft creates a new ICMPSet, implemented using CartesianPairLeft. +func icmpPropsPathLeft(typesSet *TypeSet, codeSet *CodeSet) *ICMPSet { + return &ICMPSet{props: ds.CartesianPairLeft(typesSet, codeSet)} } -// collect returns a list of ICMP values for a given type, collecting into a single ICMP value with nil Code if all codes are present. -func (s *ICMPSet) collect(old int) []netp.ICMP { - var res []netp.ICMP - for code := 0; code <= netp.MaxCode(old); code++ { - if s.Contains(encode(old, code)) { - icmp, err := netp.NewICMP(&netp.ICMPTypeCode{Type: old, Code: &code}) - if err != nil { - log.Panicf("collection failed for type %v, code %v", old, &code) - } - res = append(res, icmp) - } - } - if len(res) == netp.MaxCode(old)+1 { - res = []netp.ICMP{{TypeCode: &netp.ICMPTypeCode{Type: old, Code: nil}}} - } - return res +func NewICMPSet(minType, maxType, minCode, maxCode int64) *ICMPSet { + return icmpPropsPathLeft( + interval.New(minType, maxType).ToSet(), + interval.New(minCode, maxCode).ToSet(), + ) } -// Partitions returns a list of ICMP values. -// if all codes for a given type are present, it adds a single ICMP value with nil Code. -// If all ICMP values are present, a single ICMP value with nil TypeCode is returned. -func (s *ICMPSet) Partitions() []netp.ICMP { - all := ICMPSet(allCodes) - if all.IsSubset(s) { - return []netp.ICMP{{TypeCode: nil}} - } - var res []netp.ICMP - for _, t := range netp.Types() { - res = append(res, s.collect(t)...) - } - return res +func EmptyICMPSet() *ICMPSet { + return &ICMPSet{props: ds.NewProductLeft[*TypeSet, *CodeSet]()} } -func fromIndex(i int) *ICMPSet { - var res ICMPSet = 1 << i - return &res +func AllICMPSet() *ICMPSet { + return icmpPropsPathLeft( + AllICMPTypes(), + AllICMPCodes(), + ) } -func (s *ICMPSet) IsAll() bool { - return s.Equal(AllICMPSet()) +func AllICMPCodes() *CodeSet { + return interval.New(netp.MinICMPCode, netp.MaxICMPCode).ToSet() } -// constants for sets of ICMP codes, grouped by types. -// For example, allDestinationUnreachable is the set of all ICMP codes for DestinationUnreachable type. -const ( - allDestinationUnreachable = 0b00000000000000111111 - allRedirect = 0b00000000001111000000 - allTimeExceeded = 0b00000000110000000000 - allOther = 0b11111111000000000000 - allCodes = allDestinationUnreachable | allRedirect | allTimeExceeded | allOther -) - -func EmptyICMPSet() *ICMPSet { - var res ICMPSet = 0 - return &res +func AllICMPTypes() *TypeSet { + return interval.New(netp.MinICMPType, netp.MaxICMPType).ToSet() } -func AllICMPSet() *ICMPSet { - res := ICMPSet(allCodes) - return &res -} +var allICMP = AllICMPSet() -func NewICMPSet(t netp.ICMP) *ICMPSet { - if t.TypeCode == nil { - return AllICMPSet() - } - if t.TypeCode.Code != nil { - return fromIndex(encode(t.TypeCode.Type, *t.TypeCode.Code)) - } - var res ICMPSet - switch t.TypeCode.Type { - case netp.DestinationUnreachable: - res = allDestinationUnreachable - case netp.Redirect: - res = allRedirect - case netp.TimeExceeded: - res = allTimeExceeded - default: - res = *fromIndex(encode(t.TypeCode.Type, 0)) - } - return &res +func (c *ICMPSet) IsAll() bool { + return c.Equal(allICMP) } -func getICMPCubeStr(cube netp.ICMP) string { - tc := cube.ICMPTypeCode() - if tc == nil { - return "" +func getICMPCubeStr(cube ds.Pair[*TypeSet, *CodeSet]) string { + if cube.Right.Equal(AllICMPCodes()) { + return fmt.Sprintf("ICMP type: %s", cube.Left.String()) } - if tc.Code == nil { - if netp.HasSingleCode(tc.Type) { - return fmt.Sprintf("icmp-type: %d icmp-code: 0", tc.Type) - } - return fmt.Sprintf("icmp-type: %d", tc.Type) - } - return fmt.Sprintf("icmp-type: %d icmp-code: %d", tc.Type, *tc.Code) + return fmt.Sprintf("ICMP type: %s code: %s", cube.Left.String(), cube.Right.String()) } -func (s *ICMPSet) String() string { - if s.IsEmpty() { - return "" +func (c *ICMPSet) String() string { + if c.IsAll() { + return string(netp.ProtocolStringICMP) } - cubes := s.Partitions() + cubes := c.Partitions() var resStrings = make([]string, len(cubes)) for i, cube := range cubes { resStrings[i] = getICMPCubeStr(cube) } sort.Strings(resStrings) - str := string(netp.ProtocolStringICMP) - last := strings.Join(resStrings, semicolon) - if last != "" { - str += " " + last - } - return str + return strings.Join(resStrings, " | ") } diff --git a/pkg/netset/icmpset_test.go b/pkg/netset/icmpset_test.go index e5c577e..ab9c776 100644 --- a/pkg/netset/icmpset_test.go +++ b/pkg/netset/icmpset_test.go @@ -16,13 +16,13 @@ import ( "github.com/np-guard/models/pkg/netset" ) -func TestBasicICMPSet(t *testing.T) { +func TestBasicICMPSetStrict(t *testing.T) { // create ICMPSet objects i1 := 8 - all := netset.AllICMPSet() + all := netset.AllICMPSetStrict() obj1, err := netp.ICMPFromTypeAndCode(&i1, nil) require.Nil(t, err) - icmpset := netset.NewICMPSet(obj1) + icmpset := netset.NewICMPSetStrict(obj1) // test basic functions, operations fmt.Println(icmpset) // ICMP icmp-type: 8 icmp-code: 0 @@ -34,3 +34,22 @@ func TestBasicICMPSet(t *testing.T) { require.True(t, icmpset.IsSubset(all)) fmt.Println("done") } + +func TestBasicICMPSet(t *testing.T) { + icmpset := netset.NewICMPSet(8, 8, 0, 255) // ICMP type: 8 + icmpset1 := netset.NewICMPSet(8, 8, 0, 0) // ICMP type: 8 code: 0 + fmt.Println(icmpset) + fmt.Println(icmpset1) + + require.True(t, icmpset1.IsSubset(icmpset)) + require.False(t, icmpset.IsSubset(icmpset1)) + require.True(t, icmpset1.Union(icmpset).Equal(icmpset)) + + require.False(t, icmpset.IsAll()) + require.False(t, icmpset.IsEmpty()) + + require.False(t, icmpset1.IsAll()) + require.False(t, icmpset1.IsEmpty()) + + fmt.Println("done") +} diff --git a/pkg/netset/rfcicmpset.go b/pkg/netset/rfcicmpset.go new file mode 100644 index 0000000..f3c08a8 --- /dev/null +++ b/pkg/netset/rfcicmpset.go @@ -0,0 +1,256 @@ +/* +Copyright 2023- IBM Inc. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 +*/ + +package netset + +import ( + "fmt" + "log" + "sort" + "strings" + + "github.com/np-guard/models/pkg/netp" +) + +// RFCICMPSet is a set of _valid_ (by RFC) ICMP values, encoded as a bitset +type RFCICMPSet uint32 + +// Encoding for ICMP types and codes, enumerating the possible pairs of values. +// For example: +// * 0 is the pair (type=DestinationUnreachable, code=0). +// * 2 is the pair (type=DestinationUnreachable, code=2). +// * 7 is the pair (type=Redirect, code=1). +// The idea is to use a simple bitset for the set of _valid_ ICMP values. +const ( + encodedDestinationUnreachable = 0 + encodedRedirect = 6 + encodedTimeExceeded = 10 + encodedParameterProblem = 12 + encodedTimestamp = 13 + encodedTimestampReply = 14 + encodedInformationRequest = 15 + encodedInformationReply = 16 + encodedEcho = 17 + encodedEchoReply = 18 + encodedSourceQuench = 19 + last = 19 +) + +func encode(t, code int) int { + switch t { + case netp.DestinationUnreachable: + return encodedDestinationUnreachable + code + case netp.Redirect: + return encodedRedirect + code + case netp.TimeExceeded: + return encodedTimeExceeded + code + case netp.ParameterProblem: + return encodedParameterProblem + case netp.Timestamp: + return encodedTimestamp + case netp.TimestampReply: + return encodedTimestampReply + case netp.InformationRequest: + return encodedInformationRequest + case netp.InformationReply: + return encodedInformationReply + case netp.Echo: + return encodedEcho + case netp.EchoReply: + return encodedEchoReply + case netp.SourceQuench: + return encodedSourceQuench + default: + log.Panicf("Invalid ICMP type %v", t) + return t + } +} + +//lint:ignore U1000 should be used in the future +func decode(encodedCode int) (netp.ICMP, error) { + t := encodedCode + switch { + case encodedCode < encodedRedirect: + t = encodedDestinationUnreachable + case encodedCode < encodedTimeExceeded: + t = encodedRedirect + case encodedCode < netp.ParameterProblem: + t = encodedTimeExceeded + case encodedCode == encodedEcho: + t = netp.Echo + case encodedCode == encodedEchoReply: + t = netp.EchoReply + case encodedCode == encodedSourceQuench: + t = netp.SourceQuench + } + code := encodedCode - t + return netp.NewICMP(&netp.ICMPTypeCode{Type: t, Code: &code}) +} + +func (s *RFCICMPSet) IsSubset(other *RFCICMPSet) bool { + return ((*s) | (*other)) == (*other) +} + +func (s *RFCICMPSet) Union(other *RFCICMPSet) *RFCICMPSet { + var res = (*s) | (*other) + return &res +} + +func (s *RFCICMPSet) Intersect(other *RFCICMPSet) *RFCICMPSet { + var res = (*s) & (*other) + return &res +} + +func (s *RFCICMPSet) Subtract(other *RFCICMPSet) *RFCICMPSet { + var res = (*s) & ^(*other) + return &res +} + +func (s *RFCICMPSet) Equal(other *RFCICMPSet) bool { + return *s == *other +} + +func (s *RFCICMPSet) Copy() *RFCICMPSet { + var res = *s + return &res +} + +func (s *RFCICMPSet) Hash() int { + return int(*s) +} + +func (s *RFCICMPSet) Size() int { + res := 0 + for i := 0; i <= last; i++ { + if s.Contains(i) { + res++ + } + } + return res +} + +func (s *RFCICMPSet) IsEmpty() bool { + return s.Equal(EmptyICMPSetStrict()) +} + +func (s *RFCICMPSet) Contains(i int) bool { + return ((1 << i) & (*s)) != 0 +} + +// collect returns a list of ICMP values for a given type, collecting into a single ICMP value with nil Code if all codes are present. +func (s *RFCICMPSet) collect(old int) []netp.ICMP { + var res []netp.ICMP + for code := 0; code <= netp.MaxCode(old); code++ { + if s.Contains(encode(old, code)) { + icmp, err := netp.NewICMP(&netp.ICMPTypeCode{Type: old, Code: &code}) + if err != nil { + log.Panicf("collection failed for type %v, code %v", old, &code) + } + res = append(res, icmp) + } + } + if len(res) == netp.MaxCode(old)+1 { + res = []netp.ICMP{{TypeCode: &netp.ICMPTypeCode{Type: old, Code: nil}}} + } + return res +} + +// Partitions returns a list of ICMP values. +// if all codes for a given type are present, it adds a single ICMP value with nil Code. +// If all ICMP values are present, a single ICMP value with nil TypeCode is returned. +func (s *RFCICMPSet) Partitions() []netp.ICMP { + all := RFCICMPSet(allCodes) + if all.IsSubset(s) { + return []netp.ICMP{{TypeCode: nil}} + } + var res []netp.ICMP + for _, t := range netp.Types() { + res = append(res, s.collect(t)...) + } + return res +} + +func fromIndex(i int) *RFCICMPSet { + var res RFCICMPSet = 1 << i + return &res +} + +func (s *RFCICMPSet) IsAll() bool { + return s.Equal(AllICMPSetStrict()) +} + +// constants for sets of ICMP codes, grouped by types. +// For example, allDestinationUnreachable is the set of all ICMP codes for DestinationUnreachable type. +const ( + allDestinationUnreachable = 0b00000000000000111111 + allRedirect = 0b00000000001111000000 + allTimeExceeded = 0b00000000110000000000 + allOther = 0b11111111000000000000 + allCodes = allDestinationUnreachable | allRedirect | allTimeExceeded | allOther +) + +func EmptyICMPSetStrict() *RFCICMPSet { + var res RFCICMPSet = 0 + return &res +} + +func AllICMPSetStrict() *RFCICMPSet { + res := RFCICMPSet(allCodes) + return &res +} + +func NewICMPSetStrict(t netp.ICMP) *RFCICMPSet { + if t.TypeCode == nil { + return AllICMPSetStrict() + } + if t.TypeCode.Code != nil { + return fromIndex(encode(t.TypeCode.Type, *t.TypeCode.Code)) + } + var res RFCICMPSet + switch t.TypeCode.Type { + case netp.DestinationUnreachable: + res = allDestinationUnreachable + case netp.Redirect: + res = allRedirect + case netp.TimeExceeded: + res = allTimeExceeded + default: + res = *fromIndex(encode(t.TypeCode.Type, 0)) + } + return &res +} + +func getRFCICMPCubeStr(cube netp.ICMP) string { + tc := cube.ICMPTypeCode() + if tc == nil { + return "" + } + if tc.Code == nil { + if netp.HasSingleCode(tc.Type) { + return fmt.Sprintf("icmp-type: %d icmp-code: 0", tc.Type) + } + return fmt.Sprintf("icmp-type: %d", tc.Type) + } + return fmt.Sprintf("icmp-type: %d icmp-code: %d", tc.Type, *tc.Code) +} + +func (s *RFCICMPSet) String() string { + if s.IsEmpty() { + return "" + } + cubes := s.Partitions() + var resStrings = make([]string, len(cubes)) + for i, cube := range cubes { + resStrings[i] = getRFCICMPCubeStr(cube) + } + sort.Strings(resStrings) + str := string(netp.ProtocolStringICMP) + last := strings.Join(resStrings, semicolon) + if last != "" { + str += " " + last + } + return str +} diff --git a/pkg/netset/tcpudpset.go b/pkg/netset/tcpudpset.go index 3baea26..40ab37f 100644 --- a/pkg/netset/tcpudpset.go +++ b/pkg/netset/tcpudpset.go @@ -93,8 +93,8 @@ func (c *TCPUDPSet) IsSubset(other *TCPUDPSet) bool { return c.props.IsSubset(other.props) } -// pathLeft creates a new TCPUDPSet, implemented using LeftTriple. -func pathLeft(protocol *ProtocolSet, srcPort, dstPort *PortSet) *TCPUDPSet { +// tcpudpPathLeft creates a new TCPUDPSet, implemented using LeftTriple. +func tcpudpPathLeft(protocol *ProtocolSet, srcPort, dstPort *PortSet) *TCPUDPSet { return &TCPUDPSet{props: ds.CartesianLeftTriple(protocol, srcPort, dstPort)} } @@ -103,7 +103,7 @@ func EmptyTCPorUDPSet() *TCPUDPSet { } func AllTCPUDPSet() *TCPUDPSet { - return pathLeft( + return tcpudpPathLeft( AllTCPUDPProtocolSet(), AllPorts(), AllPorts(), @@ -111,7 +111,7 @@ func AllTCPUDPSet() *TCPUDPSet { } func NewAllTCPOnlySet() *TCPUDPSet { - return pathLeft( + return tcpudpPathLeft( interval.New(TCPCode, TCPCode).ToSet(), AllPorts(), AllPorts(), @@ -119,7 +119,7 @@ func NewAllTCPOnlySet() *TCPUDPSet { } func NewAllUDPOnlySet() *TCPUDPSet { - return pathLeft( + return tcpudpPathLeft( interval.New(UDPCode, UDPCode).ToSet(), AllPorts(), AllPorts(), @@ -128,7 +128,7 @@ func NewAllUDPOnlySet() *TCPUDPSet { func NewTCPorUDPSet(protocolString netp.ProtocolString, srcMinP, srcMaxP, dstMinP, dstMaxP int64) *TCPUDPSet { protocol := protocolStringToCode(protocolString) - return pathLeft( + return tcpudpPathLeft( interval.New(protocol, protocol).ToSet(), interval.New(srcMinP, srcMaxP).ToSet(), interval.New(dstMinP, dstMaxP).ToSet(), diff --git a/pkg/netset/transportset.go b/pkg/netset/transportset.go index a5288e2..140a824 100644 --- a/pkg/netset/transportset.go +++ b/pkg/netset/transportset.go @@ -28,21 +28,13 @@ func NewTCPorUDPTransport(protocol netp.ProtocolString, srcMinP, srcMaxP, dstMin )} } -func NewICMPTransport(tc netp.ICMP) *TransportSet { +func NewICMPTransport(minType, maxType, minCode, maxCode int64) *TransportSet { return &TransportSet{ds.NewDisjoint( EmptyTCPorUDPSet(), - NewICMPSet(tc), + NewICMPSet(minType, maxType, minCode, maxCode), )} } -func NewICMPTransportFromTypeCode(icmpType, icmpCode int64) (*TransportSet, error) { - icmp, err := netp.ICMPFromTypeAndCode64(&icmpType, &icmpCode) - if err != nil { - return nil, err - } - return NewICMPTransport(icmp), nil -} - func AllOrNothingTransport(allTcpudp, allIcmp bool) *TransportSet { var tcpudp *TCPUDPSet var icmp *ICMPSet diff --git a/pkg/netset/transportset_test.go b/pkg/netset/transportset_test.go index 0193c37..9ef86ab 100644 --- a/pkg/netset/transportset_test.go +++ b/pkg/netset/transportset_test.go @@ -35,7 +35,8 @@ func TestAllConnectionsTransportSet(t *testing.T) { require.True(t, tcpudpPartitions[0].S2.Equal(netset.AllPorts())) require.True(t, tcpudpPartitions[0].S3.Equal(netset.AllPorts())) // all icmp - require.Nil(t, icmpPartitions[0].TypeCode) + require.True(t, icmpPartitions[0].Left.Equal(netset.AllICMPTypes())) + require.True(t, icmpPartitions[0].Right.Equal(netset.AllICMPCodes())) } func TestNoConnectionsTransportSet(t *testing.T) { @@ -50,10 +51,9 @@ func TestNoConnectionsTransportSet(t *testing.T) { } func TestBasicSetICMPTransportSet(t *testing.T) { - c, err := netset.NewICMPTransportFromTypeCode(ICMPValue, 5) - require.Nil(t, err) - fmt.Println(c) // ICMP icmp-type: 3 icmp-code: 5 - require.Equal(t, "ICMP icmp-type: 3 icmp-code: 5", c.String()) + c := netset.NewICMPTransport(ICMPValue, ICMPValue, 5, 5) + fmt.Println(c) // "ICMP type: 3 code: 5" + require.Equal(t, "ICMP type: 3 code: 5", c.String()) } func TestBasicSetTCPTransportSet(t *testing.T) { @@ -78,32 +78,30 @@ func TestBasicSetTCPTransportSet(t *testing.T) { } func TestBasicSet2TransportSet(t *testing.T) { - except1, err := netset.NewICMPTransportFromTypeCode(ICMPValue, 5) - require.Nil(t, err) + except1 := netset.NewICMPTransport(ICMPValue, ICMPValue, 5, 5) except2 := netset.NewTCPorUDPTransport(netp.ProtocolStringTCP, 1, 65535, 1, 65535) d := netset.AllTransportSet().Subtract(except1).Subtract(except2) - fmt.Println(d) - // ICMP icmp-type: 0 icmp-code: 0;icmp-type: 11;icmp-type: 12 icmp-code: 0;icmp-type: 13 icmp-code: 0; - // icmp-type: 14 icmp-code: 0;icmp-type: 15 icmp-code: 0;icmp-type: 16 icmp-code: 0;icmp-type: 3 icmp-code: 0; - // icmp-type: 3 icmp-code: 1;icmp-type: 3 icmp-code: 2;icmp-type: 3 icmp-code: 3;icmp-type: 3 icmp-code: 4; - // icmp-type: 4 icmp-code: 0;icmp-type: 5;icmp-type: 8 icmp-code: 0;UDP + fmt.Println(d) // ICMP type: 0-2,4-254 | ICMP type: 3 code: 0-4,6-255;UDP - require.Equal(t, 15, len(d.ICMPSet().Partitions())) + require.Equal(t, 2, len(d.ICMPSet().Partitions())) require.Equal(t, 1, len(d.TCPUDPSet().Partitions())) - fmt.Println("done") /* string from older version: "protocol: ICMP icmp-type: 0-2,4-16; "+ "protocol: ICMP icmp-type: 3 icmp-code: 0-4; "+ "protocol: UDP", d.String()) + + from icmp-strict version: + // ICMP icmp-type: 0 icmp-code: 0;icmp-type: 11;icmp-type: 12 icmp-code: 0;icmp-type: 13 icmp-code: 0; + // icmp-type: 14 icmp-code: 0;icmp-type: 15 icmp-code: 0;icmp-type: 16 icmp-code: 0;icmp-type: 3 icmp-code: 0; + // icmp-type: 3 icmp-code: 1;icmp-type: 3 icmp-code: 2;icmp-type: 3 icmp-code: 3;icmp-type: 3 icmp-code: 4; + // icmp-type: 4 icmp-code: 0;icmp-type: 5;icmp-type: 8 icmp-code: 0;UDP */ } func TestBasicSet3TransportSet(t *testing.T) { - c, err := netset.NewICMPTransportFromTypeCode(ICMPValue, 5) - c1, _ := netset.NewICMPTransportFromTypeCode(ICMPValue, 5) - require.Nil(t, err) - d := netset.AllTransportSet().Subtract(c).Union(c1) + c := netset.NewICMPTransport(ICMPValue, ICMPValue, 5, 5) + d := netset.AllTransportSet().Subtract(c).Union(netset.NewICMPTransport(ICMPValue, ICMPValue, 5, 5)) require.Equal(t, netset.AllConnections, d.String()) }