diff --git a/cmd/registration-server/main.go b/cmd/registration-server/main.go index dfe59731..eee17120 100644 --- a/cmd/registration-server/main.go +++ b/cmd/registration-server/main.go @@ -33,20 +33,25 @@ type regServer interface { // config defines the variables and options from the toml config file type config struct { - DNSListenAddr string `toml:"dns_listen_addr"` - Domain string `toml:"domain"` - DNSPrivkeyPath string `toml:"dns_private_key_path"` - APIPort uint16 `toml:"api_port"` - ZMQAuthVerbose bool `toml:"zmq_auth_verbose"` - ZMQAuthType string `toml:"zmq_auth_type"` - ZMQPort uint16 `toml:"zmq_port"` - ZMQBindAddr string `toml:"zmq_bind_addr"` - ZMQPrivateKeyPath string `toml:"zmq_privkey_path"` - StationPublicKeys []string `toml:"station_pubkeys"` - ClientConfPath string `toml:"clientconf_path"` - latestClientConf *pb.ClientConf - LogLevel string `toml:"log_level"` - LogMetricsInterval uint16 `toml:"log_metrics_interval"` + DNSListenAddr string `toml:"dns_listen_addr"` + Domain string `toml:"domain"` + DNSPrivkeyPath string `toml:"dns_private_key_path"` + APIPort uint16 `toml:"api_port"` + ZMQAuthVerbose bool `toml:"zmq_auth_verbose"` + ZMQAuthType string `toml:"zmq_auth_type"` + ZMQPort uint16 `toml:"zmq_port"` + ZMQBindAddr string `toml:"zmq_bind_addr"` + ZMQPrivateKeyPath string `toml:"zmq_privkey_path"` + StationPublicKeys []string `toml:"station_pubkeys"` + ClientConfPath string `toml:"clientconf_path"` + latestClientConf *pb.ClientConf + LogLevel string `toml:"log_level"` + LogMetricsInterval uint16 `toml:"log_metrics_interval"` + EnforceSubnetOverrides bool `toml:"enforce_subnet_overrides"` + PrcntMinRegsToOverride float64 `toml:"prcnt_min_regs_to_override"` + PrcntPrefixRegsToOverride float64 `toml:"prcnt_prefix_regs_to_override"` + OverrideSubnets []regprocessor.Subnet `toml:"override_subnet"` + ExclusionsFromOverride []regprocessor.Subnet `toml:"excluded_subnet_from_overrides"` } var defaultTransports = map[pb.TransportType]lib.Transport{ @@ -192,9 +197,9 @@ func main() { switch conf.ZMQAuthType { case "CURVE": - processor, err = regprocessor.NewRegProcessor(conf.ZMQBindAddr, conf.ZMQPort, zmqPrivkey, conf.ZMQAuthVerbose, conf.StationPublicKeys, metrics) + processor, err = regprocessor.NewRegProcessor(conf.ZMQBindAddr, conf.ZMQPort, zmqPrivkey, conf.ZMQAuthVerbose, conf.StationPublicKeys, metrics, conf.EnforceSubnetOverrides, conf.OverrideSubnets, conf.ExclusionsFromOverride, conf.PrcntMinRegsToOverride, conf.PrcntPrefixRegsToOverride) case "NULL": - processor, err = regprocessor.NewRegProcessorNoAuth(conf.ZMQBindAddr, conf.ZMQPort, metrics) + processor, err = regprocessor.NewRegProcessorNoAuth(conf.ZMQBindAddr, conf.ZMQPort, metrics, conf.EnforceSubnetOverrides, conf.OverrideSubnets, conf.ExclusionsFromOverride, conf.PrcntMinRegsToOverride, conf.PrcntPrefixRegsToOverride) default: log.Fatalf("Unknown ZMQ auth type: %s", conf.ZMQAuthType) } diff --git a/cmd/registration-server/reg_config.toml b/cmd/registration-server/reg_config.toml index bc470647..867bb82b 100644 --- a/cmd/registration-server/reg_config.toml +++ b/cmd/registration-server/reg_config.toml @@ -45,3 +45,32 @@ bidirectional_api_generation = 957 # Path on disk to the latest ClientConfig file that the station should use clientconf_path = "/var/lib/conjure/ClientConf" + +# Whether to apply the below subnet overrides to clients bidirectional api registrations +enforce_subnet_overrides = true + +# Percentage of bidirectional api registrations to override per transport +prcnt_min_regs_to_override = 100 +prcnt_prefix_regs_to_override = 100 + +# Subnets to use when overriding clients bidirectional api registrations +[[override_subnet]] +cidr = "X.X.X.X/32" +weight = 10.7 +port = 443 +transport = "Min_Transport" + +[[override_subnet]] +cidr = "X.X.X.X/24" +weight = 10 +port = 80 +transport = "Prefix_Transport" +prefix_id = 1 + +# Subnets to refrain from overriding when clients bidirectional api registrations pick a v4 phantom inside them +[[excluded_subnet_from_overrides]] +cidr = "X.X.X.X/25" +# For future features that can exclude subnets according to weight, port, or transport +weight = 28.7 +port = 80 +transport = "Min_Transport" diff --git a/pkg/regserver/regprocessor/auth_test.go b/pkg/regserver/regprocessor/auth_test.go index 49ccf7c3..49e605da 100644 --- a/pkg/regserver/regprocessor/auth_test.go +++ b/pkg/regserver/regprocessor/auth_test.go @@ -134,7 +134,7 @@ func TestZMQAuth(t *testing.T) { // messages that we expect the station to hear. in production this will be new registrations, // here we don't care about the message contents. go func() { - regProcessor, err := newRegProcessor(zmqBindAddr, zmqPort, []byte(zmq.Z85decode(serverPrivkeyZ85)), true, stationPublicKeys) + regProcessor, err := newRegProcessor(zmqBindAddr, zmqPort, []byte(zmq.Z85decode(serverPrivkeyZ85)), true, stationPublicKeys, false, nil, nil, 0.0, 0.0) require.Nil(t, err) defer regProcessor.Close() errStation := regProcessor.AddTransport(pb.TransportType_Min, min.Transport{}) diff --git a/pkg/regserver/regprocessor/regprocessor.go b/pkg/regserver/regprocessor/regprocessor.go index d3cbcee7..080d77ed 100644 --- a/pkg/regserver/regprocessor/regprocessor.go +++ b/pkg/regserver/regprocessor/regprocessor.go @@ -11,7 +11,12 @@ import ( "encoding/binary" "errors" "fmt" + "math" + "math/big" + mrand "math/rand" + "net" "sync" + "time" zmq "github.com/pebbe/zmq4" "github.com/refraction-networking/conjure/pkg/core" @@ -20,8 +25,11 @@ import ( "github.com/refraction-networking/conjure/pkg/phantoms" "github.com/refraction-networking/conjure/pkg/regserver/overrides" "github.com/refraction-networking/conjure/pkg/station/lib" + "github.com/refraction-networking/conjure/pkg/transports/wrapping/prefix" + pb "github.com/refraction-networking/conjure/proto" "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/anypb" ) var ( @@ -74,10 +82,180 @@ type RegProcessor struct { regOverrides interfaces.Overrides transports map[pb.TransportType]lib.Transport + + enforceSubnetOverrides bool + minOverrideSubnets []Subnet + minOverrideSubnetsCumulativeWeights []float64 + prefixOverrideSubnetsCumulativeWeights []float64 + prefixOverrideSubnets []Subnet + exclusionsFromOverride []Subnet + prcntMinRegsToOverride float64 + prcntPrefixRegsToOverride float64 +} + +type Subnet struct { + CIDR Ipnet `toml:"cidr"` + Weight float64 `toml:"weight"` + Port uint32 `toml:"port"` + Transport string `toml:"transport"` + PrefixId prefix.PrefixID `toml:"prefix_id"` +} + +type Ipnet struct { + *net.IPNet +} + +// UnmarshalText makes CIDR compatible with TOML decoding +func (n *Ipnet) UnmarshalText(text []byte) error { + _, cidr, err := net.ParseCIDR(string(text)) + if err != nil { + return err + } + n.IPNet = cidr + return nil +} + +// helper function to convert IPv4 to uint32 +func ipv4ToUint32(ip net.IP) (uint32, error) { + err := errors.New("Provided IP is not IPv4") + if ip == nil { + return 0, err + } + + ip = ip.To4() + if ip == nil { + return 0, err + } + + return binary.BigEndian.Uint32(ip), nil +} + +// helper function to cenvert uint32 to IPv4 +func uint32ToIPv4(ip *uint32) net.IP { + if ip == nil { + return nil + } + + ipInt := *ip + return net.IPv4( + byte(ipInt>>24), + byte(ipInt>>16), + byte(ipInt>>8), + byte(ipInt), + ) +} + +// helper function that wraps randomInt() +func getRandUint32IPv4(ipNet *net.IPNet) (uint32, error) { + ipUint32, err := ipv4ToUint32(ipNet.IP) + if err != nil { + return 0, errors.New("Failed to convert IPv4 to uint32") + } + + mask := ipNet.Mask + ones, bits := mask.Size() + hosts := uint32(1 << uint32(bits-ones)) + + ip, err := randomInt(ipUint32, ipUint32+hosts) + if err != nil { + return 0, errors.New("Failed to get random IPv4 as uint32 from the given range") + } + return ip, nil +} + +// helper function to get random integers within a range +func randomInt(x, y uint32) (uint32, error) { + rangeSize := y - x + // Generate a random number in the range [0, rangeSize) + randomNum, err := rand.Int(rand.Reader, big.NewInt(int64(rangeSize))) + if err != nil { + return 0, err + } + // Return the random number in the range [x, y] + return x + uint32(randomNum.Int64()), nil +} + +// helper function to override the prefix in the registration response +func overridePrefix(newRegResp *pb.RegistrationResponse, prefixId prefix.PrefixID, dstPort uint32) error { + // Override Phantom dstPort + newRegResp.DstPort = proto.Uint32(dstPort) + // Override Prefix choice and PrefixParam + newPrefix, err := prefix.TryFromID(prefixId) + if err != nil { + return err + } + var fp = newPrefix.FlushPolicy() + var i int32 = int32(newPrefix.ID()) + newparams := &pb.PrefixTransportParams{} + newparams.PrefixId = &i + newparams.CustomFlushPolicy = &fp + newparams.Prefix = newPrefix.Bytes() + anypbParams, err := anypb.New(newparams) + if err != nil { + return err + } + newRegResp.TransportParams = anypbParams + return nil +} + +// helper function to validate override percentages for the Min and Prefix transports set by reg_config.toml +func validateOverridePercentages(prcntMinConnsToOverride float64, prcntPrefixConnsToOverride float64) (float64, float64) { + if prcntMinConnsToOverride > 100.0 || prcntMinConnsToOverride < 0.0 { + fmt.Println("prcnt_min_conns_to_override value in reg_config.toml is out of range [0,100]. Resetting to 50%") + prcntMinConnsToOverride = 50 * 10 + } else { + prcntMinConnsToOverride = math.Round(prcntMinConnsToOverride*100) / 10 + } + if prcntPrefixConnsToOverride > 100.0 || prcntPrefixConnsToOverride < 0.0 { + fmt.Println("prcnt_prefix_conns_to_override value in reg_config.toml is out of range [0,100]. Resetting to 50%") + prcntPrefixConnsToOverride = 50 * 10 + } else { + prcntPrefixConnsToOverride = math.Round(prcntPrefixConnsToOverride*100) / 10 + } + return prcntMinConnsToOverride, prcntPrefixConnsToOverride +} + +// shallow-copy the override subnets into different slices based on transport type. +// could be improved to handle different transports +func splitOverrideSubnets(overrideSubnets []Subnet) ([]Subnet, []Subnet) { + + var minOverrideSubnets []Subnet + var prefixOverrideSubnets []Subnet + for _, subnet := range overrideSubnets { + if subnet.Transport == "Min_Transport" { + minOverrideSubnets = append(minOverrideSubnets, subnet) + } else if subnet.Transport == "Prefix_Transport" { + prefixOverrideSubnets = append(prefixOverrideSubnets, subnet) + } + } + return minOverrideSubnets, prefixOverrideSubnets +} + +// calculate cumulative weights for a given subnets slice +func processOverrideSubnetsWeights(subnets []Subnet) []float64 { + + if len(subnets) == 0 { + return nil + } + + var totalWeight float64 + for _, subnet := range subnets { + totalWeight += subnet.Weight + } + + cumulativeWeights := make([]float64, len(subnets)) + for i, subnet := range subnets { + if i == 0 { + cumulativeWeights[i] = subnet.Weight / totalWeight + } else { + cumulativeWeights[i] = cumulativeWeights[i-1] + (subnet.Weight / totalWeight) + } + } + return cumulativeWeights } // NewRegProcessor initialize a new RegProcessor -func NewRegProcessor(zmqBindAddr string, zmqPort uint16, privkey []byte, authVerbose bool, stationPublicKeys []string, metrics *metrics.Metrics) (*RegProcessor, error) { +func NewRegProcessor(zmqBindAddr string, zmqPort uint16, privkey []byte, authVerbose bool, stationPublicKeys []string, metrics *metrics.Metrics, enforceSubnetOverrides bool, overrideSubnets []Subnet, exclusionsFromOverride []Subnet, prcntMinRegsToOverride float64, prcntPrefixRegsToOverride float64) (*RegProcessor, error) { if len(privkey) != ed25519.PrivateKeySize { // We require the 64 byte [private_key][public_key] format to Sign using crypto/ed25519 @@ -89,7 +267,7 @@ func NewRegProcessor(zmqBindAddr string, zmqPort uint16, privkey []byte, authVer return nil, err } - regProcessor, err := newRegProcessor(zmqBindAddr, zmqPort, privkey, authVerbose, stationPublicKeys) + regProcessor, err := newRegProcessor(zmqBindAddr, zmqPort, privkey, authVerbose, stationPublicKeys, enforceSubnetOverrides, overrideSubnets, exclusionsFromOverride, prcntMinRegsToOverride, prcntPrefixRegsToOverride) if err != nil { return nil, err } @@ -101,7 +279,7 @@ func NewRegProcessor(zmqBindAddr string, zmqPort uint16, privkey []byte, authVer // initializes the registration processor without the phantom selector which can be added by a // wrapping function before it is returned. This function is required for testing. -func newRegProcessor(zmqBindAddr string, zmqPort uint16, privkey []byte, authVerbose bool, stationPublicKeys []string) (*RegProcessor, error) { +func newRegProcessor(zmqBindAddr string, zmqPort uint16, privkey []byte, authVerbose bool, stationPublicKeys []string, enforceSubnetOverrides bool, overrideSubnets []Subnet, exclusionsFromOverride []Subnet, prcntMinRegsToOverride float64, prcntPrefixRegsToOverride float64) (*RegProcessor, error) { sock, err := zmq.NewSocket(zmq.PUB) if err != nil { return nil, fmt.Errorf("%w: %v", ErrZmqSocket, err) @@ -137,19 +315,37 @@ func newRegProcessor(zmqBindAddr string, zmqPort uint16, privkey []byte, authVer regOverrides = interfaces.Overrides([]interfaces.RegOverride{overrides.NewRandPrefixOverride()}) } - return &RegProcessor{ - zmqMutex: sync.Mutex{}, - selectorMutex: sync.RWMutex{}, - sock: sock, - transports: make(map[pb.TransportType]lib.Transport), - authenticated: true, - privkey: privkey, - regOverrides: regOverrides, - }, nil + prcntMinRegsToOverride, prcntPrefixRegsToOverride = validateOverridePercentages(prcntMinRegsToOverride, prcntPrefixRegsToOverride) + + minOverrideSubnets, prefixOverrideSubnets := splitOverrideSubnets(overrideSubnets) + + minOverrideSubnetsCumulativeWeights := processOverrideSubnetsWeights(minOverrideSubnets) + prefixOverrideSubnetsCumulativeWeights := processOverrideSubnetsWeights(prefixOverrideSubnets) + + rp := &RegProcessor{ + zmqMutex: sync.Mutex{}, + selectorMutex: sync.RWMutex{}, + sock: sock, + transports: make(map[pb.TransportType]lib.Transport), + authenticated: true, + privkey: privkey, + regOverrides: regOverrides, + enforceSubnetOverrides: enforceSubnetOverrides, + minOverrideSubnets: minOverrideSubnets, + prefixOverrideSubnets: prefixOverrideSubnets, + minOverrideSubnetsCumulativeWeights: minOverrideSubnetsCumulativeWeights, + prefixOverrideSubnetsCumulativeWeights: prefixOverrideSubnetsCumulativeWeights, + exclusionsFromOverride: make([]Subnet, len(exclusionsFromOverride)), + prcntMinRegsToOverride: prcntMinRegsToOverride, + prcntPrefixRegsToOverride: prcntPrefixRegsToOverride, + } + copy(rp.exclusionsFromOverride, exclusionsFromOverride) + + return rp, nil } // NewRegProcessorNoAuth creates a regprocessor without authentication to zmq address -func NewRegProcessorNoAuth(zmqBindAddr string, zmqPort uint16, metrics *metrics.Metrics) (*RegProcessor, error) { +func NewRegProcessorNoAuth(zmqBindAddr string, zmqPort uint16, metrics *metrics.Metrics, enforceSubnetOverrides bool, overrideSubnets []Subnet, exclusionsFromOverride []Subnet, prcntMinRegsToOverride float64, prcntPrefixRegsToOverride float64) (*RegProcessor, error) { sock, err := zmq.NewSocket(zmq.PUB) if err != nil { return nil, ErrZmqSocket @@ -165,15 +361,33 @@ func NewRegProcessorNoAuth(zmqBindAddr string, zmqPort uint16, metrics *metrics. return nil, err } - return &RegProcessor{ - zmqMutex: sync.Mutex{}, - selectorMutex: sync.RWMutex{}, - ipSelector: phantomSelector, - sock: sock, - metrics: metrics, - transports: make(map[pb.TransportType]lib.Transport), - authenticated: false, - }, nil + prcntMinRegsToOverride, prcntPrefixRegsToOverride = validateOverridePercentages(prcntMinRegsToOverride, prcntPrefixRegsToOverride) + + minOverrideSubnets, prefixOverrideSubnets := splitOverrideSubnets(overrideSubnets) + + minOverrideSubnetsCumulativeWeights := processOverrideSubnetsWeights(minOverrideSubnets) + prefixOverrideSubnetsCumulativeWeights := processOverrideSubnetsWeights(prefixOverrideSubnets) + + rp := &RegProcessor{ + zmqMutex: sync.Mutex{}, + selectorMutex: sync.RWMutex{}, + ipSelector: phantomSelector, + sock: sock, + metrics: metrics, + transports: make(map[pb.TransportType]lib.Transport), + authenticated: false, + enforceSubnetOverrides: enforceSubnetOverrides, + minOverrideSubnets: minOverrideSubnets, + prefixOverrideSubnets: prefixOverrideSubnets, + minOverrideSubnetsCumulativeWeights: minOverrideSubnetsCumulativeWeights, + prefixOverrideSubnetsCumulativeWeights: prefixOverrideSubnetsCumulativeWeights, + exclusionsFromOverride: make([]Subnet, len(exclusionsFromOverride)), + prcntMinRegsToOverride: prcntMinRegsToOverride, + prcntPrefixRegsToOverride: prcntPrefixRegsToOverride, + } + copy(rp.exclusionsFromOverride, exclusionsFromOverride) + + return rp, nil } // Close cleans up the (ZMQ) servers running in the background supporting registration. @@ -355,7 +569,111 @@ func (p *RegProcessor) processBdReq(c2sPayload *pb.C2SWrapper) (*pb.Registration } else { regResp.DstPort = proto.Uint32(443) } + if p.enforceSubnetOverrides { + ipv4FromRegResponse := uint32ToIPv4(regResp.Ipv4Addr) + for _, subnet := range p.exclusionsFromOverride { + // TODO: apply exclusions based on both transport and subnet + if subnet.CIDR.IPNet.Contains(ipv4FromRegResponse) { + // the IPv4 originally chosen by the client exists in a subnet we excluded from overrides + // so do not apply overrides + return regResp, nil + } + } + num, err := randomInt(0, 10000) + if err != nil { + // In case of an error, return the original regResp and + // do not apply overrides + return regResp, nil + } + + // random float64 between 0 and 999 + randNumFloat := float64(num) / 10.0 + + var ipNet *net.IPNet + var dstPortOverride uint32 + + // random float64 between 0 and 1 + mrand.Seed(time.Now().UnixNano()) + randVal := mrand.Float64() + + // ignore prior choices and begin experimental overrides for Min and Prefix transports only + if transportType == pb.TransportType_Min { + if randNumFloat < p.prcntMinRegsToOverride { + if p.minOverrideSubnets == nil { + // reg_conf.toml does not contain subnet overrides for Min transport + return regResp, nil + } + + for i, cumulativeWeight := range p.minOverrideSubnetsCumulativeWeights { + if randVal < cumulativeWeight { + ipNet = p.minOverrideSubnets[i].CIDR.IPNet + //dstPortOverride = p.minOverrideSubnets[i].Port + } + } + + if ipNet == nil { + // problem in choosing a weighted override subnet + // so do not apply overrides + return regResp, nil + } + + ip, err := getRandUint32IPv4(ipNet) + if err != nil { + // failed to get random IPv4 as uint32 from the given range. + // do not apply override and return the original regResp. + return regResp, nil + } + regResp.Ipv4Addr = proto.Uint32(ip) + } + } else if transportType == pb.TransportType_Prefix { + + // Override the Phantom IPv4 for clients with the Prefix transport + // and override the transport type only if c2s.GetDisableRegistrarOverrides() is false + if !c2s.GetDisableRegistrarOverrides() { + if randNumFloat < p.prcntPrefixRegsToOverride { + if p.prefixOverrideSubnets == nil { + // reg_conf.toml does not contain subnet overrides for Prefix transport + return regResp, nil + } + + //newRegResp := &pb.RegistrationResponse{} + var prefixid prefix.PrefixID + for i, cumulativeWeight := range p.prefixOverrideSubnetsCumulativeWeights { + if randVal < cumulativeWeight { + ipNet = p.prefixOverrideSubnets[i].CIDR.IPNet + dstPortOverride = p.prefixOverrideSubnets[i].Port + prefixid = p.prefixOverrideSubnets[i].PrefixId + } + } + + if ipNet == nil { + // problem in choosing a weighted override subnet + // so do not apply overrides + return regResp, nil + } + + ip, err := getRandUint32IPv4(ipNet) + if err != nil { + // failed to get random IPv4 as uint32 from the given range. + // do not apply override and return the original regResp. + return regResp, nil + } + + newRegResp := proto.Clone(regResp).(*pb.RegistrationResponse) + + err = overridePrefix(newRegResp, prefixid, dstPortOverride) + if err != nil { + return regResp, nil + } + newRegResp.Ipv4Addr = proto.Uint32(ip) + + regResp = newRegResp + c2sPayload.RegistrationResponse = regResp + } + } + } + } return regResp, nil }