diff --git a/x/configurl/config.go b/x/configurl/config.go index 0c3c0b97..e338c950 100644 --- a/x/configurl/config.go +++ b/x/configurl/config.go @@ -15,108 +15,84 @@ package configurl import ( + "context" "errors" "fmt" "net/url" - "strconv" "strings" - - "github.com/Jigsaw-Code/outline-sdk/transport" - "github.com/Jigsaw-Code/outline-sdk/transport/tlsfrag" ) -// ConfigToDialer enables the creation of stream and packet dialers based on a config. The config is -// extensible by registering wrappers for config subtypes. -type ConfigToDialer struct { - // Base StreamDialer to create direct stream connections. If you need direct stream connections, this must not be nil. - BaseStreamDialer transport.StreamDialer - // Base PacketDialer to create direct packet connections. If you need direct packet connections, this must not be nil. - BasePacketDialer transport.PacketDialer - sdBuilders map[string]NewStreamDialerFunc - pdBuilders map[string]NewPacketDialerFunc +// Config is a pre-parsed generic config created from pipe-separated URLs. +type Config struct { + URL url.URL + BaseConfig *Config } -// NewStreamDialerFunc wraps a Dialer based on the wrapConfig. The innerSD and innerPD functions can provide a base Stream and Packet Dialers if needed. -type NewStreamDialerFunc func(innerSD func() (transport.StreamDialer, error), innerPD func() (transport.PacketDialer, error), wrapConfig *url.URL) (transport.StreamDialer, error) - -// NewPacketDialerFunc wraps a Dialer based on the wrapConfig. The innerSD and innerPD functions can provide a base Stream and Packet Dialers if needed. -type NewPacketDialerFunc func(innerSD func() (transport.StreamDialer, error), innerPD func() (transport.PacketDialer, error), wrapConfig *url.URL) (transport.PacketDialer, error) - -// NewDefaultConfigToDialer creates a [ConfigToDialer] with a set of default wrappers already registered. -func NewDefaultConfigToDialer() *ConfigToDialer { - p := new(ConfigToDialer) - p.BaseStreamDialer = &transport.TCPDialer{} - p.BasePacketDialer = &transport.UDPDialer{} - - // Please keep the list in alphabetical order. - p.RegisterStreamDialerType("do53", wrapStreamDialerWithDO53) - - p.RegisterStreamDialerType("doh", wrapStreamDialerWithDOH) - - p.RegisterStreamDialerType("override", wrapStreamDialerWithOverride) - p.RegisterPacketDialerType("override", wrapPacketDialerWithOverride) - - p.RegisterStreamDialerType("socks5", wrapStreamDialerWithSOCKS5) - p.RegisterPacketDialerType("socks5", wrapPacketDialerWithSOCKS5) - - p.RegisterStreamDialerType("split", wrapStreamDialerWithSplit) +// BuildFunc is a function that creates an instance of ObjectType given a [Config]. +type BuildFunc[ObjectType any] func(ctx context.Context, config *Config) (ObjectType, error) - p.RegisterStreamDialerType("ss", wrapStreamDialerWithShadowsocks) - p.RegisterPacketDialerType("ss", wrapPacketDialerWithShadowsocks) - - p.RegisterStreamDialerType("tls", wrapStreamDialerWithTLS) +// TypeRegistry registers config types. +type TypeRegistry[ObjectType any] interface { + RegisterType(subtype string, newInstance BuildFunc[ObjectType]) +} - p.RegisterStreamDialerType("tlsfrag", func(innerSD func() (transport.StreamDialer, error), innerPD func() (transport.PacketDialer, error), wrapConfig *url.URL) (transport.StreamDialer, error) { - sd, err := innerSD() - if err != nil { - return nil, err - } - lenStr := wrapConfig.Opaque - fixedLen, err := strconv.Atoi(lenStr) - if err != nil { - return nil, fmt.Errorf("invalid tlsfrag option: %v. It should be in tlsfrag: format", lenStr) - } - return tlsfrag.NewFixedLenStreamDialer(sd, fixedLen) - }) +// ExtensibleProvider creates instances of ObjectType in a way that can be extended via its [TypeRegistry] interface. +type ExtensibleProvider[ObjectType comparable] struct { + // Instance to return when config is nil. + BaseInstance ObjectType + builders map[string]BuildFunc[ObjectType] +} - p.RegisterStreamDialerType("ws", wrapStreamDialerWithWebSocket) - p.RegisterPacketDialerType("ws", wrapPacketDialerWithWebSocket) +var ( + _ BuildFunc[any] = (*ExtensibleProvider[any])(nil).NewInstance + _ TypeRegistry[any] = (*ExtensibleProvider[any])(nil) +) - return p +// NewExtensibleProvider creates an [ExtensibleProvider] with the given base instance. +func NewExtensibleProvider[ObjectType comparable](baseInstance ObjectType) ExtensibleProvider[ObjectType] { + return ExtensibleProvider[ObjectType]{ + BaseInstance: baseInstance, + builders: make(map[string]BuildFunc[ObjectType]), + } } -// RegisterStreamDialerType will register a wrapper for stream dialers under the given subtype. -func (p *ConfigToDialer) RegisterStreamDialerType(subtype string, newDialer NewStreamDialerFunc) error { - if p.sdBuilders == nil { - p.sdBuilders = make(map[string]NewStreamDialerFunc) +func (p *ExtensibleProvider[ObjectType]) ensureBuildersMap() map[string]BuildFunc[ObjectType] { + if p.builders == nil { + p.builders = make(map[string]BuildFunc[ObjectType]) } + return p.builders +} - if _, found := p.sdBuilders[subtype]; found { - return fmt.Errorf("config parser %v for StreamDialer added twice", subtype) - } - p.sdBuilders[subtype] = newDialer - return nil +// RegisterType will register a factory for the given subtype. +func (p *ExtensibleProvider[ObjectType]) RegisterType(subtype string, newInstance BuildFunc[ObjectType]) { + p.ensureBuildersMap()[subtype] = newInstance } -// RegisterPacketDialerType will register a wrapper for packet dialers under the given subtype. -func (p *ConfigToDialer) RegisterPacketDialerType(subtype string, newDialer NewPacketDialerFunc) error { - if p.pdBuilders == nil { - p.pdBuilders = make(map[string]NewPacketDialerFunc) +// NewInstance creates a new instance of ObjectType according to the config. +func (p *ExtensibleProvider[ObjectType]) NewInstance(ctx context.Context, config *Config) (ObjectType, error) { + var zero ObjectType + if config == nil { + if p.BaseInstance == zero { + return zero, errors.New("base instance is not configured") + } + return p.BaseInstance, nil } - if _, found := p.pdBuilders[subtype]; found { - return fmt.Errorf("config parser %v for StreamDialer added twice", subtype) + newInstance, ok := p.ensureBuildersMap()[config.URL.Scheme] + if !ok { + return zero, fmt.Errorf("config type '%v' is not registered", config.URL.Scheme) } - p.pdBuilders[subtype] = newDialer - return nil + return newInstance(ctx, config) } -func parseConfig(configText string) ([]*url.URL, error) { +// ParseConfig will parse a config given as a string and return the structured [Config]. +func ParseConfig(configText string) (*Config, error) { parts := strings.Split(strings.TrimSpace(configText), "|") if len(parts) == 1 && parts[0] == "" { - return []*url.URL{}, nil + return nil, nil } - urls := make([]*url.URL, 0, len(parts)) + + var config *Config = nil for _, part := range parts { part = strings.TrimSpace(part) if part == "" { @@ -130,143 +106,7 @@ func parseConfig(configText string) ([]*url.URL, error) { if err != nil { return nil, fmt.Errorf("part is not a valid URL: %w", err) } - urls = append(urls, url) - } - return urls, nil -} - -// NewStreamDialer creates a [Dialer] according to transportConfig, using dialer as the -// base [Dialer]. The given dialer must not be nil. -func (p *ConfigToDialer) NewStreamDialer(transportConfig string) (transport.StreamDialer, error) { - parts, err := parseConfig(transportConfig) - if err != nil { - return nil, err - } - return p.newStreamDialer(parts) -} - -// NewPacketDialer creates a [Dialer] according to transportConfig, using dialer as the -// base [Dialer]. The given dialer must not be nil. -func (p *ConfigToDialer) NewPacketDialer(transportConfig string) (transport.PacketDialer, error) { - parts, err := parseConfig(transportConfig) - if err != nil { - return nil, err - } - return p.newPacketDialer(parts) -} - -func (p *ConfigToDialer) newStreamDialer(configParts []*url.URL) (transport.StreamDialer, error) { - if len(configParts) == 0 { - if p.BaseStreamDialer == nil { - return nil, fmt.Errorf("base StreamDialer must not be nil") - } - return p.BaseStreamDialer, nil - } - thisURL := configParts[len(configParts)-1] - innerConfig := configParts[:len(configParts)-1] - newDialer, ok := p.sdBuilders[thisURL.Scheme] - if !ok { - return nil, fmt.Errorf("config scheme '%v' is not supported for Stream Dialers", thisURL.Scheme) - } - newSD := func() (transport.StreamDialer, error) { - return p.newStreamDialer(innerConfig) - } - newPD := func() (transport.PacketDialer, error) { - return p.newPacketDialer(innerConfig) - } - return newDialer(newSD, newPD, thisURL) -} - -func (p *ConfigToDialer) newPacketDialer(configParts []*url.URL) (transport.PacketDialer, error) { - if len(configParts) == 0 { - if p.BasePacketDialer == nil { - return nil, fmt.Errorf("base PacketDialer must not be nil") - } - return p.BasePacketDialer, nil - } - thisURL := configParts[len(configParts)-1] - innerConfig := configParts[:len(configParts)-1] - newDialer, ok := p.pdBuilders[thisURL.Scheme] - if !ok { - return nil, fmt.Errorf("config scheme '%v' is not supported for Packet Dialers", thisURL.Scheme) - } - newSD := func() (transport.StreamDialer, error) { - return p.newStreamDialer(innerConfig) - } - newPD := func() (transport.PacketDialer, error) { - return p.newPacketDialer(innerConfig) - } - return newDialer(newSD, newPD, thisURL) -} - -// NewpacketListener creates a new [transport.PacketListener] according to the given config, -// the config must contain only one "ss://" segment. -// TODO: make NewPacketListener configurable. -func NewPacketListener(transportConfig string) (transport.PacketListener, error) { - parts, err := parseConfig(transportConfig) - if err != nil { - return nil, err - } - if len(parts) == 0 { - return nil, errors.New("config is required") - } - if len(parts) > 1 { - return nil, errors.New("multi-part config is not supported") - } - - url := parts[0] - // Please keep scheme list sorted. - switch strings.ToLower(url.Scheme) { - case "ss": - // TODO: support nested dialer, the last part must be "ss://" - return newShadowsocksPacketListenerFromURL(url) - default: - return nil, fmt.Errorf("config scheme '%v' is not supported", url.Scheme) - } -} - -func SanitizeConfig(transportConfig string) (string, error) { - parts, err := parseConfig(transportConfig) - if err != nil { - return "", err - } - - // Do nothing if the config is empty - if len(parts) == 0 { - return "", nil - } - - // Iterate through each part - textParts := make([]string, len(parts)) - for i, u := range parts { - scheme := strings.ToLower(u.Scheme) - switch scheme { - case "ss": - textParts[i], err = sanitizeShadowsocksURL(u) - if err != nil { - return "", err - } - case "socks5": - textParts[i], err = sanitizeSocks5URL(u) - if err != nil { - return "", err - } - case "override", "split", "tls", "tlsfrag": - // No sanitization needed - textParts[i] = u.String() - default: - textParts[i] = scheme + "://UNKNOWN" - } - } - // Join the parts back into a string - return strings.Join(textParts, "|"), nil -} - -func sanitizeSocks5URL(u *url.URL) (string, error) { - const redactedPlaceholder = "REDACTED" - if u.User != nil { - u.User = url.User(redactedPlaceholder) - return u.String(), nil + config = &Config{URL: *url, BaseConfig: config} } - return u.String(), nil + return config, nil } diff --git a/x/configurl/dns.go b/x/configurl/dns.go index 205dc8fd..2e41ff59 100644 --- a/x/configurl/dns.go +++ b/x/configurl/dns.go @@ -27,16 +27,46 @@ import ( "golang.org/x/net/dns/dnsmessage" ) -func wrapStreamDialerWithDO53(innerSD func() (transport.StreamDialer, error), innerPD func() (transport.PacketDialer, error), configURL *url.URL) (transport.StreamDialer, error) { - sd, err := innerSD() - if err != nil { - return nil, err - } - pd, err := innerPD() - if err != nil { - return nil, err - } - query := configURL.Opaque +func registerDO53StreamDialer(r TypeRegistry[transport.StreamDialer], typeID string, newSD BuildFunc[transport.StreamDialer], newPD BuildFunc[transport.PacketDialer]) { + r.RegisterType(typeID, func(ctx context.Context, config *Config) (transport.StreamDialer, error) { + if config == nil { + return nil, fmt.Errorf("emtpy do53 config") + } + sd, err := newSD(ctx, config.BaseConfig) + if err != nil { + return nil, err + } + pd, err := newPD(ctx, config.BaseConfig) + if err != nil { + return nil, err + } + resolver, err := newDO53Resolver(config.URL, sd, pd) + if err != nil { + return nil, err + } + return dns.NewStreamDialer(resolver, sd) + }) +} + +func registerDOHStreamDialer(r TypeRegistry[transport.StreamDialer], typeID string, newSD BuildFunc[transport.StreamDialer]) { + r.RegisterType(typeID, func(ctx context.Context, config *Config) (transport.StreamDialer, error) { + if config == nil { + return nil, fmt.Errorf("emtpy doh config") + } + sd, err := newSD(ctx, config.BaseConfig) + if err != nil { + return nil, err + } + resolver, err := newDOHResolver(config.URL, sd) + if err != nil { + return nil, err + } + return dns.NewStreamDialer(resolver, sd) + }) +} + +func newDO53Resolver(config url.URL, sd transport.StreamDialer, pd transport.PacketDialer) (dns.Resolver, error) { + query := config.Opaque values, err := url.ParseQuery(query) if err != nil { return nil, err @@ -75,19 +105,15 @@ func wrapStreamDialerWithDO53(innerSD func() (transport.StreamDialer, error), in // See https://datatracker.ietf.org/doc/html/rfc1123#page-75. return tcpResolver.Query(ctx, q) }) - return dns.NewStreamDialer(resolver, sd) + return resolver, nil } -func wrapStreamDialerWithDOH(innerSD func() (transport.StreamDialer, error), innerPD func() (transport.PacketDialer, error), configURL *url.URL) (transport.StreamDialer, error) { - query := configURL.Opaque +func newDOHResolver(config url.URL, sd transport.StreamDialer) (dns.Resolver, error) { + query := config.Opaque values, err := url.ParseQuery(query) if err != nil { return nil, err } - sd, err := innerSD() - if err != nil { - return nil, err - } var name, address string for key, values := range values { @@ -119,6 +145,5 @@ func wrapStreamDialerWithDOH(innerSD func() (transport.StreamDialer, error), inn port = "443" } dohURL := url.URL{Scheme: "https", Host: net.JoinHostPort(name, port), Path: "/dns-query"} - resolver := dns.NewHTTPSResolver(sd, address, dohURL.String()) - return dns.NewStreamDialer(resolver, sd) + return dns.NewHTTPSResolver(sd, address, dohURL.String()), nil } diff --git a/x/configurl/doc.go b/x/configurl/doc.go index 70cfd6b4..dfee54c6 100644 --- a/x/configurl/doc.go +++ b/x/configurl/doc.go @@ -13,11 +13,11 @@ // limitations under the License. /* -Package config provides convenience functions to create dialer objects based on a text config. +Package configurl provides convenience functions to create network objects based on a text config. This is experimental and mostly for illustrative purposes at this point. -Configurable transports simplifies the way you create and manage transports. -With the config package, you can use [NewPacketDialer] and [NewStreamDialer] to create dialers using a simple text string. +Configurable strategies simplifies the way you create and manage strategies. +With the configurl package, you can use [ProviderContainer.NewPacketDialer], [ProviderContainer.NewStreamDialer] and [ProviderContainer.NewPacketListener] to create objects using a simple text string. Key Benefits: @@ -129,19 +129,23 @@ DPI Evasion - To add packet splitting to a Shadowsocks server for enhanced DPI e split:2|ss://[USERINFO]@[HOST]:[PORT] -Defining custom transport - You can define your custom transport by implementing and registering the [NewStreamDialerFunc] and [NewPacketDialerFunc] functions: +Defining custom strategies - You can define your custom strategy by implementing and registering [BuildFunc[ObjectType]] functions: - // create new config parser - // p := new(ConfigToDialer) + // Create new config parser. + // p := configurl.NewProviderContainer // or - p := NewDefaultConfigToDialer() - // register your custom dialer - p.RegisterPacketDialerWrapper("custom", wrapStreamDialerWithCustom) - p.RegisterStreamDialerWrapper("custom", wrapPacketDialerWithCustom) - // then use it - dialer, err := p.NewStreamDialer(innerDialer, "custom://config") - -where wrapStreamDialerWithCustom and wrapPacketDialerWithCustom implement [NewPacketDialerFunc] and [NewStreamDialerFunc]. + p := configurl.NewDefaultProviders() + // Register your custom dialer. + p.StreamDialers.RegisterType("custom", func(ctx context.Context, config *Config) (transport.StreamDialer, error) { + // Build logic + // ... + }) + p.PacketDialers.RegisterType("custom", func(ctx context.Context, config *Config) (transport.PacketDialer, error) { + // Build logic + // ... + }) + // Then use it + dialer, err := p.NewStreamDialer(context.Background(), "custom://config") [Onion Routing]: https://en.wikipedia.org/wiki/Onion_routing */ diff --git a/x/configurl/module.go b/x/configurl/module.go new file mode 100644 index 00000000..83e14b89 --- /dev/null +++ b/x/configurl/module.go @@ -0,0 +1,153 @@ +// Copyright 2024 The Outline Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package configurl + +import ( + "context" + "net/url" + "strings" + + "github.com/Jigsaw-Code/outline-sdk/transport" +) + +// ProviderContainer contains providers for the creation of network objects based on a config. The config is +// extensible by registering providers for different config subtypes. +type ProviderContainer struct { + StreamDialers ExtensibleProvider[transport.StreamDialer] + PacketDialers ExtensibleProvider[transport.PacketDialer] + PacketListeners ExtensibleProvider[transport.PacketListener] +} + +// NewProviderContainer creates a [ProviderContainer] with the base instances properly initialized. +func NewProviderContainer() *ProviderContainer { + return &ProviderContainer{ + StreamDialers: NewExtensibleProvider[transport.StreamDialer](&transport.TCPDialer{}), + PacketDialers: NewExtensibleProvider[transport.PacketDialer](&transport.UDPDialer{}), + PacketListeners: NewExtensibleProvider[transport.PacketListener](&transport.UDPListener{}), + } +} + +// RegisterDefaultProviders registers a set of default providers with the providers in [ProviderContainer]. +func RegisterDefaultProviders(c *ProviderContainer) *ProviderContainer { + // Please keep the list in alphabetical order. + registerDO53StreamDialer(&c.StreamDialers, "do53", c.StreamDialers.NewInstance, c.PacketDialers.NewInstance) + registerDOHStreamDialer(&c.StreamDialers, "doh", c.StreamDialers.NewInstance) + + registerOverrideStreamDialer(&c.StreamDialers, "override", c.StreamDialers.NewInstance) + registerOverridePacketDialer(&c.PacketDialers, "override", c.PacketDialers.NewInstance) + + registerSOCKS5StreamDialer(&c.StreamDialers, "socks5", c.StreamDialers.NewInstance) + registerSOCKS5PacketDialer(&c.PacketDialers, "socks5", c.StreamDialers.NewInstance, c.PacketDialers.NewInstance) + registerSOCKS5PacketListener(&c.PacketListeners, "socks5", c.StreamDialers.NewInstance, c.PacketDialers.NewInstance) + + registerSplitStreamDialer(&c.StreamDialers, "split", c.StreamDialers.NewInstance) + + registerShadowsocksStreamDialer(&c.StreamDialers, "ss", c.StreamDialers.NewInstance) + registerShadowsocksPacketDialer(&c.PacketDialers, "ss", c.PacketDialers.NewInstance) + registerShadowsocksPacketListener(&c.PacketListeners, "ss", c.PacketDialers.NewInstance) + + registerTLSStreamDialer(&c.StreamDialers, "tls", c.StreamDialers.NewInstance) + + registerTLSFragStreamDialer(&c.StreamDialers, "tlsfrag", c.StreamDialers.NewInstance) + + registerWebsocketStreamDialer(&c.StreamDialers, "ws", c.StreamDialers.NewInstance) + registerWebsocketPacketDialer(&c.PacketDialers, "ws", c.StreamDialers.NewInstance) + + return c +} + +// NewDefaultProviders creates a [ProviderContainer] with a set of default providers already registered. +func NewDefaultProviders() *ProviderContainer { + return RegisterDefaultProviders(NewProviderContainer()) +} + +// NewStreamDialer creates a [transport.StreamDialer] according to the config text. +func (p *ProviderContainer) NewStreamDialer(ctx context.Context, configText string) (transport.StreamDialer, error) { + config, err := ParseConfig(configText) + if err != nil { + return nil, err + } + return p.StreamDialers.NewInstance(ctx, config) +} + +// NewPacketDialer creates a [transport.PacketDialer] according to the config text. +func (p *ProviderContainer) NewPacketDialer(ctx context.Context, configText string) (transport.PacketDialer, error) { + config, err := ParseConfig(configText) + if err != nil { + return nil, err + } + return p.PacketDialers.NewInstance(ctx, config) +} + +// NewPacketListner creates a [transport.PacketListener] according to the config text. +func (p *ProviderContainer) NewPacketListener(ctx context.Context, configText string) (transport.PacketListener, error) { + config, err := ParseConfig(configText) + if err != nil { + return nil, err + } + return p.PacketListeners.NewInstance(ctx, config) +} + +// SanitizeConfig removes sensitive information from the given config so it can be safely be used in logging and debugging. +func SanitizeConfig(configStr string) (string, error) { + config, err := ParseConfig(configStr) + if err != nil { + return "", err + } + + // Do nothing if the config is empty + if config == nil { + return "", nil + } + + var sanitized string + for config != nil { + var part string + scheme := strings.ToLower(config.URL.Scheme) + switch scheme { + case "ss": + part, err = sanitizeShadowsocksURL(config.URL) + if err != nil { + return "", err + } + case "socks5": + part, err = sanitizeSOCKS5URL(&config.URL) + if err != nil { + return "", err + } + case "override", "split", "tls", "tlsfrag": + // No sanitization needed + part = config.URL.String() + default: + part = scheme + "://UNKNOWN" + } + if sanitized == "" { + sanitized = part + } else { + sanitized = part + "|" + sanitized + } + config = config.BaseConfig + } + return sanitized, nil +} + +func sanitizeSOCKS5URL(u *url.URL) (string, error) { + const redactedPlaceholder = "REDACTED" + if u.User != nil { + u.User = url.User(redactedPlaceholder) + return u.String(), nil + } + return u.String(), nil +} diff --git a/x/configurl/config_test.go b/x/configurl/module_test.go similarity index 100% rename from x/configurl/config_test.go rename to x/configurl/module_test.go diff --git a/x/configurl/override.go b/x/configurl/override.go index 3f34ea63..743aae31 100644 --- a/x/configurl/override.go +++ b/x/configurl/override.go @@ -24,7 +24,47 @@ import ( "github.com/Jigsaw-Code/outline-sdk/transport" ) -func newOverrideFromURL(configURL *url.URL) (func(string) (string, error), error) { +func registerOverrideStreamDialer(r TypeRegistry[transport.StreamDialer], typeID string, newSD BuildFunc[transport.StreamDialer]) { + r.RegisterType(typeID, func(ctx context.Context, config *Config) (transport.StreamDialer, error) { + sd, err := newSD(ctx, config.BaseConfig) + if err != nil { + return nil, err + } + override, err := newOverrideFromURL(config.URL) + if err != nil { + return nil, err + } + return transport.FuncStreamDialer(func(ctx context.Context, addr string) (transport.StreamConn, error) { + addr, err := override(addr) + if err != nil { + return nil, err + } + return sd.DialStream(ctx, addr) + }), nil + }) +} + +func registerOverridePacketDialer(r TypeRegistry[transport.PacketDialer], typeID string, newPD BuildFunc[transport.PacketDialer]) { + r.RegisterType(typeID, func(ctx context.Context, config *Config) (transport.PacketDialer, error) { + pd, err := newPD(ctx, config.BaseConfig) + if err != nil { + return nil, err + } + override, err := newOverrideFromURL(config.URL) + if err != nil { + return nil, err + } + return transport.FuncPacketDialer(func(ctx context.Context, addr string) (net.Conn, error) { + addr, err := override(addr) + if err != nil { + return nil, err + } + return pd.DialPacket(ctx, addr) + }), nil + }) +} + +func newOverrideFromURL(configURL url.URL) (func(string) (string, error), error) { query := configURL.Opaque values, err := url.ParseQuery(query) if err != nil { @@ -65,39 +105,3 @@ func newOverrideFromURL(configURL *url.URL) (func(string) (string, error), error return net.JoinHostPort(host, port), nil }, nil } - -func wrapStreamDialerWithOverride(innerSD func() (transport.StreamDialer, error), innerPD func() (transport.PacketDialer, error), configURL *url.URL) (transport.StreamDialer, error) { - sd, err := innerSD() - if err != nil { - return nil, err - } - override, err := newOverrideFromURL(configURL) - if err != nil { - return nil, err - } - return transport.FuncStreamDialer(func(ctx context.Context, addr string) (transport.StreamConn, error) { - addr, err := override(addr) - if err != nil { - return nil, err - } - return sd.DialStream(ctx, addr) - }), nil -} - -func wrapPacketDialerWithOverride(innerSD func() (transport.StreamDialer, error), innerPD func() (transport.PacketDialer, error), configURL *url.URL) (transport.PacketDialer, error) { - pd, err := innerPD() - if err != nil { - return nil, err - } - override, err := newOverrideFromURL(configURL) - if err != nil { - return nil, err - } - return transport.FuncPacketDialer(func(ctx context.Context, addr string) (net.Conn, error) { - addr, err := override(addr) - if err != nil { - return nil, err - } - return pd.DialPacket(ctx, addr) - }), nil -} diff --git a/x/configurl/override_test.go b/x/configurl/override_test.go index 0069de91..6a7e122e 100644 --- a/x/configurl/override_test.go +++ b/x/configurl/override_test.go @@ -25,7 +25,7 @@ func Test_newOverrideFromURL(t *testing.T) { t.Run("Host Override", func(t *testing.T) { cfgUrl, err := url.Parse("override:host=www.google.com") require.NoError(t, err) - override, err := newOverrideFromURL(cfgUrl) + override, err := newOverrideFromURL(*cfgUrl) require.NoError(t, err) addr, err := override("www.youtube.com:443") require.NoError(t, err) @@ -34,7 +34,7 @@ func Test_newOverrideFromURL(t *testing.T) { t.Run("Port Override", func(t *testing.T) { cfgUrl, err := url.Parse("override:port=853") require.NoError(t, err) - override, err := newOverrideFromURL(cfgUrl) + override, err := newOverrideFromURL(*cfgUrl) require.NoError(t, err) addr, err := override("8.8.8.8:53") require.NoError(t, err) @@ -43,7 +43,7 @@ func Test_newOverrideFromURL(t *testing.T) { t.Run("Full Override", func(t *testing.T) { cfgUrl, err := url.Parse("override:host=8.8.8.8&port=853") require.NoError(t, err) - override, err := newOverrideFromURL(cfgUrl) + override, err := newOverrideFromURL(*cfgUrl) require.NoError(t, err) addr, err := override("dns.google:53") require.NoError(t, err) @@ -53,7 +53,7 @@ func Test_newOverrideFromURL(t *testing.T) { t.Run("Host Override", func(t *testing.T) { cfgUrl, err := url.Parse("override:host=www.google.com") require.NoError(t, err) - override, err := newOverrideFromURL(cfgUrl) + override, err := newOverrideFromURL(*cfgUrl) require.NoError(t, err) _, err = override("foo bar") require.Error(t, err) @@ -61,7 +61,7 @@ func Test_newOverrideFromURL(t *testing.T) { t.Run("Full Override", func(t *testing.T) { cfgUrl, err := url.Parse("override:host=8.8.8.8&port=853") require.NoError(t, err) - override, err := newOverrideFromURL(cfgUrl) + override, err := newOverrideFromURL(*cfgUrl) require.NoError(t, err) addr, err := override("foo bar") require.NoError(t, err) diff --git a/x/configurl/shadowsocks.go b/x/configurl/shadowsocks.go index 246460f3..594483e8 100644 --- a/x/configurl/shadowsocks.go +++ b/x/configurl/shadowsocks.go @@ -15,6 +15,7 @@ package configurl import ( + "context" "encoding/base64" "errors" "fmt" @@ -25,52 +26,61 @@ import ( "github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks" ) -func wrapStreamDialerWithShadowsocks(innerSD func() (transport.StreamDialer, error), _ func() (transport.PacketDialer, error), configURL *url.URL) (transport.StreamDialer, error) { - sd, err := innerSD() - if err != nil { - return nil, err - } - config, err := parseShadowsocksURL(configURL) - if err != nil { - return nil, err - } - endpoint := &transport.StreamDialerEndpoint{Dialer: sd, Address: config.serverAddress} - dialer, err := shadowsocks.NewStreamDialer(endpoint, config.cryptoKey) - if err != nil { - return nil, err - } - if len(config.prefix) > 0 { - dialer.SaltGenerator = shadowsocks.NewPrefixSaltGenerator(config.prefix) - } - return dialer, nil +func registerShadowsocksStreamDialer(r TypeRegistry[transport.StreamDialer], typeID string, newSD BuildFunc[transport.StreamDialer]) { + r.RegisterType(typeID, func(ctx context.Context, config *Config) (transport.StreamDialer, error) { + sd, err := newSD(ctx, config.BaseConfig) + if err != nil { + return nil, err + } + ssConfig, err := parseShadowsocksURL(config.URL) + if err != nil { + return nil, err + } + endpoint := &transport.StreamDialerEndpoint{Dialer: sd, Address: ssConfig.serverAddress} + dialer, err := shadowsocks.NewStreamDialer(endpoint, ssConfig.cryptoKey) + if err != nil { + return nil, err + } + if len(ssConfig.prefix) > 0 { + dialer.SaltGenerator = shadowsocks.NewPrefixSaltGenerator(ssConfig.prefix) + } + return dialer, nil + }) } -func wrapPacketDialerWithShadowsocks(_ func() (transport.StreamDialer, error), innerPD func() (transport.PacketDialer, error), configURL *url.URL) (transport.PacketDialer, error) { - pd, err := innerPD() - if err != nil { - return nil, err - } - config, err := parseShadowsocksURL(configURL) - if err != nil { - return nil, err - } - endpoint := &transport.PacketDialerEndpoint{Dialer: pd, Address: config.serverAddress} - listener, err := shadowsocks.NewPacketListener(endpoint, config.cryptoKey) - if err != nil { - return nil, err - } - dialer := transport.PacketListenerDialer{Listener: listener} - return dialer, nil +func registerShadowsocksPacketDialer(r TypeRegistry[transport.PacketDialer], typeID string, newPD BuildFunc[transport.PacketDialer]) { + r.RegisterType(typeID, func(ctx context.Context, config *Config) (transport.PacketDialer, error) { + pd, err := newPD(ctx, config.BaseConfig) + if err != nil { + return nil, err + } + ssConfig, err := parseShadowsocksURL(config.URL) + if err != nil { + return nil, err + } + endpoint := &transport.PacketDialerEndpoint{Dialer: pd, Address: ssConfig.serverAddress} + pl, err := shadowsocks.NewPacketListener(endpoint, ssConfig.cryptoKey) + if err != nil { + return nil, err + } + // TODO: support UDP prefix. + return transport.PacketListenerDialer{Listener: pl}, nil + }) } -func newShadowsocksPacketListenerFromURL(configURL *url.URL) (transport.PacketListener, error) { - config, err := parseShadowsocksURL(configURL) - if err != nil { - return nil, err - } - // TODO: accept an inner dialer from the caller and pass it to UDPEndpoint - ep := &transport.UDPEndpoint{Address: config.serverAddress} - return shadowsocks.NewPacketListener(ep, config.cryptoKey) +func registerShadowsocksPacketListener(r TypeRegistry[transport.PacketListener], typeID string, newPD BuildFunc[transport.PacketDialer]) { + r.RegisterType(typeID, func(ctx context.Context, config *Config) (transport.PacketListener, error) { + pd, err := newPD(ctx, config.BaseConfig) + if err != nil { + return nil, err + } + ssConfig, err := parseShadowsocksURL(config.URL) + if err != nil { + return nil, err + } + endpoint := &transport.PacketDialerEndpoint{Dialer: pd, Address: ssConfig.serverAddress} + return shadowsocks.NewPacketListener(endpoint, ssConfig.cryptoKey) + }) } type shadowsocksConfig struct { @@ -79,7 +89,7 @@ type shadowsocksConfig struct { prefix []byte } -func parseShadowsocksURL(url *url.URL) (*shadowsocksConfig, error) { +func parseShadowsocksURL(url url.URL) (*shadowsocksConfig, error) { // attempt to decode as SIP002 URI format and // fall back to legacy base64 format if decoding fails config, err := parseShadowsocksSIP002URL(url) @@ -91,7 +101,7 @@ func parseShadowsocksURL(url *url.URL) (*shadowsocksConfig, error) { // parseShadowsocksLegacyBase64URL parses URL based on legacy base64 format: // https://shadowsocks.org/doc/configs.html#uri-and-qr-code -func parseShadowsocksLegacyBase64URL(url *url.URL) (*shadowsocksConfig, error) { +func parseShadowsocksLegacyBase64URL(url url.URL) (*shadowsocksConfig, error) { config := &shadowsocksConfig{} if url.Host == "" { return nil, errors.New("host not specified") @@ -138,7 +148,7 @@ func parseShadowsocksLegacyBase64URL(url *url.URL) (*shadowsocksConfig, error) { // parseShadowsocksSIP002URL parses URL based on SIP002 format: // https://shadowsocks.org/doc/sip002.html -func parseShadowsocksSIP002URL(url *url.URL) (*shadowsocksConfig, error) { +func parseShadowsocksSIP002URL(url url.URL) (*shadowsocksConfig, error) { config := &shadowsocksConfig{} if url.Host == "" { return nil, errors.New("host not specified") @@ -188,7 +198,7 @@ func parseStringPrefix(utf8Str string) ([]byte, error) { return rawBytes, nil } -func sanitizeShadowsocksURL(u *url.URL) (string, error) { +func sanitizeShadowsocksURL(u url.URL) (string, error) { config, err := parseShadowsocksURL(u) if err != nil { return "", err diff --git a/x/configurl/shadowsocks_test.go b/x/configurl/shadowsocks_test.go index 093574a2..e24245c5 100644 --- a/x/configurl/shadowsocks_test.go +++ b/x/configurl/shadowsocks_test.go @@ -25,7 +25,7 @@ import ( func Test_sanitizeShadowsocksURL(t *testing.T) { ssURL, err := url.Parse("ss://YWVzLTEyOC1nY206dGVzdA@192.168.100.1:8888") require.NoError(t, err) - sanitized, err := sanitizeShadowsocksURL(ssURL) + sanitized, err := sanitizeShadowsocksURL(*ssURL) require.NoError(t, err) require.Equal(t, "ss://REDACTED@192.168.100.1:8888", sanitized) } @@ -33,116 +33,115 @@ func Test_sanitizeShadowsocksURL(t *testing.T) { func Test_sanitizeShadowsocksURL_withPrefix(t *testing.T) { ssURL, err := url.Parse("ss://YWVzLTEyOC1nY206dGVzdA@192.168.100.1:8888?prefix=foo") require.NoError(t, err) - sanitized, err := sanitizeShadowsocksURL(ssURL) + sanitized, err := sanitizeShadowsocksURL(*ssURL) require.NoError(t, err) require.Equal(t, "ss://REDACTED@192.168.100.1:8888?prefix=foo", sanitized) } func TestParseShadowsocksURLFullyEncoded(t *testing.T) { encoded := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString([]byte("aes-256-gcm:1234567@example.com:1234?prefix=HTTP%2F1.1%20")) - urls, err := parseConfig("ss://" + string(encoded) + "#outline-123") + config, err := ParseConfig("ss://" + string(encoded) + "#outline-123") require.NoError(t, err) - require.Equal(t, 1, len(urls)) + require.Nil(t, config.BaseConfig) - config, err := parseShadowsocksURL(urls[0]) + ssConfig, err := parseShadowsocksURL(config.URL) require.NoError(t, err) - require.Equal(t, "example.com:1234", config.serverAddress) - require.Equal(t, "HTTP/1.1 ", string(config.prefix)) + require.Equal(t, "example.com:1234", ssConfig.serverAddress) + require.Equal(t, "HTTP/1.1 ", string(ssConfig.prefix)) } func TestParseShadowsocksURLUserInfoEncoded(t *testing.T) { encoded := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString([]byte("aes-256-gcm:1234567")) - urls, err := parseConfig("ss://" + string(encoded) + "@example.com:1234?prefix=HTTP%2F1.1%20" + "#outline-123") + config, err := ParseConfig("ss://" + string(encoded) + "@example.com:1234?prefix=HTTP%2F1.1%20" + "#outline-123") require.NoError(t, err) - require.Equal(t, 1, len(urls)) + require.Nil(t, config.BaseConfig) - config, err := parseShadowsocksURL(urls[0]) + ssConfig, err := parseShadowsocksURL(config.URL) require.NoError(t, err) - require.Equal(t, "example.com:1234", config.serverAddress) - require.Equal(t, "HTTP/1.1 ", string(config.prefix)) + require.Equal(t, "example.com:1234", ssConfig.serverAddress) + require.Equal(t, "HTTP/1.1 ", string(ssConfig.prefix)) } func TestParseShadowsocksURLUserInfoLegacyEncoded(t *testing.T) { encoded := base64.StdEncoding.EncodeToString([]byte("aes-256-gcm:shadowsocks")) - urls, err := parseConfig("ss://" + string(encoded) + "@example.com:1234?prefix=HTTP%2F1.1%20" + "#outline-123") + config, err := ParseConfig("ss://" + string(encoded) + "@example.com:1234?prefix=HTTP%2F1.1%20" + "#outline-123") require.NoError(t, err) - require.Equal(t, 1, len(urls)) + require.Nil(t, config.BaseConfig) - config, err := parseShadowsocksURL(urls[0]) + ssConfig, err := parseShadowsocksURL(config.URL) require.NoError(t, err) - require.Equal(t, "example.com:1234", config.serverAddress) - require.Equal(t, "HTTP/1.1 ", string(config.prefix)) + require.Equal(t, "example.com:1234", ssConfig.serverAddress) + require.Equal(t, "HTTP/1.1 ", string(ssConfig.prefix)) } func TestLegacyEncodedShadowsocksURL(t *testing.T) { configString := "ss://YWVzLTEyOC1nY206c2hhZG93c29ja3M=@example.com:1234" - urls, err := parseConfig(configString) + config, err := ParseConfig(configString) require.NoError(t, err) - require.Equal(t, 1, len(urls)) + require.Nil(t, config.BaseConfig) - config, err := parseShadowsocksURL(urls[0]) + ssConfig, err := parseShadowsocksURL(config.URL) require.NoError(t, err) - require.Equal(t, "example.com:1234", config.serverAddress) + require.Equal(t, "example.com:1234", ssConfig.serverAddress) } func TestParseShadowsocksURLNoEncoding(t *testing.T) { configString := "ss://aes-256-gcm:1234567@example.com:1234" - urls, err := parseConfig(configString) + config, err := ParseConfig(configString) require.NoError(t, err) - require.Equal(t, 1, len(urls)) + require.Nil(t, config.BaseConfig) - config, err := parseShadowsocksURL(urls[0]) + ssConfig, err := parseShadowsocksURL(config.URL) require.NoError(t, err) - require.Equal(t, "example.com:1234", config.serverAddress) + require.Equal(t, "example.com:1234", ssConfig.serverAddress) } func TestParseShadowsocksURLInvalidCipherInfoFails(t *testing.T) { configString := "ss://aes-256-gcm1234567@example.com:1234" - urls, err := parseConfig(configString) + config, err := ParseConfig(configString) require.NoError(t, err) - require.Equal(t, 1, len(urls)) + require.Nil(t, config.BaseConfig) - _, err = parseShadowsocksURL(urls[0]) + _, err = parseShadowsocksURL(config.URL) require.Error(t, err) } func TestParseShadowsocksURLUnsupportedCypherFails(t *testing.T) { configString := "ss://Y2hhY2hhMjAtaWV0Zi1wb2x5MTMwnTpLeTUyN2duU3FEVFB3R0JpQ1RxUnlT@example.com:1234" - urls, err := parseConfig(configString) + config, err := ParseConfig(configString) require.NoError(t, err) - require.Equal(t, 1, len(urls)) + require.Nil(t, config.BaseConfig) - _, err = parseShadowsocksURL(urls[0]) + _, err = parseShadowsocksURL(config.URL) require.Error(t, err) } func TestParseShadowsocksLegacyBase64URL(t *testing.T) { encoded := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString([]byte("aes-256-gcm:1234567@example.com:1234?prefix=HTTP%2F1.1%20")) - urls, err := parseConfig("ss://" + string(encoded) + "#outline-123") + config, err := ParseConfig("ss://" + string(encoded) + "#outline-123") require.NoError(t, err) - require.Equal(t, 1, len(urls)) + require.Nil(t, config.BaseConfig) - config, err := parseShadowsocksLegacyBase64URL(urls[0]) + ssConfig, err := parseShadowsocksLegacyBase64URL(config.URL) require.NoError(t, err) - require.Equal(t, "example.com:1234", config.serverAddress) - require.Equal(t, "HTTP/1.1 ", string(config.prefix)) + require.Equal(t, "example.com:1234", ssConfig.serverAddress) + require.Equal(t, "HTTP/1.1 ", string(ssConfig.prefix)) } func TestParseShadowsocksSIP002URLUnsuccessful(t *testing.T) { encoded := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString([]byte("aes-256-gcm:1234567@example.com:1234?prefix=HTTP%2F1.1%20")) - urls, err := parseConfig("ss://" + string(encoded) + "#outline-123") + config, err := ParseConfig("ss://" + string(encoded) + "#outline-123") require.NoError(t, err) - require.Equal(t, 1, len(urls)) + require.Nil(t, config.BaseConfig) - _, err = parseShadowsocksSIP002URL(urls[0]) - - require.Error(t, err) + _, err = parseShadowsocksSIP002URL(config.URL) + require.Error(t, err, "URL is %v", config.URL.String()) } diff --git a/x/configurl/socks5.go b/x/configurl/socks5.go index 8fc6e238..cbb8de51 100644 --- a/x/configurl/socks5.go +++ b/x/configurl/socks5.go @@ -15,46 +15,59 @@ package configurl import ( - "net/url" + "context" "github.com/Jigsaw-Code/outline-sdk/transport" "github.com/Jigsaw-Code/outline-sdk/transport/socks5" ) -func wrapStreamDialerWithSOCKS5(innerSD func() (transport.StreamDialer, error), _ func() (transport.PacketDialer, error), configURL *url.URL) (transport.StreamDialer, error) { - sd, err := innerSD() - if err != nil { - return nil, err - } - endpoint := transport.StreamDialerEndpoint{Dialer: sd, Address: configURL.Host} - client, err := socks5.NewClient(&endpoint) - if err != nil { - return nil, err - } - userInfo := configURL.User - if userInfo != nil { - username := userInfo.Username() - password, _ := userInfo.Password() - err := client.SetCredentials([]byte(username), []byte(password)) +func registerSOCKS5StreamDialer(r TypeRegistry[transport.StreamDialer], typeID string, newSD BuildFunc[transport.StreamDialer]) { + r.RegisterType(typeID, func(ctx context.Context, config *Config) (transport.StreamDialer, error) { + return newSOCKS5Client(ctx, *config, newSD) + }) +} + +func registerSOCKS5PacketDialer(r TypeRegistry[transport.PacketDialer], typeID string, newSD BuildFunc[transport.StreamDialer], newPD BuildFunc[transport.PacketDialer]) { + r.RegisterType(typeID, func(ctx context.Context, config *Config) (transport.PacketDialer, error) { + client, err := newSOCKS5Client(ctx, *config, newSD) if err != nil { return nil, err } - } + pd, err := newPD(ctx, config.BaseConfig) + if err != nil { + return nil, err + } + client.EnablePacket(pd) + return transport.PacketListenerDialer{Listener: client}, nil + }) +} - return client, nil +func registerSOCKS5PacketListener(r TypeRegistry[transport.PacketListener], typeID string, newSD BuildFunc[transport.StreamDialer], newPD BuildFunc[transport.PacketDialer]) { + r.RegisterType(typeID, func(ctx context.Context, config *Config) (transport.PacketListener, error) { + client, err := newSOCKS5Client(ctx, *config, newSD) + if err != nil { + return nil, err + } + pd, err := newPD(ctx, config.BaseConfig) + if err != nil { + return nil, err + } + client.EnablePacket(pd) + return client, nil + }) } -func wrapPacketDialerWithSOCKS5(innerSD func() (transport.StreamDialer, error), innerPD func() (transport.PacketDialer, error), configURL *url.URL) (transport.PacketDialer, error) { - sd, err := innerSD() +func newSOCKS5Client(ctx context.Context, config Config, newSD BuildFunc[transport.StreamDialer]) (*socks5.Client, error) { + sd, err := newSD(ctx, config.BaseConfig) if err != nil { return nil, err } - streamEndpoint := transport.StreamDialerEndpoint{Dialer: sd, Address: configURL.Host} - client, err := socks5.NewClient(&streamEndpoint) + endpoint := transport.StreamDialerEndpoint{Dialer: sd, Address: config.URL.Host} + client, err := socks5.NewClient(&endpoint) if err != nil { return nil, err } - userInfo := configURL.User + userInfo := config.URL.User if userInfo != nil { username := userInfo.Username() password, _ := userInfo.Password() @@ -63,12 +76,5 @@ func wrapPacketDialerWithSOCKS5(innerSD func() (transport.StreamDialer, error), return nil, err } } - - pd, err := innerPD() - if err != nil { - return nil, err - } - client.EnablePacket(pd) - packetDialer := transport.PacketListenerDialer{Listener: client} - return packetDialer, nil + return client, nil } diff --git a/x/configurl/split.go b/x/configurl/split.go index feca4d75..e6d89b3a 100644 --- a/x/configurl/split.go +++ b/x/configurl/split.go @@ -15,23 +15,25 @@ package configurl import ( + "context" "fmt" - "net/url" "strconv" "github.com/Jigsaw-Code/outline-sdk/transport" "github.com/Jigsaw-Code/outline-sdk/transport/split" ) -func wrapStreamDialerWithSplit(innerSD func() (transport.StreamDialer, error), _ func() (transport.PacketDialer, error), configURL *url.URL) (transport.StreamDialer, error) { - sd, err := innerSD() - if err != nil { - return nil, err - } - prefixBytesStr := configURL.Opaque - prefixBytes, err := strconv.Atoi(prefixBytesStr) - if err != nil { - return nil, fmt.Errorf("prefixBytes is not a number: %v. Split config should be in split: format", prefixBytesStr) - } - return split.NewStreamDialer(sd, int64(prefixBytes)) +func registerSplitStreamDialer(r TypeRegistry[transport.StreamDialer], typeID string, newSD BuildFunc[transport.StreamDialer]) { + r.RegisterType(typeID, func(ctx context.Context, config *Config) (transport.StreamDialer, error) { + sd, err := newSD(ctx, config.BaseConfig) + if err != nil { + return nil, err + } + prefixBytesStr := config.URL.Opaque + prefixBytes, err := strconv.Atoi(prefixBytesStr) + if err != nil { + return nil, fmt.Errorf("prefixBytes is not a number: %v. Split config should be in split: format", prefixBytesStr) + } + return split.NewStreamDialer(sd, int64(prefixBytes)) + }) } diff --git a/x/configurl/tls.go b/x/configurl/tls.go index 2b29a1c3..7517cc43 100644 --- a/x/configurl/tls.go +++ b/x/configurl/tls.go @@ -15,6 +15,7 @@ package configurl import ( + "context" "fmt" "net/url" "strings" @@ -23,7 +24,21 @@ import ( "github.com/Jigsaw-Code/outline-sdk/transport/tls" ) -func parseOptions(configURL *url.URL) ([]tls.ClientOption, error) { +func registerTLSStreamDialer(r TypeRegistry[transport.StreamDialer], typeID string, newSD BuildFunc[transport.StreamDialer]) { + r.RegisterType(typeID, func(ctx context.Context, config *Config) (transport.StreamDialer, error) { + sd, err := newSD(ctx, config.BaseConfig) + if err != nil { + return nil, err + } + options, err := parseOptions(config.URL) + if err != nil { + return nil, err + } + return tls.NewStreamDialer(sd, options...) + }) +} + +func parseOptions(configURL url.URL) ([]tls.ClientOption, error) { query := configURL.Opaque values, err := url.ParseQuery(query) if err != nil { @@ -49,15 +64,3 @@ func parseOptions(configURL *url.URL) ([]tls.ClientOption, error) { } return options, nil } - -func wrapStreamDialerWithTLS(innerSD func() (transport.StreamDialer, error), _ func() (transport.PacketDialer, error), configURL *url.URL) (transport.StreamDialer, error) { - sd, err := innerSD() - if err != nil { - return nil, err - } - options, err := parseOptions(configURL) - if err != nil { - return nil, err - } - return tls.NewStreamDialer(sd, options...) -} diff --git a/x/configurl/tls_test.go b/x/configurl/tls_test.go index d9cdcc58..0aa59dc6 100644 --- a/x/configurl/tls_test.go +++ b/x/configurl/tls_test.go @@ -15,25 +15,16 @@ package configurl import ( - "net/url" "testing" - "github.com/Jigsaw-Code/outline-sdk/transport" "github.com/Jigsaw-Code/outline-sdk/transport/tls" "github.com/stretchr/testify/require" ) -func TestTLS(t *testing.T) { - tlsURL, err := url.Parse("tls") - require.NoError(t, err) - _, err = wrapStreamDialerWithTLS(func() (transport.StreamDialer, error) { return &transport.TCPDialer{}, nil }, nil, tlsURL) - require.NoError(t, err) -} - func TestTLS_SNI(t *testing.T) { - tlsURL, err := url.Parse("tls:sni=www.google.com") + config, err := ParseConfig("tls:sni=www.google.com") require.NoError(t, err) - options, err := parseOptions(tlsURL) + options, err := parseOptions(config.URL) require.NoError(t, err) cfg := tls.ClientConfig{ServerName: "host", CertificateName: "host"} for _, option := range options { @@ -44,9 +35,9 @@ func TestTLS_SNI(t *testing.T) { } func TestTLS_NoSNI(t *testing.T) { - tlsURL, err := url.Parse("tls:sni=") + config, err := ParseConfig("tls:sni=") require.NoError(t, err) - options, err := parseOptions(tlsURL) + options, err := parseOptions(config.URL) require.NoError(t, err) cfg := tls.ClientConfig{ServerName: "host", CertificateName: "host"} for _, option := range options { @@ -57,16 +48,16 @@ func TestTLS_NoSNI(t *testing.T) { } func TestTLS_MultipleSNI(t *testing.T) { - tlsURL, err := url.Parse("tls:sni=www.google.com&sni=second") + config, err := ParseConfig("tls:sni=www.google.com&sni=second") require.NoError(t, err) - _, err = parseOptions(tlsURL) + _, err = parseOptions(config.URL) require.Error(t, err) } func TestTLS_CertName(t *testing.T) { - tlsURL, err := url.Parse("tls:certname=www.google.com") + config, err := ParseConfig("tls:certname=www.google.com") require.NoError(t, err) - options, err := parseOptions(tlsURL) + options, err := parseOptions(config.URL) require.NoError(t, err) cfg := tls.ClientConfig{ServerName: "host", CertificateName: "host"} for _, option := range options { @@ -77,9 +68,9 @@ func TestTLS_CertName(t *testing.T) { } func TestTLS_Combined(t *testing.T) { - tlsURL, err := url.Parse("tls:SNI=sni.example.com&CertName=certname.example.com") + config, err := ParseConfig("tls:SNI=sni.example.com&CertName=certname.example.com") require.NoError(t, err) - options, err := parseOptions(tlsURL) + options, err := parseOptions(config.URL) require.NoError(t, err) cfg := tls.ClientConfig{ServerName: "host", CertificateName: "host"} for _, option := range options { @@ -90,8 +81,8 @@ func TestTLS_Combined(t *testing.T) { } func TestTLS_UnsupportedOption(t *testing.T) { - tlsURL, err := url.Parse("tls:unsupported") + config, err := ParseConfig("tls:unsupported") require.NoError(t, err) - _, err = parseOptions(tlsURL) + _, err = parseOptions(config.URL) require.Error(t, err) } diff --git a/x/configurl/tlsfrag.go b/x/configurl/tlsfrag.go new file mode 100644 index 00000000..4e2e0fe9 --- /dev/null +++ b/x/configurl/tlsfrag.go @@ -0,0 +1,39 @@ +// Copyright 2024 The Outline Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package configurl + +import ( + "context" + "fmt" + "strconv" + + "github.com/Jigsaw-Code/outline-sdk/transport" + "github.com/Jigsaw-Code/outline-sdk/transport/tlsfrag" +) + +func registerTLSFragStreamDialer(r TypeRegistry[transport.StreamDialer], typeID string, newSD BuildFunc[transport.StreamDialer]) { + r.RegisterType(typeID, func(ctx context.Context, config *Config) (transport.StreamDialer, error) { + sd, err := newSD(ctx, config.BaseConfig) + if err != nil { + return nil, err + } + lenStr := config.URL.Opaque + fixedLen, err := strconv.Atoi(lenStr) + if err != nil { + return nil, fmt.Errorf("invalid tlsfrag option: %v. It should be in tlsfrag: format", lenStr) + } + return tlsfrag.NewFixedLenStreamDialer(sd, fixedLen) + }) +} diff --git a/x/configurl/websocket.go b/x/configurl/websocket.go index 49861d15..31ef95e4 100644 --- a/x/configurl/websocket.go +++ b/x/configurl/websocket.go @@ -31,7 +31,7 @@ type wsConfig struct { udpPath string } -func parseWSConfig(configURL *url.URL) (*wsConfig, error) { +func parseWSConfig(configURL url.URL) (*wsConfig, error) { query := configURL.Opaque values, err := url.ParseQuery(query) if err != nil { @@ -71,66 +71,70 @@ func (c *wsToStreamConn) CloseWrite() error { return c.Close() } -func wrapStreamDialerWithWebSocket(innerSD func() (transport.StreamDialer, error), _ func() (transport.PacketDialer, error), configURL *url.URL) (transport.StreamDialer, error) { - sd, err := innerSD() - if err != nil { - return nil, err - } - config, err := parseWSConfig(configURL) - if err != nil { - return nil, err - } - if config.tcpPath == "" { - return nil, errors.New("must specify tcp_path") - } - return transport.FuncStreamDialer(func(ctx context.Context, addr string) (transport.StreamConn, error) { - wsURL := url.URL{Scheme: "ws", Host: addr, Path: config.tcpPath} - origin := url.URL{Scheme: "http", Host: addr} - wsCfg, err := websocket.NewConfig(wsURL.String(), origin.String()) +func registerWebsocketStreamDialer(r TypeRegistry[transport.StreamDialer], typeID string, newSD BuildFunc[transport.StreamDialer]) { + r.RegisterType(typeID, func(ctx context.Context, config *Config) (transport.StreamDialer, error) { + sd, err := newSD(ctx, config.BaseConfig) if err != nil { - return nil, fmt.Errorf("failed to create websocket config: %w", err) + return nil, err } - baseConn, err := sd.DialStream(ctx, addr) + wsConfig, err := parseWSConfig(config.URL) if err != nil { - return nil, fmt.Errorf("failed to connect to websocket endpoint: %w", err) + return nil, err } - wsConn, err := websocket.NewClient(wsCfg, baseConn) - if err != nil { - baseConn.Close() - return nil, fmt.Errorf("failed to create websocket client: %w", err) + if wsConfig.tcpPath == "" { + return nil, errors.New("must specify tcp_path") } - return &wsToStreamConn{wsConn}, nil - }), nil + return transport.FuncStreamDialer(func(ctx context.Context, addr string) (transport.StreamConn, error) { + wsURL := url.URL{Scheme: "ws", Host: addr, Path: wsConfig.tcpPath} + origin := url.URL{Scheme: "http", Host: addr} + wsCfg, err := websocket.NewConfig(wsURL.String(), origin.String()) + if err != nil { + return nil, fmt.Errorf("failed to create websocket config: %w", err) + } + baseConn, err := sd.DialStream(ctx, addr) + if err != nil { + return nil, fmt.Errorf("failed to connect to websocket endpoint: %w", err) + } + wsConn, err := websocket.NewClient(wsCfg, baseConn) + if err != nil { + baseConn.Close() + return nil, fmt.Errorf("failed to create websocket client: %w", err) + } + return &wsToStreamConn{wsConn}, nil + }), nil + }) } -func wrapPacketDialerWithWebSocket(innerSD func() (transport.StreamDialer, error), _ func() (transport.PacketDialer, error), configURL *url.URL) (transport.PacketDialer, error) { - sd, err := innerSD() - if err != nil { - return nil, err - } - config, err := parseWSConfig(configURL) - if err != nil { - return nil, err - } - if config.udpPath == "" { - return nil, errors.New("must specify udp_path") - } - return transport.FuncPacketDialer(func(ctx context.Context, addr string) (net.Conn, error) { - wsURL := url.URL{Scheme: "ws", Host: addr, Path: config.udpPath} - origin := url.URL{Scheme: "http", Host: addr} - wsCfg, err := websocket.NewConfig(wsURL.String(), origin.String()) +func registerWebsocketPacketDialer(r TypeRegistry[transport.PacketDialer], typeID string, newSD BuildFunc[transport.StreamDialer]) { + r.RegisterType(typeID, func(ctx context.Context, config *Config) (transport.PacketDialer, error) { + sd, err := newSD(ctx, config.BaseConfig) if err != nil { - return nil, fmt.Errorf("failed to create websocket config: %w", err) + return nil, err } - baseConn, err := sd.DialStream(ctx, addr) + wsConfig, err := parseWSConfig(config.URL) if err != nil { - return nil, fmt.Errorf("failed to connect to websocket endpoint: %w", err) + return nil, err } - wsConn, err := websocket.NewClient(wsCfg, baseConn) - if err != nil { - baseConn.Close() - return nil, fmt.Errorf("failed to create websocket client: %w", err) + if wsConfig.udpPath == "" { + return nil, errors.New("must specify udp_path") } - return wsConn, nil - }), nil + return transport.FuncPacketDialer(func(ctx context.Context, addr string) (net.Conn, error) { + wsURL := url.URL{Scheme: "ws", Host: addr, Path: wsConfig.udpPath} + origin := url.URL{Scheme: "http", Host: addr} + wsCfg, err := websocket.NewConfig(wsURL.String(), origin.String()) + if err != nil { + return nil, fmt.Errorf("failed to create websocket config: %w", err) + } + baseConn, err := sd.DialStream(ctx, addr) + if err != nil { + return nil, fmt.Errorf("failed to connect to websocket endpoint: %w", err) + } + wsConn, err := websocket.NewClient(wsCfg, baseConn) + if err != nil { + baseConn.Close() + return nil, fmt.Errorf("failed to create websocket client: %w", err) + } + return wsConn, nil + }), nil + }) } diff --git a/x/examples/fetch-speed/main.go b/x/examples/fetch-speed/main.go index 7e8db297..f6b131a7 100644 --- a/x/examples/fetch-speed/main.go +++ b/x/examples/fetch-speed/main.go @@ -58,7 +58,7 @@ func main() { os.Exit(1) } - dialer, err := configurl.NewDefaultConfigToDialer().NewStreamDialer(*transportFlag) + dialer, err := configurl.NewDefaultProviders().NewStreamDialer(context.Background(), *transportFlag) if err != nil { log.Fatalf("Could not create dialer: %v\n", err) } diff --git a/x/examples/fetch/main.go b/x/examples/fetch/main.go index 634385a5..3b7d2d02 100644 --- a/x/examples/fetch/main.go +++ b/x/examples/fetch/main.go @@ -84,7 +84,7 @@ func main() { os.Exit(1) } - dialer, err := configurl.NewDefaultConfigToDialer().NewStreamDialer(*transportFlag) + dialer, err := configurl.NewDefaultProviders().NewStreamDialer(context.Background(), *transportFlag) if err != nil { log.Fatalf("Could not create dialer: %v\n", err) } diff --git a/x/examples/http2transport/main.go b/x/examples/http2transport/main.go index 72fb3dcf..a0489fae 100644 --- a/x/examples/http2transport/main.go +++ b/x/examples/http2transport/main.go @@ -34,7 +34,7 @@ func main() { urlProxyPrefixFlag := flag.String("urlProxyPrefix", "/proxy", "Path where to run the URL proxy. Set to empty (\"\") to disable it.") flag.Parse() - dialer, err := configurl.NewDefaultConfigToDialer().NewStreamDialer(*transportFlag) + dialer, err := configurl.NewDefaultProviders().NewStreamDialer(context.Background(), *transportFlag) if err != nil { log.Fatalf("Could not create dialer: %v", err) diff --git a/x/examples/outline-cli/outline_device.go b/x/examples/outline-cli/outline_device.go index 858e2bfa..61333bf8 100644 --- a/x/examples/outline-cli/outline_device.go +++ b/x/examples/outline-cli/outline_device.go @@ -15,6 +15,7 @@ package main import ( + "context" "errors" "fmt" "net" @@ -39,7 +40,7 @@ type OutlineDevice struct { svrIP net.IP } -var configToDialer = configurl.NewDefaultConfigToDialer() +var configModule = configurl.NewDefaultProviders() func NewOutlineDevice(transportConfig string) (od *OutlineDevice, err error) { ip, err := resolveShadowsocksServerIPFromConfig(transportConfig) @@ -50,7 +51,7 @@ func NewOutlineDevice(transportConfig string) (od *OutlineDevice, err error) { svrIP: ip, } - if od.sd, err = configToDialer.NewStreamDialer(transportConfig); err != nil { + if od.sd, err = configModule.NewStreamDialer(context.TODO(), transportConfig); err != nil { return nil, fmt.Errorf("failed to create TCP dialer: %w", err) } if od.pp, err = newOutlinePacketProxy(transportConfig); err != nil { diff --git a/x/examples/outline-cli/outline_packet_proxy.go b/x/examples/outline-cli/outline_packet_proxy.go index fe202b8e..7c0b314b 100644 --- a/x/examples/outline-cli/outline_packet_proxy.go +++ b/x/examples/outline-cli/outline_packet_proxy.go @@ -36,7 +36,7 @@ type outlinePacketProxy struct { func newOutlinePacketProxy(transportConfig string) (opp *outlinePacketProxy, err error) { opp = &outlinePacketProxy{} - if opp.remotePl, err = configurl.NewPacketListener(transportConfig); err != nil { + if opp.remotePl, err = configurl.NewDefaultProviders().NewPacketListener(context.TODO(), transportConfig); err != nil { return nil, fmt.Errorf("failed to create UDP packet listener: %w", err) } if opp.remote, err = network.NewPacketProxyFromPacketListener(opp.remotePl); err != nil { diff --git a/x/examples/resolve/main.go b/x/examples/resolve/main.go index 2055b620..e55dd790 100644 --- a/x/examples/resolve/main.go +++ b/x/examples/resolve/main.go @@ -66,15 +66,15 @@ func main() { resolverAddr := *resolverFlag var resolver dns.Resolver - configToDialer := configurl.NewDefaultConfigToDialer() + providers := configurl.NewDefaultProviders() if *tcpFlag { - streamDialer, err := configToDialer.NewStreamDialer(*transportFlag) + streamDialer, err := providers.NewStreamDialer(context.Background(), *transportFlag) if err != nil { log.Fatalf("Could not create stream dialer: %v", err) } resolver = dns.NewTCPResolver(streamDialer, resolverAddr) } else { - packetDialer, err := configToDialer.NewPacketDialer(*transportFlag) + packetDialer, err := providers.NewPacketDialer(context.Background(), *transportFlag) if err != nil { log.Fatalf("Could not create packet dialer: %v", err) } diff --git a/x/examples/smart-proxy/main.go b/x/examples/smart-proxy/main.go index 15490355..59d40b28 100644 --- a/x/examples/smart-proxy/main.go +++ b/x/examples/smart-proxy/main.go @@ -86,12 +86,12 @@ func main() { log.Fatalf("Could not read config: %v", err) } - configToDialer := configurl.NewDefaultConfigToDialer() - packetDialer, err := configToDialer.NewPacketDialer(*transportFlag) + providers := configurl.NewDefaultProviders() + packetDialer, err := providers.NewPacketDialer(context.Background(), *transportFlag) if err != nil { log.Fatalf("Could not create packet dialer: %v", err) } - streamDialer, err := configToDialer.NewStreamDialer(*transportFlag) + streamDialer, err := providers.NewStreamDialer(context.Background(), *transportFlag) if err != nil { log.Fatalf("Could not create stream dialer: %v", err) } diff --git a/x/examples/test-connectivity/main.go b/x/examples/test-connectivity/main.go index 1378ee1b..7197b6f7 100644 --- a/x/examples/test-connectivity/main.go +++ b/x/examples/test-connectivity/main.go @@ -240,7 +240,7 @@ func main() { var mu sync.Mutex dnsReports := make([]dnsReport, 0) tcpReports := make([]tcpReport, 0) - configToDialer := configurl.NewDefaultConfigToDialer() + providers := configurl.NewDefaultProviders() onDNS := func(ctx context.Context, domain string) func(di httptrace.DNSDoneInfo) { dnsStart := time.Now() return func(di httptrace.DNSDoneInfo) { @@ -260,7 +260,7 @@ func main() { mu.Unlock() } } - configToDialer.BaseStreamDialer = transport.FuncStreamDialer(func(ctx context.Context, addr string) (transport.StreamConn, error) { + providers.StreamDialers.BaseInstance = transport.FuncStreamDialer(func(ctx context.Context, addr string) (transport.StreamConn, error) { hostname, _, err := net.SplitHostPort(addr) if err != nil { return nil, err @@ -284,13 +284,13 @@ func main() { } return newTCPTraceDialer(onDNS, onDial).DialStream(ctx, addr) }) - configToDialer.BasePacketDialer = transport.FuncPacketDialer(func(ctx context.Context, addr string) (net.Conn, error) { + providers.PacketDialers.BaseInstance = transport.FuncPacketDialer(func(ctx context.Context, addr string) (net.Conn, error) { return newUDPTraceDialer(onDNS).DialPacket(ctx, addr) }) switch proto { case "tcp": - streamDialer, err := configToDialer.NewStreamDialer(*transportFlag) + streamDialer, err := providers.NewStreamDialer(context.Background(), *transportFlag) if err != nil { slog.Error("Failed to create StreamDialer", "error", err) os.Exit(1) @@ -298,7 +298,7 @@ func main() { resolver = dns.NewTCPResolver(streamDialer, resolverAddress) case "udp": - packetDialer, err := configToDialer.NewPacketDialer(*transportFlag) + packetDialer, err := providers.NewPacketDialer(context.Background(), *transportFlag) if err != nil { slog.Error("Failed to create PacketDialer", "error", err) os.Exit(1) diff --git a/x/examples/ws2endpoint/main.go b/x/examples/ws2endpoint/main.go index 5fac6aec..e09865e7 100644 --- a/x/examples/ws2endpoint/main.go +++ b/x/examples/ws2endpoint/main.go @@ -60,10 +60,10 @@ func main() { defer listener.Close() log.Printf("Proxy listening on %v\n", listener.Addr().String()) - config2Dialer := configurl.NewDefaultConfigToDialer() + providers := configurl.NewDefaultProviders() mux := http.NewServeMux() if *tcpPathFlag != "" { - dialer, err := config2Dialer.NewStreamDialer(*transportFlag) + dialer, err := providers.NewStreamDialer(context.Background(), *transportFlag) if err != nil { log.Fatalf("Could not create stream dialer: %v", err) } @@ -90,7 +90,7 @@ func main() { mux.Handle(*tcpPathFlag, http.StripPrefix(*tcpPathFlag, handler)) } if *udpPathFlag != "" { - dialer, err := config2Dialer.NewPacketDialer(*transportFlag) + dialer, err := providers.NewPacketDialer(context.Background(), *transportFlag) if err != nil { log.Fatalf("Could not create stream dialer: %v", err) } diff --git a/x/go.mod b/x/go.mod index a1dfe72a..7ad5e856 100644 --- a/x/go.mod +++ b/x/go.mod @@ -3,7 +3,7 @@ module github.com/Jigsaw-Code/outline-sdk/x go 1.21 require ( - github.com/Jigsaw-Code/outline-sdk v0.0.17-0.20240726212635-470a9290ec57 + github.com/Jigsaw-Code/outline-sdk v0.0.17 // Use github.com/Psiphon-Labs/psiphon-tunnel-core@staging-client as per // https://github.com/Psiphon-Labs/psiphon-tunnel-core/?tab=readme-ov-file#using-psiphon-with-go-modules github.com/Psiphon-Labs/psiphon-tunnel-core v1.0.11-0.20240619172145-03cade11f647 diff --git a/x/go.sum b/x/go.sum index 033308fb..67b2fc27 100644 --- a/x/go.sum +++ b/x/go.sum @@ -6,8 +6,8 @@ github.com/AndreasBriese/bbloom v0.0.0-20170702084017-28f7e881ca57 h1:CVuXDbdzPW github.com/AndreasBriese/bbloom v0.0.0-20170702084017-28f7e881ca57/go.mod h1:bOvUY6CB00SOBii9/FifXqc0awNKxLFCL/+pkDPuyl8= github.com/BurntSushi/toml v1.3.2 h1:o7IhLm0Msx3BaB+n3Ag7L8EVlByGnpq14C4YWiu/gL8= github.com/BurntSushi/toml v1.3.2/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= -github.com/Jigsaw-Code/outline-sdk v0.0.17-0.20240726212635-470a9290ec57 h1:XNSV0dGW48J8DmdmCnk/txGHf9glAPqa6Xme/rFWn7c= -github.com/Jigsaw-Code/outline-sdk v0.0.17-0.20240726212635-470a9290ec57/go.mod h1:e1oQZbSdLJBBuHgfeQsgEkvkuyIePPwstUeZRGq0KO8= +github.com/Jigsaw-Code/outline-sdk v0.0.17 h1:xkGsp3fs+5EbIZEeCEPI1rl0oyCrOLuO7YSK0gB75HE= +github.com/Jigsaw-Code/outline-sdk v0.0.17/go.mod h1:CFDKyGZA4zatKE4vMLe8TyQpZCyINOeRFbMAmYHxodw= github.com/Psiphon-Inc/rotate-safe-writer v0.0.0-20210303140923-464a7a37606e h1:NPfqIbzmijrl0VclX2t8eO5EPBhqe47LLGKpRrcVjXk= github.com/Psiphon-Inc/rotate-safe-writer v0.0.0-20210303140923-464a7a37606e/go.mod h1:ZdY5pBfat/WVzw3eXbIf7N1nZN0XD5H5+X8ZMDWbCs4= github.com/Psiphon-Labs/bolt v0.0.0-20200624191537-23cedaef7ad7 h1:Hx/NCZTnvoKZuIBwSmxE58KKoNLXIGG6hBJYN7pj9Ag= @@ -72,8 +72,8 @@ github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEe github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= github.com/gobwas/glob v0.2.4-0.20180402141543-f00a7392b439 h1:T6zlOdzrYuHf6HUKujm9bzkzbZ5Iv/xf6rs8BHZDpoI= github.com/gobwas/glob v0.2.4-0.20180402141543-f00a7392b439/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8= -github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e h1:1r7pUrabqp18hOBcwBwiTsbnFeTZHV9eER/QT5JVZxY= -github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= +github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= @@ -99,9 +99,8 @@ github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORN github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= -github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/libp2p/go-reuseport v0.4.0 h1:nR5KU7hD0WxXCJbmw7r2rhRYruNRl2koHw8fQscQm2s= github.com/libp2p/go-reuseport v0.4.0/go.mod h1:ZtI03j/wO5hZVDFo2jKywN6bYKWLOy8Se6DrI2E1cLU= github.com/lmittmann/tint v1.0.5 h1:NQclAutOfYsqs2F1Lenue6OoWCajs5wJcP3DfWVpePw= diff --git a/x/httpproxy/connect_handler.go b/x/httpproxy/connect_handler.go index 69512a7e..0922cce2 100644 --- a/x/httpproxy/connect_handler.go +++ b/x/httpproxy/connect_handler.go @@ -51,8 +51,8 @@ func (d *sanitizeErrorDialer) DialStream(ctx context.Context, addr string) (tran } type connectHandler struct { - dialer *sanitizeErrorDialer - dialerConfig *configurl.ConfigToDialer + dialer *sanitizeErrorDialer + providers *configurl.ProviderContainer } var _ http.Handler = (*connectHandler)(nil) @@ -77,7 +77,7 @@ func (h *connectHandler) ServeHTTP(proxyResp http.ResponseWriter, proxyReq *http // Dial the target. transportConfig := proxyReq.Header.Get("Transport") - dialer, err := h.dialerConfig.NewStreamDialer(transportConfig) + dialer, err := h.providers.NewStreamDialer(proxyReq.Context(), transportConfig) if err != nil { // Because we sanitize the base dialer error, it's safe to return error details here. http.Error(proxyResp, fmt.Sprintf("Invalid config in Transport header: %v", err), http.StatusBadRequest) @@ -149,7 +149,7 @@ func NewConnectHandler(dialer transport.StreamDialer) http.Handler { // of the base dialer (e.g. access key credentials) to the user. sd := &sanitizeErrorDialer{dialer} // TODO(fortuna): Inject the config parser - dialerConfig := configurl.NewDefaultConfigToDialer() - dialerConfig.BaseStreamDialer = sd - return &connectHandler{sd, dialerConfig} + providers := configurl.NewDefaultProviders() + providers.StreamDialers.BaseInstance = sd + return &connectHandler{sd, providers} } diff --git a/x/mobileproxy/mobileproxy.go b/x/mobileproxy/mobileproxy.go index 19ea7563..5b832b7b 100644 --- a/x/mobileproxy/mobileproxy.go +++ b/x/mobileproxy/mobileproxy.go @@ -152,12 +152,12 @@ type StreamDialer struct { transport.StreamDialer } -var configToDialer = configurl.NewDefaultConfigToDialer() +var configModule = configurl.NewDefaultProviders() // NewStreamDialerFromConfig creates a [StreamDialer] based on the given config. // The config format is specified in https://pkg.go.dev/github.com/Jigsaw-Code/outline-sdk/x/config#hdr-Config_Format. func NewStreamDialerFromConfig(transportConfig string) (*StreamDialer, error) { - dialer, err := configToDialer.NewStreamDialer(transportConfig) + dialer, err := configModule.NewStreamDialer(context.Background(), transportConfig) if err != nil { return nil, err } diff --git a/x/smart/stream_dialer.go b/x/smart/stream_dialer.go index 092aac1c..54c089ef 100644 --- a/x/smart/stream_dialer.go +++ b/x/smart/stream_dialer.go @@ -231,8 +231,8 @@ func (f *StrategyFinder) findTLS(ctx context.Context, testDomains []string, base if len(tlsConfig) == 0 { return nil, errors.New("config for TLS is empty. Please specify at least one transport") } - var configToDialer = configurl.NewDefaultConfigToDialer() - configToDialer.BaseStreamDialer = baseDialer + var configModule = configurl.NewDefaultProviders() + configModule.StreamDialers.BaseInstance = baseDialer ctx, searchDone := context.WithCancel(ctx) defer searchDone() @@ -242,7 +242,7 @@ func (f *StrategyFinder) findTLS(ctx context.Context, testDomains []string, base Config string } result, err := raceTests(ctx, 250*time.Millisecond, tlsConfig, func(transportCfg string) (*SearchResult, error) { - tlsDialer, err := configToDialer.NewStreamDialer(transportCfg) + tlsDialer, err := configModule.NewStreamDialer(ctx, transportCfg) if err != nil { return nil, fmt.Errorf("WrapStreamDialer failed: %w", err) }