From c4b04616cc017b5ae8be7412d802b40a965ef827 Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Fri, 1 Nov 2024 17:36:09 -0400 Subject: [PATCH 01/21] feat(x): make config support PacketListeners and make dependencies explicit and decoupled (#304) --- x/configurl/config.go | 268 ++++-------------- x/configurl/dns.go | 63 ++-- x/configurl/doc.go | 32 ++- x/configurl/module.go | 153 ++++++++++ .../{config_test.go => module_test.go} | 0 x/configurl/override.go | 78 ++--- x/configurl/override_test.go | 10 +- x/configurl/shadowsocks.go | 102 ++++--- x/configurl/shadowsocks_test.go | 81 +++--- x/configurl/socks5.go | 68 +++-- x/configurl/split.go | 26 +- x/configurl/tls.go | 29 +- x/configurl/tls_test.go | 33 +-- x/configurl/tlsfrag.go | 39 +++ x/configurl/websocket.go | 106 +++---- x/examples/fetch-speed/main.go | 2 +- x/examples/fetch/main.go | 2 +- x/examples/http2transport/main.go | 2 +- x/examples/outline-cli/outline_device.go | 5 +- .../outline-cli/outline_packet_proxy.go | 2 +- x/examples/resolve/main.go | 6 +- x/examples/smart-proxy/main.go | 6 +- x/examples/test-connectivity/main.go | 10 +- x/examples/ws2endpoint/main.go | 6 +- x/go.mod | 2 +- x/go.sum | 11 +- x/httpproxy/connect_handler.go | 12 +- x/mobileproxy/mobileproxy.go | 4 +- x/smart/stream_dialer.go | 6 +- 29 files changed, 622 insertions(+), 542 deletions(-) create mode 100644 x/configurl/module.go rename x/configurl/{config_test.go => module_test.go} (100%) create mode 100644 x/configurl/tlsfrag.go diff --git a/x/configurl/config.go b/x/configurl/config.go index 0c3c0b97..e338c950 100644 --- a/x/configurl/config.go +++ b/x/configurl/config.go @@ -15,108 +15,84 @@ package configurl import ( + "context" "errors" "fmt" "net/url" - "strconv" "strings" - - "github.com/Jigsaw-Code/outline-sdk/transport" - "github.com/Jigsaw-Code/outline-sdk/transport/tlsfrag" ) -// ConfigToDialer enables the creation of stream and packet dialers based on a config. The config is -// extensible by registering wrappers for config subtypes. -type ConfigToDialer struct { - // Base StreamDialer to create direct stream connections. If you need direct stream connections, this must not be nil. - BaseStreamDialer transport.StreamDialer - // Base PacketDialer to create direct packet connections. If you need direct packet connections, this must not be nil. - BasePacketDialer transport.PacketDialer - sdBuilders map[string]NewStreamDialerFunc - pdBuilders map[string]NewPacketDialerFunc +// Config is a pre-parsed generic config created from pipe-separated URLs. +type Config struct { + URL url.URL + BaseConfig *Config } -// NewStreamDialerFunc wraps a Dialer based on the wrapConfig. The innerSD and innerPD functions can provide a base Stream and Packet Dialers if needed. -type NewStreamDialerFunc func(innerSD func() (transport.StreamDialer, error), innerPD func() (transport.PacketDialer, error), wrapConfig *url.URL) (transport.StreamDialer, error) - -// NewPacketDialerFunc wraps a Dialer based on the wrapConfig. The innerSD and innerPD functions can provide a base Stream and Packet Dialers if needed. -type NewPacketDialerFunc func(innerSD func() (transport.StreamDialer, error), innerPD func() (transport.PacketDialer, error), wrapConfig *url.URL) (transport.PacketDialer, error) - -// NewDefaultConfigToDialer creates a [ConfigToDialer] with a set of default wrappers already registered. -func NewDefaultConfigToDialer() *ConfigToDialer { - p := new(ConfigToDialer) - p.BaseStreamDialer = &transport.TCPDialer{} - p.BasePacketDialer = &transport.UDPDialer{} - - // Please keep the list in alphabetical order. - p.RegisterStreamDialerType("do53", wrapStreamDialerWithDO53) - - p.RegisterStreamDialerType("doh", wrapStreamDialerWithDOH) - - p.RegisterStreamDialerType("override", wrapStreamDialerWithOverride) - p.RegisterPacketDialerType("override", wrapPacketDialerWithOverride) - - p.RegisterStreamDialerType("socks5", wrapStreamDialerWithSOCKS5) - p.RegisterPacketDialerType("socks5", wrapPacketDialerWithSOCKS5) - - p.RegisterStreamDialerType("split", wrapStreamDialerWithSplit) +// BuildFunc is a function that creates an instance of ObjectType given a [Config]. +type BuildFunc[ObjectType any] func(ctx context.Context, config *Config) (ObjectType, error) - p.RegisterStreamDialerType("ss", wrapStreamDialerWithShadowsocks) - p.RegisterPacketDialerType("ss", wrapPacketDialerWithShadowsocks) - - p.RegisterStreamDialerType("tls", wrapStreamDialerWithTLS) +// TypeRegistry registers config types. +type TypeRegistry[ObjectType any] interface { + RegisterType(subtype string, newInstance BuildFunc[ObjectType]) +} - p.RegisterStreamDialerType("tlsfrag", func(innerSD func() (transport.StreamDialer, error), innerPD func() (transport.PacketDialer, error), wrapConfig *url.URL) (transport.StreamDialer, error) { - sd, err := innerSD() - if err != nil { - return nil, err - } - lenStr := wrapConfig.Opaque - fixedLen, err := strconv.Atoi(lenStr) - if err != nil { - return nil, fmt.Errorf("invalid tlsfrag option: %v. It should be in tlsfrag: format", lenStr) - } - return tlsfrag.NewFixedLenStreamDialer(sd, fixedLen) - }) +// ExtensibleProvider creates instances of ObjectType in a way that can be extended via its [TypeRegistry] interface. +type ExtensibleProvider[ObjectType comparable] struct { + // Instance to return when config is nil. + BaseInstance ObjectType + builders map[string]BuildFunc[ObjectType] +} - p.RegisterStreamDialerType("ws", wrapStreamDialerWithWebSocket) - p.RegisterPacketDialerType("ws", wrapPacketDialerWithWebSocket) +var ( + _ BuildFunc[any] = (*ExtensibleProvider[any])(nil).NewInstance + _ TypeRegistry[any] = (*ExtensibleProvider[any])(nil) +) - return p +// NewExtensibleProvider creates an [ExtensibleProvider] with the given base instance. +func NewExtensibleProvider[ObjectType comparable](baseInstance ObjectType) ExtensibleProvider[ObjectType] { + return ExtensibleProvider[ObjectType]{ + BaseInstance: baseInstance, + builders: make(map[string]BuildFunc[ObjectType]), + } } -// RegisterStreamDialerType will register a wrapper for stream dialers under the given subtype. -func (p *ConfigToDialer) RegisterStreamDialerType(subtype string, newDialer NewStreamDialerFunc) error { - if p.sdBuilders == nil { - p.sdBuilders = make(map[string]NewStreamDialerFunc) +func (p *ExtensibleProvider[ObjectType]) ensureBuildersMap() map[string]BuildFunc[ObjectType] { + if p.builders == nil { + p.builders = make(map[string]BuildFunc[ObjectType]) } + return p.builders +} - if _, found := p.sdBuilders[subtype]; found { - return fmt.Errorf("config parser %v for StreamDialer added twice", subtype) - } - p.sdBuilders[subtype] = newDialer - return nil +// RegisterType will register a factory for the given subtype. +func (p *ExtensibleProvider[ObjectType]) RegisterType(subtype string, newInstance BuildFunc[ObjectType]) { + p.ensureBuildersMap()[subtype] = newInstance } -// RegisterPacketDialerType will register a wrapper for packet dialers under the given subtype. -func (p *ConfigToDialer) RegisterPacketDialerType(subtype string, newDialer NewPacketDialerFunc) error { - if p.pdBuilders == nil { - p.pdBuilders = make(map[string]NewPacketDialerFunc) +// NewInstance creates a new instance of ObjectType according to the config. +func (p *ExtensibleProvider[ObjectType]) NewInstance(ctx context.Context, config *Config) (ObjectType, error) { + var zero ObjectType + if config == nil { + if p.BaseInstance == zero { + return zero, errors.New("base instance is not configured") + } + return p.BaseInstance, nil } - if _, found := p.pdBuilders[subtype]; found { - return fmt.Errorf("config parser %v for StreamDialer added twice", subtype) + newInstance, ok := p.ensureBuildersMap()[config.URL.Scheme] + if !ok { + return zero, fmt.Errorf("config type '%v' is not registered", config.URL.Scheme) } - p.pdBuilders[subtype] = newDialer - return nil + return newInstance(ctx, config) } -func parseConfig(configText string) ([]*url.URL, error) { +// ParseConfig will parse a config given as a string and return the structured [Config]. +func ParseConfig(configText string) (*Config, error) { parts := strings.Split(strings.TrimSpace(configText), "|") if len(parts) == 1 && parts[0] == "" { - return []*url.URL{}, nil + return nil, nil } - urls := make([]*url.URL, 0, len(parts)) + + var config *Config = nil for _, part := range parts { part = strings.TrimSpace(part) if part == "" { @@ -130,143 +106,7 @@ func parseConfig(configText string) ([]*url.URL, error) { if err != nil { return nil, fmt.Errorf("part is not a valid URL: %w", err) } - urls = append(urls, url) - } - return urls, nil -} - -// NewStreamDialer creates a [Dialer] according to transportConfig, using dialer as the -// base [Dialer]. The given dialer must not be nil. -func (p *ConfigToDialer) NewStreamDialer(transportConfig string) (transport.StreamDialer, error) { - parts, err := parseConfig(transportConfig) - if err != nil { - return nil, err - } - return p.newStreamDialer(parts) -} - -// NewPacketDialer creates a [Dialer] according to transportConfig, using dialer as the -// base [Dialer]. The given dialer must not be nil. -func (p *ConfigToDialer) NewPacketDialer(transportConfig string) (transport.PacketDialer, error) { - parts, err := parseConfig(transportConfig) - if err != nil { - return nil, err - } - return p.newPacketDialer(parts) -} - -func (p *ConfigToDialer) newStreamDialer(configParts []*url.URL) (transport.StreamDialer, error) { - if len(configParts) == 0 { - if p.BaseStreamDialer == nil { - return nil, fmt.Errorf("base StreamDialer must not be nil") - } - return p.BaseStreamDialer, nil - } - thisURL := configParts[len(configParts)-1] - innerConfig := configParts[:len(configParts)-1] - newDialer, ok := p.sdBuilders[thisURL.Scheme] - if !ok { - return nil, fmt.Errorf("config scheme '%v' is not supported for Stream Dialers", thisURL.Scheme) - } - newSD := func() (transport.StreamDialer, error) { - return p.newStreamDialer(innerConfig) - } - newPD := func() (transport.PacketDialer, error) { - return p.newPacketDialer(innerConfig) - } - return newDialer(newSD, newPD, thisURL) -} - -func (p *ConfigToDialer) newPacketDialer(configParts []*url.URL) (transport.PacketDialer, error) { - if len(configParts) == 0 { - if p.BasePacketDialer == nil { - return nil, fmt.Errorf("base PacketDialer must not be nil") - } - return p.BasePacketDialer, nil - } - thisURL := configParts[len(configParts)-1] - innerConfig := configParts[:len(configParts)-1] - newDialer, ok := p.pdBuilders[thisURL.Scheme] - if !ok { - return nil, fmt.Errorf("config scheme '%v' is not supported for Packet Dialers", thisURL.Scheme) - } - newSD := func() (transport.StreamDialer, error) { - return p.newStreamDialer(innerConfig) - } - newPD := func() (transport.PacketDialer, error) { - return p.newPacketDialer(innerConfig) - } - return newDialer(newSD, newPD, thisURL) -} - -// NewpacketListener creates a new [transport.PacketListener] according to the given config, -// the config must contain only one "ss://" segment. -// TODO: make NewPacketListener configurable. -func NewPacketListener(transportConfig string) (transport.PacketListener, error) { - parts, err := parseConfig(transportConfig) - if err != nil { - return nil, err - } - if len(parts) == 0 { - return nil, errors.New("config is required") - } - if len(parts) > 1 { - return nil, errors.New("multi-part config is not supported") - } - - url := parts[0] - // Please keep scheme list sorted. - switch strings.ToLower(url.Scheme) { - case "ss": - // TODO: support nested dialer, the last part must be "ss://" - return newShadowsocksPacketListenerFromURL(url) - default: - return nil, fmt.Errorf("config scheme '%v' is not supported", url.Scheme) - } -} - -func SanitizeConfig(transportConfig string) (string, error) { - parts, err := parseConfig(transportConfig) - if err != nil { - return "", err - } - - // Do nothing if the config is empty - if len(parts) == 0 { - return "", nil - } - - // Iterate through each part - textParts := make([]string, len(parts)) - for i, u := range parts { - scheme := strings.ToLower(u.Scheme) - switch scheme { - case "ss": - textParts[i], err = sanitizeShadowsocksURL(u) - if err != nil { - return "", err - } - case "socks5": - textParts[i], err = sanitizeSocks5URL(u) - if err != nil { - return "", err - } - case "override", "split", "tls", "tlsfrag": - // No sanitization needed - textParts[i] = u.String() - default: - textParts[i] = scheme + "://UNKNOWN" - } - } - // Join the parts back into a string - return strings.Join(textParts, "|"), nil -} - -func sanitizeSocks5URL(u *url.URL) (string, error) { - const redactedPlaceholder = "REDACTED" - if u.User != nil { - u.User = url.User(redactedPlaceholder) - return u.String(), nil + config = &Config{URL: *url, BaseConfig: config} } - return u.String(), nil + return config, nil } diff --git a/x/configurl/dns.go b/x/configurl/dns.go index 205dc8fd..2e41ff59 100644 --- a/x/configurl/dns.go +++ b/x/configurl/dns.go @@ -27,16 +27,46 @@ import ( "golang.org/x/net/dns/dnsmessage" ) -func wrapStreamDialerWithDO53(innerSD func() (transport.StreamDialer, error), innerPD func() (transport.PacketDialer, error), configURL *url.URL) (transport.StreamDialer, error) { - sd, err := innerSD() - if err != nil { - return nil, err - } - pd, err := innerPD() - if err != nil { - return nil, err - } - query := configURL.Opaque +func registerDO53StreamDialer(r TypeRegistry[transport.StreamDialer], typeID string, newSD BuildFunc[transport.StreamDialer], newPD BuildFunc[transport.PacketDialer]) { + r.RegisterType(typeID, func(ctx context.Context, config *Config) (transport.StreamDialer, error) { + if config == nil { + return nil, fmt.Errorf("emtpy do53 config") + } + sd, err := newSD(ctx, config.BaseConfig) + if err != nil { + return nil, err + } + pd, err := newPD(ctx, config.BaseConfig) + if err != nil { + return nil, err + } + resolver, err := newDO53Resolver(config.URL, sd, pd) + if err != nil { + return nil, err + } + return dns.NewStreamDialer(resolver, sd) + }) +} + +func registerDOHStreamDialer(r TypeRegistry[transport.StreamDialer], typeID string, newSD BuildFunc[transport.StreamDialer]) { + r.RegisterType(typeID, func(ctx context.Context, config *Config) (transport.StreamDialer, error) { + if config == nil { + return nil, fmt.Errorf("emtpy doh config") + } + sd, err := newSD(ctx, config.BaseConfig) + if err != nil { + return nil, err + } + resolver, err := newDOHResolver(config.URL, sd) + if err != nil { + return nil, err + } + return dns.NewStreamDialer(resolver, sd) + }) +} + +func newDO53Resolver(config url.URL, sd transport.StreamDialer, pd transport.PacketDialer) (dns.Resolver, error) { + query := config.Opaque values, err := url.ParseQuery(query) if err != nil { return nil, err @@ -75,19 +105,15 @@ func wrapStreamDialerWithDO53(innerSD func() (transport.StreamDialer, error), in // See https://datatracker.ietf.org/doc/html/rfc1123#page-75. return tcpResolver.Query(ctx, q) }) - return dns.NewStreamDialer(resolver, sd) + return resolver, nil } -func wrapStreamDialerWithDOH(innerSD func() (transport.StreamDialer, error), innerPD func() (transport.PacketDialer, error), configURL *url.URL) (transport.StreamDialer, error) { - query := configURL.Opaque +func newDOHResolver(config url.URL, sd transport.StreamDialer) (dns.Resolver, error) { + query := config.Opaque values, err := url.ParseQuery(query) if err != nil { return nil, err } - sd, err := innerSD() - if err != nil { - return nil, err - } var name, address string for key, values := range values { @@ -119,6 +145,5 @@ func wrapStreamDialerWithDOH(innerSD func() (transport.StreamDialer, error), inn port = "443" } dohURL := url.URL{Scheme: "https", Host: net.JoinHostPort(name, port), Path: "/dns-query"} - resolver := dns.NewHTTPSResolver(sd, address, dohURL.String()) - return dns.NewStreamDialer(resolver, sd) + return dns.NewHTTPSResolver(sd, address, dohURL.String()), nil } diff --git a/x/configurl/doc.go b/x/configurl/doc.go index 70cfd6b4..dfee54c6 100644 --- a/x/configurl/doc.go +++ b/x/configurl/doc.go @@ -13,11 +13,11 @@ // limitations under the License. /* -Package config provides convenience functions to create dialer objects based on a text config. +Package configurl provides convenience functions to create network objects based on a text config. This is experimental and mostly for illustrative purposes at this point. -Configurable transports simplifies the way you create and manage transports. -With the config package, you can use [NewPacketDialer] and [NewStreamDialer] to create dialers using a simple text string. +Configurable strategies simplifies the way you create and manage strategies. +With the configurl package, you can use [ProviderContainer.NewPacketDialer], [ProviderContainer.NewStreamDialer] and [ProviderContainer.NewPacketListener] to create objects using a simple text string. Key Benefits: @@ -129,19 +129,23 @@ DPI Evasion - To add packet splitting to a Shadowsocks server for enhanced DPI e split:2|ss://[USERINFO]@[HOST]:[PORT] -Defining custom transport - You can define your custom transport by implementing and registering the [NewStreamDialerFunc] and [NewPacketDialerFunc] functions: +Defining custom strategies - You can define your custom strategy by implementing and registering [BuildFunc[ObjectType]] functions: - // create new config parser - // p := new(ConfigToDialer) + // Create new config parser. + // p := configurl.NewProviderContainer // or - p := NewDefaultConfigToDialer() - // register your custom dialer - p.RegisterPacketDialerWrapper("custom", wrapStreamDialerWithCustom) - p.RegisterStreamDialerWrapper("custom", wrapPacketDialerWithCustom) - // then use it - dialer, err := p.NewStreamDialer(innerDialer, "custom://config") - -where wrapStreamDialerWithCustom and wrapPacketDialerWithCustom implement [NewPacketDialerFunc] and [NewStreamDialerFunc]. + p := configurl.NewDefaultProviders() + // Register your custom dialer. + p.StreamDialers.RegisterType("custom", func(ctx context.Context, config *Config) (transport.StreamDialer, error) { + // Build logic + // ... + }) + p.PacketDialers.RegisterType("custom", func(ctx context.Context, config *Config) (transport.PacketDialer, error) { + // Build logic + // ... + }) + // Then use it + dialer, err := p.NewStreamDialer(context.Background(), "custom://config") [Onion Routing]: https://en.wikipedia.org/wiki/Onion_routing */ diff --git a/x/configurl/module.go b/x/configurl/module.go new file mode 100644 index 00000000..83e14b89 --- /dev/null +++ b/x/configurl/module.go @@ -0,0 +1,153 @@ +// Copyright 2024 The Outline Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package configurl + +import ( + "context" + "net/url" + "strings" + + "github.com/Jigsaw-Code/outline-sdk/transport" +) + +// ProviderContainer contains providers for the creation of network objects based on a config. The config is +// extensible by registering providers for different config subtypes. +type ProviderContainer struct { + StreamDialers ExtensibleProvider[transport.StreamDialer] + PacketDialers ExtensibleProvider[transport.PacketDialer] + PacketListeners ExtensibleProvider[transport.PacketListener] +} + +// NewProviderContainer creates a [ProviderContainer] with the base instances properly initialized. +func NewProviderContainer() *ProviderContainer { + return &ProviderContainer{ + StreamDialers: NewExtensibleProvider[transport.StreamDialer](&transport.TCPDialer{}), + PacketDialers: NewExtensibleProvider[transport.PacketDialer](&transport.UDPDialer{}), + PacketListeners: NewExtensibleProvider[transport.PacketListener](&transport.UDPListener{}), + } +} + +// RegisterDefaultProviders registers a set of default providers with the providers in [ProviderContainer]. +func RegisterDefaultProviders(c *ProviderContainer) *ProviderContainer { + // Please keep the list in alphabetical order. + registerDO53StreamDialer(&c.StreamDialers, "do53", c.StreamDialers.NewInstance, c.PacketDialers.NewInstance) + registerDOHStreamDialer(&c.StreamDialers, "doh", c.StreamDialers.NewInstance) + + registerOverrideStreamDialer(&c.StreamDialers, "override", c.StreamDialers.NewInstance) + registerOverridePacketDialer(&c.PacketDialers, "override", c.PacketDialers.NewInstance) + + registerSOCKS5StreamDialer(&c.StreamDialers, "socks5", c.StreamDialers.NewInstance) + registerSOCKS5PacketDialer(&c.PacketDialers, "socks5", c.StreamDialers.NewInstance, c.PacketDialers.NewInstance) + registerSOCKS5PacketListener(&c.PacketListeners, "socks5", c.StreamDialers.NewInstance, c.PacketDialers.NewInstance) + + registerSplitStreamDialer(&c.StreamDialers, "split", c.StreamDialers.NewInstance) + + registerShadowsocksStreamDialer(&c.StreamDialers, "ss", c.StreamDialers.NewInstance) + registerShadowsocksPacketDialer(&c.PacketDialers, "ss", c.PacketDialers.NewInstance) + registerShadowsocksPacketListener(&c.PacketListeners, "ss", c.PacketDialers.NewInstance) + + registerTLSStreamDialer(&c.StreamDialers, "tls", c.StreamDialers.NewInstance) + + registerTLSFragStreamDialer(&c.StreamDialers, "tlsfrag", c.StreamDialers.NewInstance) + + registerWebsocketStreamDialer(&c.StreamDialers, "ws", c.StreamDialers.NewInstance) + registerWebsocketPacketDialer(&c.PacketDialers, "ws", c.StreamDialers.NewInstance) + + return c +} + +// NewDefaultProviders creates a [ProviderContainer] with a set of default providers already registered. +func NewDefaultProviders() *ProviderContainer { + return RegisterDefaultProviders(NewProviderContainer()) +} + +// NewStreamDialer creates a [transport.StreamDialer] according to the config text. +func (p *ProviderContainer) NewStreamDialer(ctx context.Context, configText string) (transport.StreamDialer, error) { + config, err := ParseConfig(configText) + if err != nil { + return nil, err + } + return p.StreamDialers.NewInstance(ctx, config) +} + +// NewPacketDialer creates a [transport.PacketDialer] according to the config text. +func (p *ProviderContainer) NewPacketDialer(ctx context.Context, configText string) (transport.PacketDialer, error) { + config, err := ParseConfig(configText) + if err != nil { + return nil, err + } + return p.PacketDialers.NewInstance(ctx, config) +} + +// NewPacketListner creates a [transport.PacketListener] according to the config text. +func (p *ProviderContainer) NewPacketListener(ctx context.Context, configText string) (transport.PacketListener, error) { + config, err := ParseConfig(configText) + if err != nil { + return nil, err + } + return p.PacketListeners.NewInstance(ctx, config) +} + +// SanitizeConfig removes sensitive information from the given config so it can be safely be used in logging and debugging. +func SanitizeConfig(configStr string) (string, error) { + config, err := ParseConfig(configStr) + if err != nil { + return "", err + } + + // Do nothing if the config is empty + if config == nil { + return "", nil + } + + var sanitized string + for config != nil { + var part string + scheme := strings.ToLower(config.URL.Scheme) + switch scheme { + case "ss": + part, err = sanitizeShadowsocksURL(config.URL) + if err != nil { + return "", err + } + case "socks5": + part, err = sanitizeSOCKS5URL(&config.URL) + if err != nil { + return "", err + } + case "override", "split", "tls", "tlsfrag": + // No sanitization needed + part = config.URL.String() + default: + part = scheme + "://UNKNOWN" + } + if sanitized == "" { + sanitized = part + } else { + sanitized = part + "|" + sanitized + } + config = config.BaseConfig + } + return sanitized, nil +} + +func sanitizeSOCKS5URL(u *url.URL) (string, error) { + const redactedPlaceholder = "REDACTED" + if u.User != nil { + u.User = url.User(redactedPlaceholder) + return u.String(), nil + } + return u.String(), nil +} diff --git a/x/configurl/config_test.go b/x/configurl/module_test.go similarity index 100% rename from x/configurl/config_test.go rename to x/configurl/module_test.go diff --git a/x/configurl/override.go b/x/configurl/override.go index 3f34ea63..743aae31 100644 --- a/x/configurl/override.go +++ b/x/configurl/override.go @@ -24,7 +24,47 @@ import ( "github.com/Jigsaw-Code/outline-sdk/transport" ) -func newOverrideFromURL(configURL *url.URL) (func(string) (string, error), error) { +func registerOverrideStreamDialer(r TypeRegistry[transport.StreamDialer], typeID string, newSD BuildFunc[transport.StreamDialer]) { + r.RegisterType(typeID, func(ctx context.Context, config *Config) (transport.StreamDialer, error) { + sd, err := newSD(ctx, config.BaseConfig) + if err != nil { + return nil, err + } + override, err := newOverrideFromURL(config.URL) + if err != nil { + return nil, err + } + return transport.FuncStreamDialer(func(ctx context.Context, addr string) (transport.StreamConn, error) { + addr, err := override(addr) + if err != nil { + return nil, err + } + return sd.DialStream(ctx, addr) + }), nil + }) +} + +func registerOverridePacketDialer(r TypeRegistry[transport.PacketDialer], typeID string, newPD BuildFunc[transport.PacketDialer]) { + r.RegisterType(typeID, func(ctx context.Context, config *Config) (transport.PacketDialer, error) { + pd, err := newPD(ctx, config.BaseConfig) + if err != nil { + return nil, err + } + override, err := newOverrideFromURL(config.URL) + if err != nil { + return nil, err + } + return transport.FuncPacketDialer(func(ctx context.Context, addr string) (net.Conn, error) { + addr, err := override(addr) + if err != nil { + return nil, err + } + return pd.DialPacket(ctx, addr) + }), nil + }) +} + +func newOverrideFromURL(configURL url.URL) (func(string) (string, error), error) { query := configURL.Opaque values, err := url.ParseQuery(query) if err != nil { @@ -65,39 +105,3 @@ func newOverrideFromURL(configURL *url.URL) (func(string) (string, error), error return net.JoinHostPort(host, port), nil }, nil } - -func wrapStreamDialerWithOverride(innerSD func() (transport.StreamDialer, error), innerPD func() (transport.PacketDialer, error), configURL *url.URL) (transport.StreamDialer, error) { - sd, err := innerSD() - if err != nil { - return nil, err - } - override, err := newOverrideFromURL(configURL) - if err != nil { - return nil, err - } - return transport.FuncStreamDialer(func(ctx context.Context, addr string) (transport.StreamConn, error) { - addr, err := override(addr) - if err != nil { - return nil, err - } - return sd.DialStream(ctx, addr) - }), nil -} - -func wrapPacketDialerWithOverride(innerSD func() (transport.StreamDialer, error), innerPD func() (transport.PacketDialer, error), configURL *url.URL) (transport.PacketDialer, error) { - pd, err := innerPD() - if err != nil { - return nil, err - } - override, err := newOverrideFromURL(configURL) - if err != nil { - return nil, err - } - return transport.FuncPacketDialer(func(ctx context.Context, addr string) (net.Conn, error) { - addr, err := override(addr) - if err != nil { - return nil, err - } - return pd.DialPacket(ctx, addr) - }), nil -} diff --git a/x/configurl/override_test.go b/x/configurl/override_test.go index 0069de91..6a7e122e 100644 --- a/x/configurl/override_test.go +++ b/x/configurl/override_test.go @@ -25,7 +25,7 @@ func Test_newOverrideFromURL(t *testing.T) { t.Run("Host Override", func(t *testing.T) { cfgUrl, err := url.Parse("override:host=www.google.com") require.NoError(t, err) - override, err := newOverrideFromURL(cfgUrl) + override, err := newOverrideFromURL(*cfgUrl) require.NoError(t, err) addr, err := override("www.youtube.com:443") require.NoError(t, err) @@ -34,7 +34,7 @@ func Test_newOverrideFromURL(t *testing.T) { t.Run("Port Override", func(t *testing.T) { cfgUrl, err := url.Parse("override:port=853") require.NoError(t, err) - override, err := newOverrideFromURL(cfgUrl) + override, err := newOverrideFromURL(*cfgUrl) require.NoError(t, err) addr, err := override("8.8.8.8:53") require.NoError(t, err) @@ -43,7 +43,7 @@ func Test_newOverrideFromURL(t *testing.T) { t.Run("Full Override", func(t *testing.T) { cfgUrl, err := url.Parse("override:host=8.8.8.8&port=853") require.NoError(t, err) - override, err := newOverrideFromURL(cfgUrl) + override, err := newOverrideFromURL(*cfgUrl) require.NoError(t, err) addr, err := override("dns.google:53") require.NoError(t, err) @@ -53,7 +53,7 @@ func Test_newOverrideFromURL(t *testing.T) { t.Run("Host Override", func(t *testing.T) { cfgUrl, err := url.Parse("override:host=www.google.com") require.NoError(t, err) - override, err := newOverrideFromURL(cfgUrl) + override, err := newOverrideFromURL(*cfgUrl) require.NoError(t, err) _, err = override("foo bar") require.Error(t, err) @@ -61,7 +61,7 @@ func Test_newOverrideFromURL(t *testing.T) { t.Run("Full Override", func(t *testing.T) { cfgUrl, err := url.Parse("override:host=8.8.8.8&port=853") require.NoError(t, err) - override, err := newOverrideFromURL(cfgUrl) + override, err := newOverrideFromURL(*cfgUrl) require.NoError(t, err) addr, err := override("foo bar") require.NoError(t, err) diff --git a/x/configurl/shadowsocks.go b/x/configurl/shadowsocks.go index 246460f3..594483e8 100644 --- a/x/configurl/shadowsocks.go +++ b/x/configurl/shadowsocks.go @@ -15,6 +15,7 @@ package configurl import ( + "context" "encoding/base64" "errors" "fmt" @@ -25,52 +26,61 @@ import ( "github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks" ) -func wrapStreamDialerWithShadowsocks(innerSD func() (transport.StreamDialer, error), _ func() (transport.PacketDialer, error), configURL *url.URL) (transport.StreamDialer, error) { - sd, err := innerSD() - if err != nil { - return nil, err - } - config, err := parseShadowsocksURL(configURL) - if err != nil { - return nil, err - } - endpoint := &transport.StreamDialerEndpoint{Dialer: sd, Address: config.serverAddress} - dialer, err := shadowsocks.NewStreamDialer(endpoint, config.cryptoKey) - if err != nil { - return nil, err - } - if len(config.prefix) > 0 { - dialer.SaltGenerator = shadowsocks.NewPrefixSaltGenerator(config.prefix) - } - return dialer, nil +func registerShadowsocksStreamDialer(r TypeRegistry[transport.StreamDialer], typeID string, newSD BuildFunc[transport.StreamDialer]) { + r.RegisterType(typeID, func(ctx context.Context, config *Config) (transport.StreamDialer, error) { + sd, err := newSD(ctx, config.BaseConfig) + if err != nil { + return nil, err + } + ssConfig, err := parseShadowsocksURL(config.URL) + if err != nil { + return nil, err + } + endpoint := &transport.StreamDialerEndpoint{Dialer: sd, Address: ssConfig.serverAddress} + dialer, err := shadowsocks.NewStreamDialer(endpoint, ssConfig.cryptoKey) + if err != nil { + return nil, err + } + if len(ssConfig.prefix) > 0 { + dialer.SaltGenerator = shadowsocks.NewPrefixSaltGenerator(ssConfig.prefix) + } + return dialer, nil + }) } -func wrapPacketDialerWithShadowsocks(_ func() (transport.StreamDialer, error), innerPD func() (transport.PacketDialer, error), configURL *url.URL) (transport.PacketDialer, error) { - pd, err := innerPD() - if err != nil { - return nil, err - } - config, err := parseShadowsocksURL(configURL) - if err != nil { - return nil, err - } - endpoint := &transport.PacketDialerEndpoint{Dialer: pd, Address: config.serverAddress} - listener, err := shadowsocks.NewPacketListener(endpoint, config.cryptoKey) - if err != nil { - return nil, err - } - dialer := transport.PacketListenerDialer{Listener: listener} - return dialer, nil +func registerShadowsocksPacketDialer(r TypeRegistry[transport.PacketDialer], typeID string, newPD BuildFunc[transport.PacketDialer]) { + r.RegisterType(typeID, func(ctx context.Context, config *Config) (transport.PacketDialer, error) { + pd, err := newPD(ctx, config.BaseConfig) + if err != nil { + return nil, err + } + ssConfig, err := parseShadowsocksURL(config.URL) + if err != nil { + return nil, err + } + endpoint := &transport.PacketDialerEndpoint{Dialer: pd, Address: ssConfig.serverAddress} + pl, err := shadowsocks.NewPacketListener(endpoint, ssConfig.cryptoKey) + if err != nil { + return nil, err + } + // TODO: support UDP prefix. + return transport.PacketListenerDialer{Listener: pl}, nil + }) } -func newShadowsocksPacketListenerFromURL(configURL *url.URL) (transport.PacketListener, error) { - config, err := parseShadowsocksURL(configURL) - if err != nil { - return nil, err - } - // TODO: accept an inner dialer from the caller and pass it to UDPEndpoint - ep := &transport.UDPEndpoint{Address: config.serverAddress} - return shadowsocks.NewPacketListener(ep, config.cryptoKey) +func registerShadowsocksPacketListener(r TypeRegistry[transport.PacketListener], typeID string, newPD BuildFunc[transport.PacketDialer]) { + r.RegisterType(typeID, func(ctx context.Context, config *Config) (transport.PacketListener, error) { + pd, err := newPD(ctx, config.BaseConfig) + if err != nil { + return nil, err + } + ssConfig, err := parseShadowsocksURL(config.URL) + if err != nil { + return nil, err + } + endpoint := &transport.PacketDialerEndpoint{Dialer: pd, Address: ssConfig.serverAddress} + return shadowsocks.NewPacketListener(endpoint, ssConfig.cryptoKey) + }) } type shadowsocksConfig struct { @@ -79,7 +89,7 @@ type shadowsocksConfig struct { prefix []byte } -func parseShadowsocksURL(url *url.URL) (*shadowsocksConfig, error) { +func parseShadowsocksURL(url url.URL) (*shadowsocksConfig, error) { // attempt to decode as SIP002 URI format and // fall back to legacy base64 format if decoding fails config, err := parseShadowsocksSIP002URL(url) @@ -91,7 +101,7 @@ func parseShadowsocksURL(url *url.URL) (*shadowsocksConfig, error) { // parseShadowsocksLegacyBase64URL parses URL based on legacy base64 format: // https://shadowsocks.org/doc/configs.html#uri-and-qr-code -func parseShadowsocksLegacyBase64URL(url *url.URL) (*shadowsocksConfig, error) { +func parseShadowsocksLegacyBase64URL(url url.URL) (*shadowsocksConfig, error) { config := &shadowsocksConfig{} if url.Host == "" { return nil, errors.New("host not specified") @@ -138,7 +148,7 @@ func parseShadowsocksLegacyBase64URL(url *url.URL) (*shadowsocksConfig, error) { // parseShadowsocksSIP002URL parses URL based on SIP002 format: // https://shadowsocks.org/doc/sip002.html -func parseShadowsocksSIP002URL(url *url.URL) (*shadowsocksConfig, error) { +func parseShadowsocksSIP002URL(url url.URL) (*shadowsocksConfig, error) { config := &shadowsocksConfig{} if url.Host == "" { return nil, errors.New("host not specified") @@ -188,7 +198,7 @@ func parseStringPrefix(utf8Str string) ([]byte, error) { return rawBytes, nil } -func sanitizeShadowsocksURL(u *url.URL) (string, error) { +func sanitizeShadowsocksURL(u url.URL) (string, error) { config, err := parseShadowsocksURL(u) if err != nil { return "", err diff --git a/x/configurl/shadowsocks_test.go b/x/configurl/shadowsocks_test.go index 093574a2..e24245c5 100644 --- a/x/configurl/shadowsocks_test.go +++ b/x/configurl/shadowsocks_test.go @@ -25,7 +25,7 @@ import ( func Test_sanitizeShadowsocksURL(t *testing.T) { ssURL, err := url.Parse("ss://YWVzLTEyOC1nY206dGVzdA@192.168.100.1:8888") require.NoError(t, err) - sanitized, err := sanitizeShadowsocksURL(ssURL) + sanitized, err := sanitizeShadowsocksURL(*ssURL) require.NoError(t, err) require.Equal(t, "ss://REDACTED@192.168.100.1:8888", sanitized) } @@ -33,116 +33,115 @@ func Test_sanitizeShadowsocksURL(t *testing.T) { func Test_sanitizeShadowsocksURL_withPrefix(t *testing.T) { ssURL, err := url.Parse("ss://YWVzLTEyOC1nY206dGVzdA@192.168.100.1:8888?prefix=foo") require.NoError(t, err) - sanitized, err := sanitizeShadowsocksURL(ssURL) + sanitized, err := sanitizeShadowsocksURL(*ssURL) require.NoError(t, err) require.Equal(t, "ss://REDACTED@192.168.100.1:8888?prefix=foo", sanitized) } func TestParseShadowsocksURLFullyEncoded(t *testing.T) { encoded := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString([]byte("aes-256-gcm:1234567@example.com:1234?prefix=HTTP%2F1.1%20")) - urls, err := parseConfig("ss://" + string(encoded) + "#outline-123") + config, err := ParseConfig("ss://" + string(encoded) + "#outline-123") require.NoError(t, err) - require.Equal(t, 1, len(urls)) + require.Nil(t, config.BaseConfig) - config, err := parseShadowsocksURL(urls[0]) + ssConfig, err := parseShadowsocksURL(config.URL) require.NoError(t, err) - require.Equal(t, "example.com:1234", config.serverAddress) - require.Equal(t, "HTTP/1.1 ", string(config.prefix)) + require.Equal(t, "example.com:1234", ssConfig.serverAddress) + require.Equal(t, "HTTP/1.1 ", string(ssConfig.prefix)) } func TestParseShadowsocksURLUserInfoEncoded(t *testing.T) { encoded := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString([]byte("aes-256-gcm:1234567")) - urls, err := parseConfig("ss://" + string(encoded) + "@example.com:1234?prefix=HTTP%2F1.1%20" + "#outline-123") + config, err := ParseConfig("ss://" + string(encoded) + "@example.com:1234?prefix=HTTP%2F1.1%20" + "#outline-123") require.NoError(t, err) - require.Equal(t, 1, len(urls)) + require.Nil(t, config.BaseConfig) - config, err := parseShadowsocksURL(urls[0]) + ssConfig, err := parseShadowsocksURL(config.URL) require.NoError(t, err) - require.Equal(t, "example.com:1234", config.serverAddress) - require.Equal(t, "HTTP/1.1 ", string(config.prefix)) + require.Equal(t, "example.com:1234", ssConfig.serverAddress) + require.Equal(t, "HTTP/1.1 ", string(ssConfig.prefix)) } func TestParseShadowsocksURLUserInfoLegacyEncoded(t *testing.T) { encoded := base64.StdEncoding.EncodeToString([]byte("aes-256-gcm:shadowsocks")) - urls, err := parseConfig("ss://" + string(encoded) + "@example.com:1234?prefix=HTTP%2F1.1%20" + "#outline-123") + config, err := ParseConfig("ss://" + string(encoded) + "@example.com:1234?prefix=HTTP%2F1.1%20" + "#outline-123") require.NoError(t, err) - require.Equal(t, 1, len(urls)) + require.Nil(t, config.BaseConfig) - config, err := parseShadowsocksURL(urls[0]) + ssConfig, err := parseShadowsocksURL(config.URL) require.NoError(t, err) - require.Equal(t, "example.com:1234", config.serverAddress) - require.Equal(t, "HTTP/1.1 ", string(config.prefix)) + require.Equal(t, "example.com:1234", ssConfig.serverAddress) + require.Equal(t, "HTTP/1.1 ", string(ssConfig.prefix)) } func TestLegacyEncodedShadowsocksURL(t *testing.T) { configString := "ss://YWVzLTEyOC1nY206c2hhZG93c29ja3M=@example.com:1234" - urls, err := parseConfig(configString) + config, err := ParseConfig(configString) require.NoError(t, err) - require.Equal(t, 1, len(urls)) + require.Nil(t, config.BaseConfig) - config, err := parseShadowsocksURL(urls[0]) + ssConfig, err := parseShadowsocksURL(config.URL) require.NoError(t, err) - require.Equal(t, "example.com:1234", config.serverAddress) + require.Equal(t, "example.com:1234", ssConfig.serverAddress) } func TestParseShadowsocksURLNoEncoding(t *testing.T) { configString := "ss://aes-256-gcm:1234567@example.com:1234" - urls, err := parseConfig(configString) + config, err := ParseConfig(configString) require.NoError(t, err) - require.Equal(t, 1, len(urls)) + require.Nil(t, config.BaseConfig) - config, err := parseShadowsocksURL(urls[0]) + ssConfig, err := parseShadowsocksURL(config.URL) require.NoError(t, err) - require.Equal(t, "example.com:1234", config.serverAddress) + require.Equal(t, "example.com:1234", ssConfig.serverAddress) } func TestParseShadowsocksURLInvalidCipherInfoFails(t *testing.T) { configString := "ss://aes-256-gcm1234567@example.com:1234" - urls, err := parseConfig(configString) + config, err := ParseConfig(configString) require.NoError(t, err) - require.Equal(t, 1, len(urls)) + require.Nil(t, config.BaseConfig) - _, err = parseShadowsocksURL(urls[0]) + _, err = parseShadowsocksURL(config.URL) require.Error(t, err) } func TestParseShadowsocksURLUnsupportedCypherFails(t *testing.T) { configString := "ss://Y2hhY2hhMjAtaWV0Zi1wb2x5MTMwnTpLeTUyN2duU3FEVFB3R0JpQ1RxUnlT@example.com:1234" - urls, err := parseConfig(configString) + config, err := ParseConfig(configString) require.NoError(t, err) - require.Equal(t, 1, len(urls)) + require.Nil(t, config.BaseConfig) - _, err = parseShadowsocksURL(urls[0]) + _, err = parseShadowsocksURL(config.URL) require.Error(t, err) } func TestParseShadowsocksLegacyBase64URL(t *testing.T) { encoded := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString([]byte("aes-256-gcm:1234567@example.com:1234?prefix=HTTP%2F1.1%20")) - urls, err := parseConfig("ss://" + string(encoded) + "#outline-123") + config, err := ParseConfig("ss://" + string(encoded) + "#outline-123") require.NoError(t, err) - require.Equal(t, 1, len(urls)) + require.Nil(t, config.BaseConfig) - config, err := parseShadowsocksLegacyBase64URL(urls[0]) + ssConfig, err := parseShadowsocksLegacyBase64URL(config.URL) require.NoError(t, err) - require.Equal(t, "example.com:1234", config.serverAddress) - require.Equal(t, "HTTP/1.1 ", string(config.prefix)) + require.Equal(t, "example.com:1234", ssConfig.serverAddress) + require.Equal(t, "HTTP/1.1 ", string(ssConfig.prefix)) } func TestParseShadowsocksSIP002URLUnsuccessful(t *testing.T) { encoded := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString([]byte("aes-256-gcm:1234567@example.com:1234?prefix=HTTP%2F1.1%20")) - urls, err := parseConfig("ss://" + string(encoded) + "#outline-123") + config, err := ParseConfig("ss://" + string(encoded) + "#outline-123") require.NoError(t, err) - require.Equal(t, 1, len(urls)) + require.Nil(t, config.BaseConfig) - _, err = parseShadowsocksSIP002URL(urls[0]) - - require.Error(t, err) + _, err = parseShadowsocksSIP002URL(config.URL) + require.Error(t, err, "URL is %v", config.URL.String()) } diff --git a/x/configurl/socks5.go b/x/configurl/socks5.go index 8fc6e238..cbb8de51 100644 --- a/x/configurl/socks5.go +++ b/x/configurl/socks5.go @@ -15,46 +15,59 @@ package configurl import ( - "net/url" + "context" "github.com/Jigsaw-Code/outline-sdk/transport" "github.com/Jigsaw-Code/outline-sdk/transport/socks5" ) -func wrapStreamDialerWithSOCKS5(innerSD func() (transport.StreamDialer, error), _ func() (transport.PacketDialer, error), configURL *url.URL) (transport.StreamDialer, error) { - sd, err := innerSD() - if err != nil { - return nil, err - } - endpoint := transport.StreamDialerEndpoint{Dialer: sd, Address: configURL.Host} - client, err := socks5.NewClient(&endpoint) - if err != nil { - return nil, err - } - userInfo := configURL.User - if userInfo != nil { - username := userInfo.Username() - password, _ := userInfo.Password() - err := client.SetCredentials([]byte(username), []byte(password)) +func registerSOCKS5StreamDialer(r TypeRegistry[transport.StreamDialer], typeID string, newSD BuildFunc[transport.StreamDialer]) { + r.RegisterType(typeID, func(ctx context.Context, config *Config) (transport.StreamDialer, error) { + return newSOCKS5Client(ctx, *config, newSD) + }) +} + +func registerSOCKS5PacketDialer(r TypeRegistry[transport.PacketDialer], typeID string, newSD BuildFunc[transport.StreamDialer], newPD BuildFunc[transport.PacketDialer]) { + r.RegisterType(typeID, func(ctx context.Context, config *Config) (transport.PacketDialer, error) { + client, err := newSOCKS5Client(ctx, *config, newSD) if err != nil { return nil, err } - } + pd, err := newPD(ctx, config.BaseConfig) + if err != nil { + return nil, err + } + client.EnablePacket(pd) + return transport.PacketListenerDialer{Listener: client}, nil + }) +} - return client, nil +func registerSOCKS5PacketListener(r TypeRegistry[transport.PacketListener], typeID string, newSD BuildFunc[transport.StreamDialer], newPD BuildFunc[transport.PacketDialer]) { + r.RegisterType(typeID, func(ctx context.Context, config *Config) (transport.PacketListener, error) { + client, err := newSOCKS5Client(ctx, *config, newSD) + if err != nil { + return nil, err + } + pd, err := newPD(ctx, config.BaseConfig) + if err != nil { + return nil, err + } + client.EnablePacket(pd) + return client, nil + }) } -func wrapPacketDialerWithSOCKS5(innerSD func() (transport.StreamDialer, error), innerPD func() (transport.PacketDialer, error), configURL *url.URL) (transport.PacketDialer, error) { - sd, err := innerSD() +func newSOCKS5Client(ctx context.Context, config Config, newSD BuildFunc[transport.StreamDialer]) (*socks5.Client, error) { + sd, err := newSD(ctx, config.BaseConfig) if err != nil { return nil, err } - streamEndpoint := transport.StreamDialerEndpoint{Dialer: sd, Address: configURL.Host} - client, err := socks5.NewClient(&streamEndpoint) + endpoint := transport.StreamDialerEndpoint{Dialer: sd, Address: config.URL.Host} + client, err := socks5.NewClient(&endpoint) if err != nil { return nil, err } - userInfo := configURL.User + userInfo := config.URL.User if userInfo != nil { username := userInfo.Username() password, _ := userInfo.Password() @@ -63,12 +76,5 @@ func wrapPacketDialerWithSOCKS5(innerSD func() (transport.StreamDialer, error), return nil, err } } - - pd, err := innerPD() - if err != nil { - return nil, err - } - client.EnablePacket(pd) - packetDialer := transport.PacketListenerDialer{Listener: client} - return packetDialer, nil + return client, nil } diff --git a/x/configurl/split.go b/x/configurl/split.go index feca4d75..e6d89b3a 100644 --- a/x/configurl/split.go +++ b/x/configurl/split.go @@ -15,23 +15,25 @@ package configurl import ( + "context" "fmt" - "net/url" "strconv" "github.com/Jigsaw-Code/outline-sdk/transport" "github.com/Jigsaw-Code/outline-sdk/transport/split" ) -func wrapStreamDialerWithSplit(innerSD func() (transport.StreamDialer, error), _ func() (transport.PacketDialer, error), configURL *url.URL) (transport.StreamDialer, error) { - sd, err := innerSD() - if err != nil { - return nil, err - } - prefixBytesStr := configURL.Opaque - prefixBytes, err := strconv.Atoi(prefixBytesStr) - if err != nil { - return nil, fmt.Errorf("prefixBytes is not a number: %v. Split config should be in split: format", prefixBytesStr) - } - return split.NewStreamDialer(sd, int64(prefixBytes)) +func registerSplitStreamDialer(r TypeRegistry[transport.StreamDialer], typeID string, newSD BuildFunc[transport.StreamDialer]) { + r.RegisterType(typeID, func(ctx context.Context, config *Config) (transport.StreamDialer, error) { + sd, err := newSD(ctx, config.BaseConfig) + if err != nil { + return nil, err + } + prefixBytesStr := config.URL.Opaque + prefixBytes, err := strconv.Atoi(prefixBytesStr) + if err != nil { + return nil, fmt.Errorf("prefixBytes is not a number: %v. Split config should be in split: format", prefixBytesStr) + } + return split.NewStreamDialer(sd, int64(prefixBytes)) + }) } diff --git a/x/configurl/tls.go b/x/configurl/tls.go index 2b29a1c3..7517cc43 100644 --- a/x/configurl/tls.go +++ b/x/configurl/tls.go @@ -15,6 +15,7 @@ package configurl import ( + "context" "fmt" "net/url" "strings" @@ -23,7 +24,21 @@ import ( "github.com/Jigsaw-Code/outline-sdk/transport/tls" ) -func parseOptions(configURL *url.URL) ([]tls.ClientOption, error) { +func registerTLSStreamDialer(r TypeRegistry[transport.StreamDialer], typeID string, newSD BuildFunc[transport.StreamDialer]) { + r.RegisterType(typeID, func(ctx context.Context, config *Config) (transport.StreamDialer, error) { + sd, err := newSD(ctx, config.BaseConfig) + if err != nil { + return nil, err + } + options, err := parseOptions(config.URL) + if err != nil { + return nil, err + } + return tls.NewStreamDialer(sd, options...) + }) +} + +func parseOptions(configURL url.URL) ([]tls.ClientOption, error) { query := configURL.Opaque values, err := url.ParseQuery(query) if err != nil { @@ -49,15 +64,3 @@ func parseOptions(configURL *url.URL) ([]tls.ClientOption, error) { } return options, nil } - -func wrapStreamDialerWithTLS(innerSD func() (transport.StreamDialer, error), _ func() (transport.PacketDialer, error), configURL *url.URL) (transport.StreamDialer, error) { - sd, err := innerSD() - if err != nil { - return nil, err - } - options, err := parseOptions(configURL) - if err != nil { - return nil, err - } - return tls.NewStreamDialer(sd, options...) -} diff --git a/x/configurl/tls_test.go b/x/configurl/tls_test.go index d9cdcc58..0aa59dc6 100644 --- a/x/configurl/tls_test.go +++ b/x/configurl/tls_test.go @@ -15,25 +15,16 @@ package configurl import ( - "net/url" "testing" - "github.com/Jigsaw-Code/outline-sdk/transport" "github.com/Jigsaw-Code/outline-sdk/transport/tls" "github.com/stretchr/testify/require" ) -func TestTLS(t *testing.T) { - tlsURL, err := url.Parse("tls") - require.NoError(t, err) - _, err = wrapStreamDialerWithTLS(func() (transport.StreamDialer, error) { return &transport.TCPDialer{}, nil }, nil, tlsURL) - require.NoError(t, err) -} - func TestTLS_SNI(t *testing.T) { - tlsURL, err := url.Parse("tls:sni=www.google.com") + config, err := ParseConfig("tls:sni=www.google.com") require.NoError(t, err) - options, err := parseOptions(tlsURL) + options, err := parseOptions(config.URL) require.NoError(t, err) cfg := tls.ClientConfig{ServerName: "host", CertificateName: "host"} for _, option := range options { @@ -44,9 +35,9 @@ func TestTLS_SNI(t *testing.T) { } func TestTLS_NoSNI(t *testing.T) { - tlsURL, err := url.Parse("tls:sni=") + config, err := ParseConfig("tls:sni=") require.NoError(t, err) - options, err := parseOptions(tlsURL) + options, err := parseOptions(config.URL) require.NoError(t, err) cfg := tls.ClientConfig{ServerName: "host", CertificateName: "host"} for _, option := range options { @@ -57,16 +48,16 @@ func TestTLS_NoSNI(t *testing.T) { } func TestTLS_MultipleSNI(t *testing.T) { - tlsURL, err := url.Parse("tls:sni=www.google.com&sni=second") + config, err := ParseConfig("tls:sni=www.google.com&sni=second") require.NoError(t, err) - _, err = parseOptions(tlsURL) + _, err = parseOptions(config.URL) require.Error(t, err) } func TestTLS_CertName(t *testing.T) { - tlsURL, err := url.Parse("tls:certname=www.google.com") + config, err := ParseConfig("tls:certname=www.google.com") require.NoError(t, err) - options, err := parseOptions(tlsURL) + options, err := parseOptions(config.URL) require.NoError(t, err) cfg := tls.ClientConfig{ServerName: "host", CertificateName: "host"} for _, option := range options { @@ -77,9 +68,9 @@ func TestTLS_CertName(t *testing.T) { } func TestTLS_Combined(t *testing.T) { - tlsURL, err := url.Parse("tls:SNI=sni.example.com&CertName=certname.example.com") + config, err := ParseConfig("tls:SNI=sni.example.com&CertName=certname.example.com") require.NoError(t, err) - options, err := parseOptions(tlsURL) + options, err := parseOptions(config.URL) require.NoError(t, err) cfg := tls.ClientConfig{ServerName: "host", CertificateName: "host"} for _, option := range options { @@ -90,8 +81,8 @@ func TestTLS_Combined(t *testing.T) { } func TestTLS_UnsupportedOption(t *testing.T) { - tlsURL, err := url.Parse("tls:unsupported") + config, err := ParseConfig("tls:unsupported") require.NoError(t, err) - _, err = parseOptions(tlsURL) + _, err = parseOptions(config.URL) require.Error(t, err) } diff --git a/x/configurl/tlsfrag.go b/x/configurl/tlsfrag.go new file mode 100644 index 00000000..4e2e0fe9 --- /dev/null +++ b/x/configurl/tlsfrag.go @@ -0,0 +1,39 @@ +// Copyright 2024 The Outline Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package configurl + +import ( + "context" + "fmt" + "strconv" + + "github.com/Jigsaw-Code/outline-sdk/transport" + "github.com/Jigsaw-Code/outline-sdk/transport/tlsfrag" +) + +func registerTLSFragStreamDialer(r TypeRegistry[transport.StreamDialer], typeID string, newSD BuildFunc[transport.StreamDialer]) { + r.RegisterType(typeID, func(ctx context.Context, config *Config) (transport.StreamDialer, error) { + sd, err := newSD(ctx, config.BaseConfig) + if err != nil { + return nil, err + } + lenStr := config.URL.Opaque + fixedLen, err := strconv.Atoi(lenStr) + if err != nil { + return nil, fmt.Errorf("invalid tlsfrag option: %v. It should be in tlsfrag: format", lenStr) + } + return tlsfrag.NewFixedLenStreamDialer(sd, fixedLen) + }) +} diff --git a/x/configurl/websocket.go b/x/configurl/websocket.go index 49861d15..31ef95e4 100644 --- a/x/configurl/websocket.go +++ b/x/configurl/websocket.go @@ -31,7 +31,7 @@ type wsConfig struct { udpPath string } -func parseWSConfig(configURL *url.URL) (*wsConfig, error) { +func parseWSConfig(configURL url.URL) (*wsConfig, error) { query := configURL.Opaque values, err := url.ParseQuery(query) if err != nil { @@ -71,66 +71,70 @@ func (c *wsToStreamConn) CloseWrite() error { return c.Close() } -func wrapStreamDialerWithWebSocket(innerSD func() (transport.StreamDialer, error), _ func() (transport.PacketDialer, error), configURL *url.URL) (transport.StreamDialer, error) { - sd, err := innerSD() - if err != nil { - return nil, err - } - config, err := parseWSConfig(configURL) - if err != nil { - return nil, err - } - if config.tcpPath == "" { - return nil, errors.New("must specify tcp_path") - } - return transport.FuncStreamDialer(func(ctx context.Context, addr string) (transport.StreamConn, error) { - wsURL := url.URL{Scheme: "ws", Host: addr, Path: config.tcpPath} - origin := url.URL{Scheme: "http", Host: addr} - wsCfg, err := websocket.NewConfig(wsURL.String(), origin.String()) +func registerWebsocketStreamDialer(r TypeRegistry[transport.StreamDialer], typeID string, newSD BuildFunc[transport.StreamDialer]) { + r.RegisterType(typeID, func(ctx context.Context, config *Config) (transport.StreamDialer, error) { + sd, err := newSD(ctx, config.BaseConfig) if err != nil { - return nil, fmt.Errorf("failed to create websocket config: %w", err) + return nil, err } - baseConn, err := sd.DialStream(ctx, addr) + wsConfig, err := parseWSConfig(config.URL) if err != nil { - return nil, fmt.Errorf("failed to connect to websocket endpoint: %w", err) + return nil, err } - wsConn, err := websocket.NewClient(wsCfg, baseConn) - if err != nil { - baseConn.Close() - return nil, fmt.Errorf("failed to create websocket client: %w", err) + if wsConfig.tcpPath == "" { + return nil, errors.New("must specify tcp_path") } - return &wsToStreamConn{wsConn}, nil - }), nil + return transport.FuncStreamDialer(func(ctx context.Context, addr string) (transport.StreamConn, error) { + wsURL := url.URL{Scheme: "ws", Host: addr, Path: wsConfig.tcpPath} + origin := url.URL{Scheme: "http", Host: addr} + wsCfg, err := websocket.NewConfig(wsURL.String(), origin.String()) + if err != nil { + return nil, fmt.Errorf("failed to create websocket config: %w", err) + } + baseConn, err := sd.DialStream(ctx, addr) + if err != nil { + return nil, fmt.Errorf("failed to connect to websocket endpoint: %w", err) + } + wsConn, err := websocket.NewClient(wsCfg, baseConn) + if err != nil { + baseConn.Close() + return nil, fmt.Errorf("failed to create websocket client: %w", err) + } + return &wsToStreamConn{wsConn}, nil + }), nil + }) } -func wrapPacketDialerWithWebSocket(innerSD func() (transport.StreamDialer, error), _ func() (transport.PacketDialer, error), configURL *url.URL) (transport.PacketDialer, error) { - sd, err := innerSD() - if err != nil { - return nil, err - } - config, err := parseWSConfig(configURL) - if err != nil { - return nil, err - } - if config.udpPath == "" { - return nil, errors.New("must specify udp_path") - } - return transport.FuncPacketDialer(func(ctx context.Context, addr string) (net.Conn, error) { - wsURL := url.URL{Scheme: "ws", Host: addr, Path: config.udpPath} - origin := url.URL{Scheme: "http", Host: addr} - wsCfg, err := websocket.NewConfig(wsURL.String(), origin.String()) +func registerWebsocketPacketDialer(r TypeRegistry[transport.PacketDialer], typeID string, newSD BuildFunc[transport.StreamDialer]) { + r.RegisterType(typeID, func(ctx context.Context, config *Config) (transport.PacketDialer, error) { + sd, err := newSD(ctx, config.BaseConfig) if err != nil { - return nil, fmt.Errorf("failed to create websocket config: %w", err) + return nil, err } - baseConn, err := sd.DialStream(ctx, addr) + wsConfig, err := parseWSConfig(config.URL) if err != nil { - return nil, fmt.Errorf("failed to connect to websocket endpoint: %w", err) + return nil, err } - wsConn, err := websocket.NewClient(wsCfg, baseConn) - if err != nil { - baseConn.Close() - return nil, fmt.Errorf("failed to create websocket client: %w", err) + if wsConfig.udpPath == "" { + return nil, errors.New("must specify udp_path") } - return wsConn, nil - }), nil + return transport.FuncPacketDialer(func(ctx context.Context, addr string) (net.Conn, error) { + wsURL := url.URL{Scheme: "ws", Host: addr, Path: wsConfig.udpPath} + origin := url.URL{Scheme: "http", Host: addr} + wsCfg, err := websocket.NewConfig(wsURL.String(), origin.String()) + if err != nil { + return nil, fmt.Errorf("failed to create websocket config: %w", err) + } + baseConn, err := sd.DialStream(ctx, addr) + if err != nil { + return nil, fmt.Errorf("failed to connect to websocket endpoint: %w", err) + } + wsConn, err := websocket.NewClient(wsCfg, baseConn) + if err != nil { + baseConn.Close() + return nil, fmt.Errorf("failed to create websocket client: %w", err) + } + return wsConn, nil + }), nil + }) } diff --git a/x/examples/fetch-speed/main.go b/x/examples/fetch-speed/main.go index 7e8db297..f6b131a7 100644 --- a/x/examples/fetch-speed/main.go +++ b/x/examples/fetch-speed/main.go @@ -58,7 +58,7 @@ func main() { os.Exit(1) } - dialer, err := configurl.NewDefaultConfigToDialer().NewStreamDialer(*transportFlag) + dialer, err := configurl.NewDefaultProviders().NewStreamDialer(context.Background(), *transportFlag) if err != nil { log.Fatalf("Could not create dialer: %v\n", err) } diff --git a/x/examples/fetch/main.go b/x/examples/fetch/main.go index 634385a5..3b7d2d02 100644 --- a/x/examples/fetch/main.go +++ b/x/examples/fetch/main.go @@ -84,7 +84,7 @@ func main() { os.Exit(1) } - dialer, err := configurl.NewDefaultConfigToDialer().NewStreamDialer(*transportFlag) + dialer, err := configurl.NewDefaultProviders().NewStreamDialer(context.Background(), *transportFlag) if err != nil { log.Fatalf("Could not create dialer: %v\n", err) } diff --git a/x/examples/http2transport/main.go b/x/examples/http2transport/main.go index 72fb3dcf..a0489fae 100644 --- a/x/examples/http2transport/main.go +++ b/x/examples/http2transport/main.go @@ -34,7 +34,7 @@ func main() { urlProxyPrefixFlag := flag.String("urlProxyPrefix", "/proxy", "Path where to run the URL proxy. Set to empty (\"\") to disable it.") flag.Parse() - dialer, err := configurl.NewDefaultConfigToDialer().NewStreamDialer(*transportFlag) + dialer, err := configurl.NewDefaultProviders().NewStreamDialer(context.Background(), *transportFlag) if err != nil { log.Fatalf("Could not create dialer: %v", err) diff --git a/x/examples/outline-cli/outline_device.go b/x/examples/outline-cli/outline_device.go index 858e2bfa..61333bf8 100644 --- a/x/examples/outline-cli/outline_device.go +++ b/x/examples/outline-cli/outline_device.go @@ -15,6 +15,7 @@ package main import ( + "context" "errors" "fmt" "net" @@ -39,7 +40,7 @@ type OutlineDevice struct { svrIP net.IP } -var configToDialer = configurl.NewDefaultConfigToDialer() +var configModule = configurl.NewDefaultProviders() func NewOutlineDevice(transportConfig string) (od *OutlineDevice, err error) { ip, err := resolveShadowsocksServerIPFromConfig(transportConfig) @@ -50,7 +51,7 @@ func NewOutlineDevice(transportConfig string) (od *OutlineDevice, err error) { svrIP: ip, } - if od.sd, err = configToDialer.NewStreamDialer(transportConfig); err != nil { + if od.sd, err = configModule.NewStreamDialer(context.TODO(), transportConfig); err != nil { return nil, fmt.Errorf("failed to create TCP dialer: %w", err) } if od.pp, err = newOutlinePacketProxy(transportConfig); err != nil { diff --git a/x/examples/outline-cli/outline_packet_proxy.go b/x/examples/outline-cli/outline_packet_proxy.go index fe202b8e..7c0b314b 100644 --- a/x/examples/outline-cli/outline_packet_proxy.go +++ b/x/examples/outline-cli/outline_packet_proxy.go @@ -36,7 +36,7 @@ type outlinePacketProxy struct { func newOutlinePacketProxy(transportConfig string) (opp *outlinePacketProxy, err error) { opp = &outlinePacketProxy{} - if opp.remotePl, err = configurl.NewPacketListener(transportConfig); err != nil { + if opp.remotePl, err = configurl.NewDefaultProviders().NewPacketListener(context.TODO(), transportConfig); err != nil { return nil, fmt.Errorf("failed to create UDP packet listener: %w", err) } if opp.remote, err = network.NewPacketProxyFromPacketListener(opp.remotePl); err != nil { diff --git a/x/examples/resolve/main.go b/x/examples/resolve/main.go index 2055b620..e55dd790 100644 --- a/x/examples/resolve/main.go +++ b/x/examples/resolve/main.go @@ -66,15 +66,15 @@ func main() { resolverAddr := *resolverFlag var resolver dns.Resolver - configToDialer := configurl.NewDefaultConfigToDialer() + providers := configurl.NewDefaultProviders() if *tcpFlag { - streamDialer, err := configToDialer.NewStreamDialer(*transportFlag) + streamDialer, err := providers.NewStreamDialer(context.Background(), *transportFlag) if err != nil { log.Fatalf("Could not create stream dialer: %v", err) } resolver = dns.NewTCPResolver(streamDialer, resolverAddr) } else { - packetDialer, err := configToDialer.NewPacketDialer(*transportFlag) + packetDialer, err := providers.NewPacketDialer(context.Background(), *transportFlag) if err != nil { log.Fatalf("Could not create packet dialer: %v", err) } diff --git a/x/examples/smart-proxy/main.go b/x/examples/smart-proxy/main.go index 15490355..59d40b28 100644 --- a/x/examples/smart-proxy/main.go +++ b/x/examples/smart-proxy/main.go @@ -86,12 +86,12 @@ func main() { log.Fatalf("Could not read config: %v", err) } - configToDialer := configurl.NewDefaultConfigToDialer() - packetDialer, err := configToDialer.NewPacketDialer(*transportFlag) + providers := configurl.NewDefaultProviders() + packetDialer, err := providers.NewPacketDialer(context.Background(), *transportFlag) if err != nil { log.Fatalf("Could not create packet dialer: %v", err) } - streamDialer, err := configToDialer.NewStreamDialer(*transportFlag) + streamDialer, err := providers.NewStreamDialer(context.Background(), *transportFlag) if err != nil { log.Fatalf("Could not create stream dialer: %v", err) } diff --git a/x/examples/test-connectivity/main.go b/x/examples/test-connectivity/main.go index 1378ee1b..7197b6f7 100644 --- a/x/examples/test-connectivity/main.go +++ b/x/examples/test-connectivity/main.go @@ -240,7 +240,7 @@ func main() { var mu sync.Mutex dnsReports := make([]dnsReport, 0) tcpReports := make([]tcpReport, 0) - configToDialer := configurl.NewDefaultConfigToDialer() + providers := configurl.NewDefaultProviders() onDNS := func(ctx context.Context, domain string) func(di httptrace.DNSDoneInfo) { dnsStart := time.Now() return func(di httptrace.DNSDoneInfo) { @@ -260,7 +260,7 @@ func main() { mu.Unlock() } } - configToDialer.BaseStreamDialer = transport.FuncStreamDialer(func(ctx context.Context, addr string) (transport.StreamConn, error) { + providers.StreamDialers.BaseInstance = transport.FuncStreamDialer(func(ctx context.Context, addr string) (transport.StreamConn, error) { hostname, _, err := net.SplitHostPort(addr) if err != nil { return nil, err @@ -284,13 +284,13 @@ func main() { } return newTCPTraceDialer(onDNS, onDial).DialStream(ctx, addr) }) - configToDialer.BasePacketDialer = transport.FuncPacketDialer(func(ctx context.Context, addr string) (net.Conn, error) { + providers.PacketDialers.BaseInstance = transport.FuncPacketDialer(func(ctx context.Context, addr string) (net.Conn, error) { return newUDPTraceDialer(onDNS).DialPacket(ctx, addr) }) switch proto { case "tcp": - streamDialer, err := configToDialer.NewStreamDialer(*transportFlag) + streamDialer, err := providers.NewStreamDialer(context.Background(), *transportFlag) if err != nil { slog.Error("Failed to create StreamDialer", "error", err) os.Exit(1) @@ -298,7 +298,7 @@ func main() { resolver = dns.NewTCPResolver(streamDialer, resolverAddress) case "udp": - packetDialer, err := configToDialer.NewPacketDialer(*transportFlag) + packetDialer, err := providers.NewPacketDialer(context.Background(), *transportFlag) if err != nil { slog.Error("Failed to create PacketDialer", "error", err) os.Exit(1) diff --git a/x/examples/ws2endpoint/main.go b/x/examples/ws2endpoint/main.go index 5fac6aec..e09865e7 100644 --- a/x/examples/ws2endpoint/main.go +++ b/x/examples/ws2endpoint/main.go @@ -60,10 +60,10 @@ func main() { defer listener.Close() log.Printf("Proxy listening on %v\n", listener.Addr().String()) - config2Dialer := configurl.NewDefaultConfigToDialer() + providers := configurl.NewDefaultProviders() mux := http.NewServeMux() if *tcpPathFlag != "" { - dialer, err := config2Dialer.NewStreamDialer(*transportFlag) + dialer, err := providers.NewStreamDialer(context.Background(), *transportFlag) if err != nil { log.Fatalf("Could not create stream dialer: %v", err) } @@ -90,7 +90,7 @@ func main() { mux.Handle(*tcpPathFlag, http.StripPrefix(*tcpPathFlag, handler)) } if *udpPathFlag != "" { - dialer, err := config2Dialer.NewPacketDialer(*transportFlag) + dialer, err := providers.NewPacketDialer(context.Background(), *transportFlag) if err != nil { log.Fatalf("Could not create stream dialer: %v", err) } diff --git a/x/go.mod b/x/go.mod index a1dfe72a..7ad5e856 100644 --- a/x/go.mod +++ b/x/go.mod @@ -3,7 +3,7 @@ module github.com/Jigsaw-Code/outline-sdk/x go 1.21 require ( - github.com/Jigsaw-Code/outline-sdk v0.0.17-0.20240726212635-470a9290ec57 + github.com/Jigsaw-Code/outline-sdk v0.0.17 // Use github.com/Psiphon-Labs/psiphon-tunnel-core@staging-client as per // https://github.com/Psiphon-Labs/psiphon-tunnel-core/?tab=readme-ov-file#using-psiphon-with-go-modules github.com/Psiphon-Labs/psiphon-tunnel-core v1.0.11-0.20240619172145-03cade11f647 diff --git a/x/go.sum b/x/go.sum index 033308fb..67b2fc27 100644 --- a/x/go.sum +++ b/x/go.sum @@ -6,8 +6,8 @@ github.com/AndreasBriese/bbloom v0.0.0-20170702084017-28f7e881ca57 h1:CVuXDbdzPW github.com/AndreasBriese/bbloom v0.0.0-20170702084017-28f7e881ca57/go.mod h1:bOvUY6CB00SOBii9/FifXqc0awNKxLFCL/+pkDPuyl8= github.com/BurntSushi/toml v1.3.2 h1:o7IhLm0Msx3BaB+n3Ag7L8EVlByGnpq14C4YWiu/gL8= github.com/BurntSushi/toml v1.3.2/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= -github.com/Jigsaw-Code/outline-sdk v0.0.17-0.20240726212635-470a9290ec57 h1:XNSV0dGW48J8DmdmCnk/txGHf9glAPqa6Xme/rFWn7c= -github.com/Jigsaw-Code/outline-sdk v0.0.17-0.20240726212635-470a9290ec57/go.mod h1:e1oQZbSdLJBBuHgfeQsgEkvkuyIePPwstUeZRGq0KO8= +github.com/Jigsaw-Code/outline-sdk v0.0.17 h1:xkGsp3fs+5EbIZEeCEPI1rl0oyCrOLuO7YSK0gB75HE= +github.com/Jigsaw-Code/outline-sdk v0.0.17/go.mod h1:CFDKyGZA4zatKE4vMLe8TyQpZCyINOeRFbMAmYHxodw= github.com/Psiphon-Inc/rotate-safe-writer v0.0.0-20210303140923-464a7a37606e h1:NPfqIbzmijrl0VclX2t8eO5EPBhqe47LLGKpRrcVjXk= github.com/Psiphon-Inc/rotate-safe-writer v0.0.0-20210303140923-464a7a37606e/go.mod h1:ZdY5pBfat/WVzw3eXbIf7N1nZN0XD5H5+X8ZMDWbCs4= github.com/Psiphon-Labs/bolt v0.0.0-20200624191537-23cedaef7ad7 h1:Hx/NCZTnvoKZuIBwSmxE58KKoNLXIGG6hBJYN7pj9Ag= @@ -72,8 +72,8 @@ github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEe github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= github.com/gobwas/glob v0.2.4-0.20180402141543-f00a7392b439 h1:T6zlOdzrYuHf6HUKujm9bzkzbZ5Iv/xf6rs8BHZDpoI= github.com/gobwas/glob v0.2.4-0.20180402141543-f00a7392b439/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8= -github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e h1:1r7pUrabqp18hOBcwBwiTsbnFeTZHV9eER/QT5JVZxY= -github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= +github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= @@ -99,9 +99,8 @@ github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORN github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= -github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/libp2p/go-reuseport v0.4.0 h1:nR5KU7hD0WxXCJbmw7r2rhRYruNRl2koHw8fQscQm2s= github.com/libp2p/go-reuseport v0.4.0/go.mod h1:ZtI03j/wO5hZVDFo2jKywN6bYKWLOy8Se6DrI2E1cLU= github.com/lmittmann/tint v1.0.5 h1:NQclAutOfYsqs2F1Lenue6OoWCajs5wJcP3DfWVpePw= diff --git a/x/httpproxy/connect_handler.go b/x/httpproxy/connect_handler.go index 69512a7e..0922cce2 100644 --- a/x/httpproxy/connect_handler.go +++ b/x/httpproxy/connect_handler.go @@ -51,8 +51,8 @@ func (d *sanitizeErrorDialer) DialStream(ctx context.Context, addr string) (tran } type connectHandler struct { - dialer *sanitizeErrorDialer - dialerConfig *configurl.ConfigToDialer + dialer *sanitizeErrorDialer + providers *configurl.ProviderContainer } var _ http.Handler = (*connectHandler)(nil) @@ -77,7 +77,7 @@ func (h *connectHandler) ServeHTTP(proxyResp http.ResponseWriter, proxyReq *http // Dial the target. transportConfig := proxyReq.Header.Get("Transport") - dialer, err := h.dialerConfig.NewStreamDialer(transportConfig) + dialer, err := h.providers.NewStreamDialer(proxyReq.Context(), transportConfig) if err != nil { // Because we sanitize the base dialer error, it's safe to return error details here. http.Error(proxyResp, fmt.Sprintf("Invalid config in Transport header: %v", err), http.StatusBadRequest) @@ -149,7 +149,7 @@ func NewConnectHandler(dialer transport.StreamDialer) http.Handler { // of the base dialer (e.g. access key credentials) to the user. sd := &sanitizeErrorDialer{dialer} // TODO(fortuna): Inject the config parser - dialerConfig := configurl.NewDefaultConfigToDialer() - dialerConfig.BaseStreamDialer = sd - return &connectHandler{sd, dialerConfig} + providers := configurl.NewDefaultProviders() + providers.StreamDialers.BaseInstance = sd + return &connectHandler{sd, providers} } diff --git a/x/mobileproxy/mobileproxy.go b/x/mobileproxy/mobileproxy.go index 19ea7563..5b832b7b 100644 --- a/x/mobileproxy/mobileproxy.go +++ b/x/mobileproxy/mobileproxy.go @@ -152,12 +152,12 @@ type StreamDialer struct { transport.StreamDialer } -var configToDialer = configurl.NewDefaultConfigToDialer() +var configModule = configurl.NewDefaultProviders() // NewStreamDialerFromConfig creates a [StreamDialer] based on the given config. // The config format is specified in https://pkg.go.dev/github.com/Jigsaw-Code/outline-sdk/x/config#hdr-Config_Format. func NewStreamDialerFromConfig(transportConfig string) (*StreamDialer, error) { - dialer, err := configToDialer.NewStreamDialer(transportConfig) + dialer, err := configModule.NewStreamDialer(context.Background(), transportConfig) if err != nil { return nil, err } diff --git a/x/smart/stream_dialer.go b/x/smart/stream_dialer.go index 092aac1c..54c089ef 100644 --- a/x/smart/stream_dialer.go +++ b/x/smart/stream_dialer.go @@ -231,8 +231,8 @@ func (f *StrategyFinder) findTLS(ctx context.Context, testDomains []string, base if len(tlsConfig) == 0 { return nil, errors.New("config for TLS is empty. Please specify at least one transport") } - var configToDialer = configurl.NewDefaultConfigToDialer() - configToDialer.BaseStreamDialer = baseDialer + var configModule = configurl.NewDefaultProviders() + configModule.StreamDialers.BaseInstance = baseDialer ctx, searchDone := context.WithCancel(ctx) defer searchDone() @@ -242,7 +242,7 @@ func (f *StrategyFinder) findTLS(ctx context.Context, testDomains []string, base Config string } result, err := raceTests(ctx, 250*time.Millisecond, tlsConfig, func(transportCfg string) (*SearchResult, error) { - tlsDialer, err := configToDialer.NewStreamDialer(transportCfg) + tlsDialer, err := configModule.NewStreamDialer(ctx, transportCfg) if err != nil { return nil, fmt.Errorf("WrapStreamDialer failed: %w", err) } From 82b33bc00a29c9007a3b4b6cb629d7bfe66ddfa7 Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Fri, 1 Nov 2024 18:42:00 -0400 Subject: [PATCH 02/21] feat(x): add h3/QUIC support to the fetch tool (#305) --- x/configurl/doc.go | 2 +- x/examples/fetch/main.go | 146 +++++++++++++++++++++++++++++++-------- x/go.mod | 23 +++--- x/go.sum | 44 ++++++------ 4 files changed, 153 insertions(+), 62 deletions(-) diff --git a/x/configurl/doc.go b/x/configurl/doc.go index dfee54c6..056c7a23 100644 --- a/x/configurl/doc.go +++ b/x/configurl/doc.go @@ -132,7 +132,7 @@ DPI Evasion - To add packet splitting to a Shadowsocks server for enhanced DPI e Defining custom strategies - You can define your custom strategy by implementing and registering [BuildFunc[ObjectType]] functions: // Create new config parser. - // p := configurl.NewProviderContainer + // p := configurl.NewProviderContainer() // or p := configurl.NewDefaultProviders() // Register your custom dialer. diff --git a/x/examples/fetch/main.go b/x/examples/fetch/main.go index 3b7d2d02..cc945779 100644 --- a/x/examples/fetch/main.go +++ b/x/examples/fetch/main.go @@ -17,10 +17,11 @@ package main import ( "bufio" "context" + "crypto/tls" "flag" "fmt" "io" - "log" + "log/slog" "net" "net/http" "net/textproto" @@ -30,10 +31,12 @@ import ( "time" "github.com/Jigsaw-Code/outline-sdk/x/configurl" + "github.com/lmittmann/tint" + "github.com/quic-go/quic-go" + "github.com/quic-go/quic-go/http3" + "golang.org/x/term" ) -var debugLog log.Logger = *log.New(io.Discard, "", 0) - type stringArrayFlagValue []string func (v *stringArrayFlagValue) String() string { @@ -52,8 +55,24 @@ func init() { } } +func overrideAddress(original string, newHost string, newPort string) (string, error) { + host, port, err := net.SplitHostPort(original) + if err != nil { + return "", fmt.Errorf("invalid address: %w", err) + } + if newHost != "" { + host = newHost + } + if newPort != "" { + port = newPort + } + return net.JoinHostPort(host, port), nil +} + func main() { verboseFlag := flag.Bool("v", false, "Enable debug output") + tlsKeyLogFlag := flag.String("tls-key-log", "", "Filename to write the TLS key log to allow for decryption on Wireshark") + protoFlag := flag.String("proto", "h1", "HTTP version to use (h1, h2, h3)") transportFlag := flag.String("transport", "", "Transport config") addressFlag := flag.String("address", "", "Address to connect to. If empty, use the URL authority") methodFlag := flag.String("method", "GET", "The HTTP method to use") @@ -63,9 +82,15 @@ func main() { flag.Parse() + logLevel := slog.LevelInfo if *verboseFlag { - debugLog = *log.New(os.Stderr, "[DEBUG] ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile) + logLevel = slog.LevelDebug } + slog.SetDefault(slog.New(tint.NewHandler( + os.Stderr, + &tint.Options{NoColor: !term.IsTerminal(int(os.Stderr.Fd())), Level: logLevel}, + ))) + var overrideHost, overridePort string if *addressFlag != "" { var err error @@ -79,47 +104,104 @@ func main() { url := flag.Arg(0) if url == "" { - log.Println("Need to pass the URL to fetch in the command-line") + slog.Error("Need to pass the URL to fetch in the command-line") flag.Usage() os.Exit(1) } - dialer, err := configurl.NewDefaultProviders().NewStreamDialer(context.Background(), *transportFlag) - if err != nil { - log.Fatalf("Could not create dialer: %v\n", err) + httpClient := &http.Client{ + Timeout: time.Duration(*timeoutSecFlag) * time.Second, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + + var tlsConfig tls.Config + if *tlsKeyLogFlag != "" { + f, err := os.Create(*tlsKeyLogFlag) + if err != nil { + slog.Error("Failed to creare TLS key log file", "error", err) + os.Exit(1) + } + defer f.Close() + tlsConfig.KeyLogWriter = f } - dialContext := func(ctx context.Context, network, addr string) (net.Conn, error) { - host, port, err := net.SplitHostPort(addr) + providers := configurl.NewDefaultProviders() + if *protoFlag == "h1" || *protoFlag == "h2" { + dialer, err := providers.NewStreamDialer(context.Background(), *transportFlag) if err != nil { - return nil, fmt.Errorf("invalid address: %w", err) + slog.Error("Could not create dialer", "error", err) + os.Exit(1) } - if overrideHost != "" { - host = overrideHost + dialContext := func(ctx context.Context, network, addr string) (net.Conn, error) { + addressToDial, err := overrideAddress(addr, overrideHost, overridePort) + if err != nil { + return nil, fmt.Errorf("invalid address: %w", err) + } + if !strings.HasPrefix(network, "tcp") { + return nil, fmt.Errorf("protocol not supported: %v", network) + } + return dialer.DialStream(ctx, addressToDial) } - if overridePort != "" { - port = overridePort + if *protoFlag == "h1" { + tlsConfig.NextProtos = []string{"http/1.1"} + httpClient.Transport = &http.Transport{ + DialContext: dialContext, + TLSClientConfig: &tlsConfig, + } + } else if *protoFlag == "h2" { + tlsConfig.NextProtos = []string{"h2"} + httpClient.Transport = &http.Transport{ + DialContext: dialContext, + TLSClientConfig: &tlsConfig, + ForceAttemptHTTP2: true, + } } - if !strings.HasPrefix(network, "tcp") { - return nil, fmt.Errorf("protocol not supported: %v", network) + } else if *protoFlag == "h3" { + listener, err := providers.NewPacketListener(context.Background(), *transportFlag) + if err != nil { + slog.Error("Could not create listener", "error", err) + os.Exit(1) } - return dialer.DialStream(ctx, net.JoinHostPort(host, port)) - } - httpClient := &http.Client{ - Transport: &http.Transport{DialContext: dialContext}, - Timeout: time.Duration(*timeoutSecFlag) * time.Second, - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - }, + conn, err := listener.ListenPacket(context.Background()) + if err != nil { + slog.Error("Could not create PacketConn", "error", err) + os.Exit(1) + } + tr := &quic.Transport{ + Conn: conn, + } + defer tr.Close() + httpClient.Transport = &http3.Transport{ + TLSClientConfig: &tlsConfig, + Dial: func(ctx context.Context, addr string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) { + addressToDial, err := overrideAddress(addr, overrideHost, overridePort) + if err != nil { + return nil, fmt.Errorf("invalid address: %w", err) + } + udpAddr, err := net.ResolveUDPAddr("udp", addressToDial) + if err != nil { + return nil, err + } + return tr.DialEarly(ctx, udpAddr, tlsConf, quicConf) + }, + Logger: slog.Default(), + } + } else { + slog.Error("Invalid HTTP protocol", "proto", *protoFlag) + os.Exit(1) } req, err := http.NewRequest(*methodFlag, url, nil) if err != nil { - log.Fatalln("Failed to create request:", err) + slog.Error("Failed to create request", "error", err) + os.Exit(1) } headerText := strings.Join(headersFlag, "\r\n") + "\r\n\r\n" h, err := textproto.NewReader(bufio.NewReader(strings.NewReader(headerText))).ReadMIMEHeader() if err != nil { - log.Fatalf("invalid header line: %v", err) + slog.Error("Invalid header line", "error", err) + os.Exit(1) } for name, values := range h { for _, value := range values { @@ -128,19 +210,23 @@ func main() { } resp, err := httpClient.Do(req) if err != nil { - log.Fatalf("HTTP request failed: %v\n", err) + slog.Error("HTTP request failed", "error", err) + os.Exit(1) } defer resp.Body.Close() if *verboseFlag { + slog.Info("HTTP Proto", "version", resp.Proto) + slog.Info("HTTP Status", "status", resp.Status) for k, v := range resp.Header { - debugLog.Printf("%v: %v", k, v) + slog.Debug("Header", "key", k, "value", v) } } _, err = io.Copy(os.Stdout, resp.Body) fmt.Println() if err != nil { - log.Fatalf("Read of page body failed: %v\n", err) + slog.Error("Read of page body failed", "error", err) + os.Exit(1) } } diff --git a/x/go.mod b/x/go.mod index 7ad5e856..ef66a9d5 100644 --- a/x/go.mod +++ b/x/go.mod @@ -1,6 +1,6 @@ module github.com/Jigsaw-Code/outline-sdk/x -go 1.21 +go 1.22 require ( github.com/Jigsaw-Code/outline-sdk v0.0.17 @@ -8,13 +8,14 @@ require ( // https://github.com/Psiphon-Labs/psiphon-tunnel-core/?tab=readme-ov-file#using-psiphon-with-go-modules github.com/Psiphon-Labs/psiphon-tunnel-core v1.0.11-0.20240619172145-03cade11f647 github.com/lmittmann/tint v1.0.5 + github.com/quic-go/quic-go v0.48.1 github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 github.com/stretchr/testify v1.9.0 github.com/vishvananda/netlink v1.1.0 golang.org/x/mobile v0.0.0-20240520174638-fa72addaaa1b - golang.org/x/net v0.25.0 - golang.org/x/sys v0.20.0 - golang.org/x/term v0.20.0 + golang.org/x/net v0.28.0 + golang.org/x/sys v0.23.0 + golang.org/x/term v0.23.0 ) require ( @@ -57,7 +58,7 @@ require ( github.com/pion/transport/v2 v2.2.3 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/quic-go/qpack v0.4.0 // indirect + github.com/quic-go/qpack v0.5.1 // indirect github.com/refraction-networking/conjure v0.7.11-0.20240130155008-c8df96195ab2 // indirect github.com/refraction-networking/ed25519 v0.1.2 // indirect github.com/refraction-networking/gotapdance v1.7.10 // indirect @@ -71,12 +72,12 @@ require ( github.com/wader/filtertransport v0.0.0-20200316221534-bdd9e61eee78 // indirect gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/goptlib v1.5.0 // indirect go.uber.org/mock v0.4.0 // indirect - golang.org/x/crypto v0.23.0 // indirect - golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect + golang.org/x/crypto v0.26.0 // indirect + golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect golang.org/x/mod v0.17.0 // indirect - golang.org/x/sync v0.7.0 // indirect - golang.org/x/text v0.15.0 // indirect - golang.org/x/tools v0.21.0 // indirect - google.golang.org/protobuf v1.31.0 // indirect + golang.org/x/sync v0.8.0 // indirect + golang.org/x/text v0.17.0 // indirect + golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect + google.golang.org/protobuf v1.33.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/x/go.sum b/x/go.sum index 67b2fc27..70bbc73a 100644 --- a/x/go.sum +++ b/x/go.sum @@ -146,8 +146,10 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= -github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A= +github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI= +github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg= +github.com/quic-go/quic-go v0.48.1 h1:y/8xmfWI9qmGTc+lBr4jKRUWLGSlSigv847ULJ4hYXA= +github.com/quic-go/quic-go v0.48.1/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs= github.com/refraction-networking/conjure v0.7.11-0.20240130155008-c8df96195ab2 h1:m2ZH6WV69otVmBpWbk8et3MypHFsjcYXTNrknQKS/PY= github.com/refraction-networking/conjure v0.7.11-0.20240130155008-c8df96195ab2/go.mod h1:7KuAtYfSL0K0WpCScjN9YKiOZ4AQ/8IzSjUtVwWbSv8= github.com/refraction-networking/ed25519 v0.1.2 h1:08kJZUkAlY7a7cZGosl1teGytV+QEoNxPO7NnRvAB+g= @@ -203,10 +205,10 @@ golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= -golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= -golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= -golang.org/x/exp v0.0.0-20221205204356-47842c84f3db h1:D/cFflL63o2KSLJIwjlcIt8PR064j/xsmdEJL/YvY/o= -golang.org/x/exp v0.0.0-20221205204356-47842c84f3db/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= +golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= +golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54= +golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM= +golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= golang.org/x/mobile v0.0.0-20240520174638-fa72addaaa1b h1:WX7nnnLfCEXg+FmdYZPai2XuP3VqCP1HZVMST0n9DF0= golang.org/x/mobile v0.0.0-20240520174638-fa72addaaa1b/go.mod h1:EiXZlVfUTaAyySFVJb9rsODuiO+WXu8HrUuySb7nYFw= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= @@ -222,14 +224,14 @@ golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= -golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= -golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= +golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= +golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190606203320-7fc4e5ec1444/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -245,8 +247,8 @@ golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= -golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.23.0 h1:YfKFowiIMvtgl1UERQoTPPToxltDeZfbj4H7dVUCwmM= +golang.org/x/sys v0.23.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= @@ -254,30 +256,32 @@ golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= -golang.org/x/term v0.20.0 h1:VnkxpohqXaOBYJtBmEppKUG6mXpi+4O6purfc2+sMhw= -golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= +golang.org/x/term v0.23.0 h1:F6D4vR+EHoL9/sWAWgAR1H2DcHr4PareCbAaCo1RpuU= +golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= -golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= -golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= +golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= +golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.21.0 h1:qc0xYgIbsSDt9EyWz05J5wfa7LOVW0YTLOXrqdLAWIw= -golang.org/x/tools v0.21.0/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/grpc v1.59.0 h1:Z5Iec2pjwb+LEOqzpB2MR12/eKFhDPhuqW91O+4bwUk= google.golang.org/grpc v1.59.0/go.mod h1:aUPDwccQo6OTjy7Hct4AfBPD1GptF4fyUjIkQ9YtF98= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= -google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= +google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= From c2399f29a3c5d1843155fd327afcc8296478b60e Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Fri, 1 Nov 2024 20:45:38 -0400 Subject: [PATCH 03/21] Proper cleanup --- x/examples/fetch/main.go | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/x/examples/fetch/main.go b/x/examples/fetch/main.go index cc945779..955b867a 100644 --- a/x/examples/fetch/main.go +++ b/x/examples/fetch/main.go @@ -115,6 +115,7 @@ func main() { return http.ErrUseLastResponse }, } + defer httpClient.CloseIdleConnections() var tlsConfig tls.Config if *tlsKeyLogFlag != "" { @@ -168,11 +169,11 @@ func main() { slog.Error("Could not create PacketConn", "error", err) os.Exit(1) } - tr := &quic.Transport{ + quicTransport := &quic.Transport{ Conn: conn, } - defer tr.Close() - httpClient.Transport = &http3.Transport{ + defer quicTransport.Close() + httpTransport := &http3.Transport{ TLSClientConfig: &tlsConfig, Dial: func(ctx context.Context, addr string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) { addressToDial, err := overrideAddress(addr, overrideHost, overridePort) @@ -183,10 +184,12 @@ func main() { if err != nil { return nil, err } - return tr.DialEarly(ctx, udpAddr, tlsConf, quicConf) + return quicTransport.DialEarly(ctx, udpAddr, tlsConf, quicConf) }, Logger: slog.Default(), } + defer httpTransport.Close() + httpClient.Transport = httpTransport } else { slog.Error("Invalid HTTP protocol", "proto", *protoFlag) os.Exit(1) From b6e887a86731814a9c5bbec652caa17bc51bee71 Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Mon, 4 Nov 2024 19:00:06 -0500 Subject: [PATCH 04/21] Advanced split --- transport/split/stream_dialer.go | 9 +++-- transport/split/writer.go | 69 ++++++++++++++++++++++++++------ transport/split/writer_test.go | 24 +++++++++++ 3 files changed, 86 insertions(+), 16 deletions(-) diff --git a/transport/split/stream_dialer.go b/transport/split/stream_dialer.go index c6b59d1a..2c123060 100644 --- a/transport/split/stream_dialer.go +++ b/transport/split/stream_dialer.go @@ -21,16 +21,19 @@ import ( "github.com/Jigsaw-Code/outline-sdk/transport" ) +// splitDialer is a [transport.StreamDialer] that implements the split strategy. +// Use [NewStreamDialer] to create new instances. type splitDialer struct { dialer transport.StreamDialer splitPoint int64 + options []Option } var _ transport.StreamDialer = (*splitDialer)(nil) // NewStreamDialer creates a [transport.StreamDialer] that splits the outgoing stream after writing "prefixBytes" bytes -// using [SplitWriter]. -func NewStreamDialer(dialer transport.StreamDialer, prefixBytes int64) (transport.StreamDialer, error) { +// using [splitWriter]. If "repeatsNumber" is not 0, will split that many times, skipping "skipBytes" in between packets. +func NewStreamDialer(dialer transport.StreamDialer, prefixBytes int64, options ...Option) (transport.StreamDialer, error) { if dialer == nil { return nil, errors.New("argument dialer must not be nil") } @@ -43,5 +46,5 @@ func (d *splitDialer) DialStream(ctx context.Context, remoteAddr string) (transp if err != nil { return nil, err } - return transport.WrapConn(innerConn, innerConn, NewWriter(innerConn, d.splitPoint)), nil + return transport.WrapConn(innerConn, innerConn, NewWriter(innerConn, d.splitPoint, d.options...)), nil } diff --git a/transport/split/writer.go b/transport/split/writer.go index b1b3e140..bbe4ee81 100644 --- a/transport/split/writer.go +++ b/transport/split/writer.go @@ -18,9 +18,15 @@ import ( "io" ) +type repeatedSplit struct { + count int + bytes int64 +} + type splitWriter struct { - writer io.Writer - prefixBytes int64 + writer io.Writer + nextSplitBytes int64 + remainingSplits []repeatedSplit } var _ io.Writer = (*splitWriter)(nil) @@ -36,32 +42,69 @@ var _ io.ReaderFrom = (*splitWriterReaderFrom)(nil) // A write will end right after byte index prefixBytes - 1, before a write starting at byte index prefixBytes. // For example, if you have a write of [0123456789] and prefixBytes = 3, you will get writes [012] and [3456789]. // If the input writer is a [io.ReaderFrom], the output writer will be too. -func NewWriter(writer io.Writer, prefixBytes int64) io.Writer { - sw := &splitWriter{writer, prefixBytes} - if rf, ok := writer.(io.ReaderFrom); ok { - return &splitWriterReaderFrom{sw, rf} +// It's possible to enable multiple splits with the [EnableRepeatSplit] option. +// In that cases, splits will happen at positions prefixBytes + i * skipBytes, for 0 <= i < count. +// This means that after the initial split, count splits will happen every skipBytes bytes. +// Example: +// prefixBytes = 1 +// count = 2 +// skipBytes = 6 +// Array of [0 1 3 2 4 5 6 7 8 9 10 11 12 13 14 15 16 ...] will become +// [0] [1 2 3 4 5 6] [7 8 9 10 11 12] [13 14 15 16 ...] +func NewWriter(writer io.Writer, prefixBytes int64, options ...Option) io.Writer { + sw := &splitWriter{writer: writer, nextSplitBytes: prefixBytes, remainingSplits: []repeatedSplit{}} + for _, option := range options { + option(sw) + } + if len(sw.remainingSplits) == 0 { + // TODO(fortuna): Support ReaderFrom for repeat split. + if rf, ok := writer.(io.ReaderFrom); ok { + return &splitWriterReaderFrom{sw, rf} + } } return sw } +type Option func(w *splitWriter) + +// AddSplitSequence will add count splits, each of skipBytes length. +func AddSplitSequence(count int, skipBytes int64) Option { + return func(w *splitWriter) { + if count > 0 { + w.remainingSplits = append(w.remainingSplits, repeatedSplit{count: count, bytes: skipBytes}) + } + } +} + func (w *splitWriterReaderFrom) ReadFrom(source io.Reader) (int64, error) { - reader := io.MultiReader(io.LimitReader(source, w.prefixBytes), source) + reader := io.MultiReader(io.LimitReader(source, w.nextSplitBytes), source) written, err := w.rf.ReadFrom(reader) - w.prefixBytes -= written + w.nextSplitBytes -= written return written, err } func (w *splitWriter) Write(data []byte) (written int, err error) { - if 0 < w.prefixBytes && w.prefixBytes < int64(len(data)) { - written, err = w.writer.Write(data[:w.prefixBytes]) - w.prefixBytes -= int64(written) + for 0 < w.nextSplitBytes && w.nextSplitBytes < int64(len(data)) { + dataToSend := data[:w.nextSplitBytes] + n, err := w.writer.Write(dataToSend) + written += n + w.nextSplitBytes -= int64(n) if err != nil { return written, err } - data = data[written:] + data = data[n:] + + // Split done. Update nextSplitBytes. + if len(w.remainingSplits) > 0 { + w.nextSplitBytes = w.remainingSplits[0].bytes + w.remainingSplits[0].count -= 1 + if w.remainingSplits[0].count == 0 { + w.remainingSplits = w.remainingSplits[1:] + } + } } n, err := w.writer.Write(data) written += n - w.prefixBytes -= int64(n) + w.nextSplitBytes -= int64(n) return written, err } diff --git a/transport/split/writer_test.go b/transport/split/writer_test.go index 025402fa..276fe7bf 100644 --- a/transport/split/writer_test.go +++ b/transport/split/writer_test.go @@ -84,6 +84,30 @@ func TestWrite_Compound(t *testing.T) { require.Equal(t, [][]byte{[]byte("R"), []byte("equ"), []byte("est")}, innerWriter.writes) } +func TestWrite_RepeatNumber3_SkipBytes5(t *testing.T) { + var innerWriter collectWrites + splitWriter := NewWriter(&innerWriter, 1, AddSplitSequence(3, 5)) + n, err := splitWriter.Write([]byte("RequestRequestRequest.")) + require.NoError(t, err) + require.Equal(t, 7*3+1, n) + require.Equal(t, [][]byte{ + []byte("R"), // prefix + []byte("eques"), // split 1 + []byte("tRequ"), // split 2 + []byte("estRe"), // split 3 + []byte("quest."), // tail + }, innerWriter.writes) +} + +func TestWrite_RepeatNumber3_SkipBytes0(t *testing.T) { + var innerWriter collectWrites + splitWriter := NewWriter(&innerWriter, 1, AddSplitSequence(0, 3)) + n, err := splitWriter.Write([]byte("Request")) + require.NoError(t, err) + require.Equal(t, 7, n) + require.Equal(t, [][]byte{[]byte("R"), []byte("equest")}, innerWriter.writes) +} + // collectReader is a [io.Reader] that appends each Read from the Reader to the reads slice. type collectReader struct { io.Reader From 230c2fe3aaf8a9e8ff225f15e39442cf7bad9d52 Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Mon, 4 Nov 2024 19:05:13 -0500 Subject: [PATCH 05/21] Update comment --- transport/split/writer.go | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/transport/split/writer.go b/transport/split/writer.go index bbe4ee81..38f991c8 100644 --- a/transport/split/writer.go +++ b/transport/split/writer.go @@ -42,13 +42,9 @@ var _ io.ReaderFrom = (*splitWriterReaderFrom)(nil) // A write will end right after byte index prefixBytes - 1, before a write starting at byte index prefixBytes. // For example, if you have a write of [0123456789] and prefixBytes = 3, you will get writes [012] and [3456789]. // If the input writer is a [io.ReaderFrom], the output writer will be too. -// It's possible to enable multiple splits with the [EnableRepeatSplit] option. -// In that cases, splits will happen at positions prefixBytes + i * skipBytes, for 0 <= i < count. -// This means that after the initial split, count splits will happen every skipBytes bytes. +// It's possible to enable multiple splits with the [AddSplitSequence] option, which adds count splits every skipBytes bytes. // Example: -// prefixBytes = 1 -// count = 2 -// skipBytes = 6 +// prefixBytes = 1, AddSplitSequence(count=2, bytes=6) // Array of [0 1 3 2 4 5 6 7 8 9 10 11 12 13 14 15 16 ...] will become // [0] [1 2 3 4 5 6] [7 8 9 10 11 12] [13 14 15 16 ...] func NewWriter(writer io.Writer, prefixBytes int64, options ...Option) io.Writer { From 81f3f2ac46d3c175be1b44b69839b0c4529f4cfb Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Mon, 4 Nov 2024 19:09:56 -0500 Subject: [PATCH 06/21] Comment --- transport/split/stream_dialer.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transport/split/stream_dialer.go b/transport/split/stream_dialer.go index 2c123060..a8f9ada6 100644 --- a/transport/split/stream_dialer.go +++ b/transport/split/stream_dialer.go @@ -32,7 +32,7 @@ type splitDialer struct { var _ transport.StreamDialer = (*splitDialer)(nil) // NewStreamDialer creates a [transport.StreamDialer] that splits the outgoing stream after writing "prefixBytes" bytes -// using [splitWriter]. If "repeatsNumber" is not 0, will split that many times, skipping "skipBytes" in between packets. +// using the split writer. You can specify multiple sequences with the [AddSplitSequence] option. func NewStreamDialer(dialer transport.StreamDialer, prefixBytes int64, options ...Option) (transport.StreamDialer, error) { if dialer == nil { return nil, errors.New("argument dialer must not be nil") From e283a16d7dbfb45be0a74523eb8d24c7c6677f00 Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Tue, 5 Nov 2024 15:57:24 -0500 Subject: [PATCH 07/21] Fix --- transport/split/writer.go | 60 ++++++++++++++++++++++++---------- transport/split/writer_test.go | 36 ++++++++++++++++++++ 2 files changed, 79 insertions(+), 17 deletions(-) diff --git a/transport/split/writer.go b/transport/split/writer.go index 38f991c8..671b9763 100644 --- a/transport/split/writer.go +++ b/transport/split/writer.go @@ -18,14 +18,17 @@ import ( "io" ) +// repeatedSplit represents a split sequence of count blocks with bytes length. type repeatedSplit struct { count int bytes int64 } type splitWriter struct { - writer io.Writer - nextSplitBytes int64 + writer io.Writer + // Bytes until the next split. This must always be > 0, unless splits are done. + nextSplitBytes int64 + // Remaining split sequences. All entries here must have count > 0 && bytes > 0. remainingSplits []repeatedSplit } @@ -48,7 +51,8 @@ var _ io.ReaderFrom = (*splitWriterReaderFrom)(nil) // Array of [0 1 3 2 4 5 6 7 8 9 10 11 12 13 14 15 16 ...] will become // [0] [1 2 3 4 5 6] [7 8 9 10 11 12] [13 14 15 16 ...] func NewWriter(writer io.Writer, prefixBytes int64, options ...Option) io.Writer { - sw := &splitWriter{writer: writer, nextSplitBytes: prefixBytes, remainingSplits: []repeatedSplit{}} + sw := &splitWriter{writer: writer, remainingSplits: []repeatedSplit{}} + sw.addSplitSequence(1, prefixBytes) for _, option := range options { option(sw) } @@ -61,14 +65,25 @@ func NewWriter(writer io.Writer, prefixBytes int64, options ...Option) io.Writer return sw } +func (w *splitWriter) addSplitSequence(count int, skipBytes int64) { + if count == 0 || skipBytes == 0 { + return + } + if w.nextSplitBytes == 0 { + w.nextSplitBytes = skipBytes + count-- + } + if count > 0 { + w.remainingSplits = append(w.remainingSplits, repeatedSplit{count: count, bytes: skipBytes}) + } +} + type Option func(w *splitWriter) // AddSplitSequence will add count splits, each of skipBytes length. func AddSplitSequence(count int, skipBytes int64) Option { return func(w *splitWriter) { - if count > 0 { - w.remainingSplits = append(w.remainingSplits, repeatedSplit{count: count, bytes: skipBytes}) - } + w.addSplitSequence(count, skipBytes) } } @@ -79,28 +94,39 @@ func (w *splitWriterReaderFrom) ReadFrom(source io.Reader) (int64, error) { return written, err } +func (w *splitWriter) advance(n int) { + if w.nextSplitBytes == 0 { + // Done with splits: return. + return + } + w.nextSplitBytes -= int64(n) + if w.nextSplitBytes > 0 { + return + } + // Split done, set next split. + if len(w.remainingSplits) == 0 { + return + } + w.nextSplitBytes = w.remainingSplits[0].bytes + w.remainingSplits[0].count -= 1 + if w.remainingSplits[0].count == 0 { + w.remainingSplits = w.remainingSplits[1:] + } +} + func (w *splitWriter) Write(data []byte) (written int, err error) { for 0 < w.nextSplitBytes && w.nextSplitBytes < int64(len(data)) { dataToSend := data[:w.nextSplitBytes] n, err := w.writer.Write(dataToSend) written += n - w.nextSplitBytes -= int64(n) + w.advance(n) if err != nil { return written, err } data = data[n:] - - // Split done. Update nextSplitBytes. - if len(w.remainingSplits) > 0 { - w.nextSplitBytes = w.remainingSplits[0].bytes - w.remainingSplits[0].count -= 1 - if w.remainingSplits[0].count == 0 { - w.remainingSplits = w.remainingSplits[1:] - } - } } n, err := w.writer.Write(data) written += n - w.nextSplitBytes -= int64(n) + w.advance(n) return written, err } diff --git a/transport/split/writer_test.go b/transport/split/writer_test.go index 276fe7bf..b20f8222 100644 --- a/transport/split/writer_test.go +++ b/transport/split/writer_test.go @@ -45,6 +45,42 @@ func TestWrite_Split(t *testing.T) { require.Equal(t, [][]byte{[]byte("Req"), []byte("uest")}, innerWriter.writes) } +func TestWrite_SplitZero(t *testing.T) { + var innerWriter collectWrites + splitWriter := NewWriter(&innerWriter, 0, AddSplitSequence(0, 1), AddSplitSequence(10, 0), AddSplitSequence(0, 2)) + n, err := splitWriter.Write([]byte("Request")) + require.NoError(t, err) + require.Equal(t, 7, n) + require.Equal(t, [][]byte{[]byte("Request")}, innerWriter.writes) +} + +func TestWrite_SplitZeroLong(t *testing.T) { + var innerWriter collectWrites + splitWriter := NewWriter(&innerWriter, 0, AddSplitSequence(1_000_000_000_000_000_000, 0)) + n, err := splitWriter.Write([]byte("Request")) + require.NoError(t, err) + require.Equal(t, 7, n) + require.Equal(t, [][]byte{[]byte("Request")}, innerWriter.writes) +} + +func TestWrite_SplitZeroPrefix(t *testing.T) { + var innerWriter collectWrites + splitWriter := NewWriter(&innerWriter, 0, AddSplitSequence(3, 2)) + n, err := splitWriter.Write([]byte("Request")) + require.NoError(t, err) + require.Equal(t, 7, n) + require.Equal(t, [][]byte{[]byte("Re"), []byte("qu"), []byte("es"), []byte("t")}, innerWriter.writes) +} + +func TestWrite_SplitMulti(t *testing.T) { + var innerWriter collectWrites + splitWriter := NewWriter(&innerWriter, 1, AddSplitSequence(3, 2), AddSplitSequence(2, 3)) + n, err := splitWriter.Write([]byte("RequestRequestRequest")) + require.NoError(t, err) + require.Equal(t, 21, n) + require.Equal(t, [][]byte{[]byte("R"), []byte("eq"), []byte("ue"), []byte("st"), []byte("Req"), []byte("ues"), []byte("tRequest")}, innerWriter.writes) +} + func TestWrite_ShortWrite(t *testing.T) { var innerWriter collectWrites splitWriter := NewWriter(&innerWriter, 10) From d2a4cf9f515ca798980417af8701fe27137fa41c Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Tue, 5 Nov 2024 16:11:27 -0500 Subject: [PATCH 08/21] Support ReaderFrom --- transport/split/writer.go | 33 ++++++++++++++++++++++----------- transport/split/writer_test.go | 12 ++++++++++++ 2 files changed, 34 insertions(+), 11 deletions(-) diff --git a/transport/split/writer.go b/transport/split/writer.go index 671b9763..a4db713c 100644 --- a/transport/split/writer.go +++ b/transport/split/writer.go @@ -56,11 +56,8 @@ func NewWriter(writer io.Writer, prefixBytes int64, options ...Option) io.Writer for _, option := range options { option(sw) } - if len(sw.remainingSplits) == 0 { - // TODO(fortuna): Support ReaderFrom for repeat split. - if rf, ok := writer.(io.ReaderFrom); ok { - return &splitWriterReaderFrom{sw, rf} - } + if rf, ok := writer.(io.ReaderFrom); ok { + return &splitWriterReaderFrom{sw, rf} } return sw } @@ -88,13 +85,27 @@ func AddSplitSequence(count int, skipBytes int64) Option { } func (w *splitWriterReaderFrom) ReadFrom(source io.Reader) (int64, error) { - reader := io.MultiReader(io.LimitReader(source, w.nextSplitBytes), source) - written, err := w.rf.ReadFrom(reader) - w.nextSplitBytes -= written + var written int64 + for w.nextSplitBytes > 0 { + expectedBytes := w.nextSplitBytes + n, err := w.rf.ReadFrom(io.LimitReader(source, expectedBytes)) + written += n + w.advance(n) + if err != nil { + return written, err + } + if n < expectedBytes { + // Source is done before the split happened. Return. + return written, err + } + } + n, err := w.rf.ReadFrom(source) + written += n + w.advance(n) return written, err } -func (w *splitWriter) advance(n int) { +func (w *splitWriter) advance(n int64) { if w.nextSplitBytes == 0 { // Done with splits: return. return @@ -119,7 +130,7 @@ func (w *splitWriter) Write(data []byte) (written int, err error) { dataToSend := data[:w.nextSplitBytes] n, err := w.writer.Write(dataToSend) written += n - w.advance(n) + w.advance(int64(n)) if err != nil { return written, err } @@ -127,6 +138,6 @@ func (w *splitWriter) Write(data []byte) (written int, err error) { } n, err := w.writer.Write(data) written += n - w.advance(n) + w.advance(int64(n)) return written, err } diff --git a/transport/split/writer_test.go b/transport/split/writer_test.go index b20f8222..d995417b 100644 --- a/transport/split/writer_test.go +++ b/transport/split/writer_test.go @@ -178,6 +178,18 @@ func TestReadFrom(t *testing.T) { require.Equal(t, [][]byte{[]byte("Request2")}, cr.reads) } +func TestReadFrom_Multi(t *testing.T) { + splitWriter := NewWriter(&bytes.Buffer{}, 1, AddSplitSequence(3, 2), AddSplitSequence(2, 3)) + rf, ok := splitWriter.(io.ReaderFrom) + require.True(t, ok) + + cr := &collectReader{Reader: bytes.NewReader([]byte("RequestRequestRequest"))} + n, err := rf.ReadFrom(cr) + require.NoError(t, err) + require.Equal(t, int64(21), n) + require.Equal(t, [][]byte{[]byte("R"), []byte("eq"), []byte("ue"), []byte("st"), []byte("Req"), []byte("ues"), []byte("tRequest")}, cr.reads) +} + func TestReadFrom_ShortRead(t *testing.T) { splitWriter := NewWriter(&bytes.Buffer{}, 10) rf, ok := splitWriter.(io.ReaderFrom) From 5adbe7183dca91594790257d60fde0643704c022 Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Tue, 5 Nov 2024 18:05:08 -0500 Subject: [PATCH 09/21] Introduce SplitIterator --- transport/split/stream_dialer.go | 11 ++-- transport/split/writer.go | 102 +++++++++++++++---------------- transport/split/writer_test.go | 30 ++++----- 3 files changed, 71 insertions(+), 72 deletions(-) diff --git a/transport/split/stream_dialer.go b/transport/split/stream_dialer.go index a8f9ada6..59326a15 100644 --- a/transport/split/stream_dialer.go +++ b/transport/split/stream_dialer.go @@ -24,20 +24,19 @@ import ( // splitDialer is a [transport.StreamDialer] that implements the split strategy. // Use [NewStreamDialer] to create new instances. type splitDialer struct { - dialer transport.StreamDialer - splitPoint int64 - options []Option + dialer transport.StreamDialer + nextSplit func() int64 } var _ transport.StreamDialer = (*splitDialer)(nil) // NewStreamDialer creates a [transport.StreamDialer] that splits the outgoing stream after writing "prefixBytes" bytes // using the split writer. You can specify multiple sequences with the [AddSplitSequence] option. -func NewStreamDialer(dialer transport.StreamDialer, prefixBytes int64, options ...Option) (transport.StreamDialer, error) { +func NewStreamDialer(dialer transport.StreamDialer, prefixBytes int64, nextSplit func() int64) (transport.StreamDialer, error) { if dialer == nil { return nil, errors.New("argument dialer must not be nil") } - return &splitDialer{dialer: dialer, splitPoint: prefixBytes}, nil + return &splitDialer{dialer: dialer, nextSplit: nextSplit}, nil } // DialStream implements [transport.StreamDialer].DialStream. @@ -46,5 +45,5 @@ func (d *splitDialer) DialStream(ctx context.Context, remoteAddr string) (transp if err != nil { return nil, err } - return transport.WrapConn(innerConn, innerConn, NewWriter(innerConn, d.splitPoint, d.options...)), nil + return transport.WrapConn(innerConn, innerConn, NewWriter(innerConn, d.nextSplit)), nil } diff --git a/transport/split/writer.go b/transport/split/writer.go index a4db713c..21d2bf66 100644 --- a/transport/split/writer.go +++ b/transport/split/writer.go @@ -18,18 +18,11 @@ import ( "io" ) -// repeatedSplit represents a split sequence of count blocks with bytes length. -type repeatedSplit struct { - count int - bytes int64 -} - type splitWriter struct { writer io.Writer // Bytes until the next split. This must always be > 0, unless splits are done. - nextSplitBytes int64 - // Remaining split sequences. All entries here must have count > 0 && bytes > 0. - remainingSplits []repeatedSplit + nextSplitBytes int64 + nextSegmentLength func() int64 } var _ io.Writer = (*splitWriter)(nil) @@ -41,49 +34,62 @@ type splitWriterReaderFrom struct { var _ io.ReaderFrom = (*splitWriterReaderFrom)(nil) -// NewWriter creates a [io.Writer] that ensures the byte sequence is split at prefixBytes. -// A write will end right after byte index prefixBytes - 1, before a write starting at byte index prefixBytes. -// For example, if you have a write of [0123456789] and prefixBytes = 3, you will get writes [012] and [3456789]. -// If the input writer is a [io.ReaderFrom], the output writer will be too. -// It's possible to enable multiple splits with the [AddSplitSequence] option, which adds count splits every skipBytes bytes. -// Example: -// prefixBytes = 1, AddSplitSequence(count=2, bytes=6) -// Array of [0 1 3 2 4 5 6 7 8 9 10 11 12 13 14 15 16 ...] will become -// [0] [1 2 3 4 5 6] [7 8 9 10 11 12] [13 14 15 16 ...] -func NewWriter(writer io.Writer, prefixBytes int64, options ...Option) io.Writer { - sw := &splitWriter{writer: writer, remainingSplits: []repeatedSplit{}} - sw.addSplitSequence(1, prefixBytes) - for _, option := range options { - option(sw) - } - if rf, ok := writer.(io.ReaderFrom); ok { - return &splitWriterReaderFrom{sw, rf} +// Split Iterator is a function that returns how many bytes until the next split point, or zero if there are no more splits to do. +type SplitIterator func() int64 + +// NewFixedSplitIterator is a helper function that returns a [SplitIterator] that returns the input number once, followed by zero. +// This is helpful for when you want to split the stream once in a fixed position. +func NewFixedSplitIterator(n int64) SplitIterator { + return func() int64 { + next := n + n = 0 + return next } - return sw } -func (w *splitWriter) addSplitSequence(count int, skipBytes int64) { - if count == 0 || skipBytes == 0 { - return - } - if w.nextSplitBytes == 0 { - w.nextSplitBytes = skipBytes - count-- +// RepeatedSplit represents a split sequence of count segments with bytes length. +type RepeatedSplit struct { + Count int + Bytes int64 +} + +// NewRepeatedSplitIterator is a helper function that returns a [SplitIterator] that returns split points according to splits. +// The splits input represents pairs of (count, bytes), meaning a sequence of count splits with bytes length. +// This is helpful for when you want to split the stream repeatedly at different positions and lengths. +func NewRepeatedSplitIterator(splits ...RepeatedSplit) SplitIterator { + // Make sure we don't edit the original slice. + cleanSplits := make([]RepeatedSplit, 0, len(splits)) + // Remove no-op splits. + for _, split := range splits { + if split.Count > 0 && split.Bytes > 0 { + cleanSplits = append(cleanSplits, split) + } } - if count > 0 { - w.remainingSplits = append(w.remainingSplits, repeatedSplit{count: count, bytes: skipBytes}) + return func() int64 { + if len(cleanSplits) == 0 { + return 0 + } + next := cleanSplits[0].Bytes + cleanSplits[0].Count -= 1 + if cleanSplits[0].Count == 0 { + cleanSplits = cleanSplits[1:] + } + return next } } -type Option func(w *splitWriter) - -// AddSplitSequence will add count splits, each of skipBytes length. -func AddSplitSequence(count int, skipBytes int64) Option { - return func(w *splitWriter) { - w.addSplitSequence(count, skipBytes) +// NewWriter creates a split Writer that calls the nextSegmentLength [SplitIterator] to determine the number bytes until the next split +// point until it returns zero. +func NewWriter(writer io.Writer, nextSegmentLength SplitIterator) io.Writer { + sw := &splitWriter{writer: writer, nextSegmentLength: nextSegmentLength} + sw.nextSplitBytes = nextSegmentLength() + if rf, ok := writer.(io.ReaderFrom); ok { + return &splitWriterReaderFrom{sw, rf} } + return sw } +// ReadFrom implements io.ReaderFrom. func (w *splitWriterReaderFrom) ReadFrom(source io.Reader) (int64, error) { var written int64 for w.nextSplitBytes > 0 { @@ -114,17 +120,11 @@ func (w *splitWriter) advance(n int64) { if w.nextSplitBytes > 0 { return } - // Split done, set next split. - if len(w.remainingSplits) == 0 { - return - } - w.nextSplitBytes = w.remainingSplits[0].bytes - w.remainingSplits[0].count -= 1 - if w.remainingSplits[0].count == 0 { - w.remainingSplits = w.remainingSplits[1:] - } + // Split done, set up the next split. + w.nextSplitBytes = w.nextSegmentLength() } +// Write implements io.Writer. func (w *splitWriter) Write(data []byte) (written int, err error) { for 0 < w.nextSplitBytes && w.nextSplitBytes < int64(len(data)) { dataToSend := data[:w.nextSplitBytes] diff --git a/transport/split/writer_test.go b/transport/split/writer_test.go index d995417b..fc936891 100644 --- a/transport/split/writer_test.go +++ b/transport/split/writer_test.go @@ -38,7 +38,7 @@ func (w *collectWrites) Write(data []byte) (int, error) { func TestWrite_Split(t *testing.T) { var innerWriter collectWrites - splitWriter := NewWriter(&innerWriter, 3) + splitWriter := NewWriter(&innerWriter, NewFixedSplitIterator(3)) n, err := splitWriter.Write([]byte("Request")) require.NoError(t, err) require.Equal(t, 7, n) @@ -47,7 +47,7 @@ func TestWrite_Split(t *testing.T) { func TestWrite_SplitZero(t *testing.T) { var innerWriter collectWrites - splitWriter := NewWriter(&innerWriter, 0, AddSplitSequence(0, 1), AddSplitSequence(10, 0), AddSplitSequence(0, 2)) + splitWriter := NewWriter(&innerWriter, NewRepeatedSplitIterator(RepeatedSplit{1, 0}, RepeatedSplit{0, 1}, RepeatedSplit{10, 0}, RepeatedSplit{0, 2})) n, err := splitWriter.Write([]byte("Request")) require.NoError(t, err) require.Equal(t, 7, n) @@ -56,7 +56,7 @@ func TestWrite_SplitZero(t *testing.T) { func TestWrite_SplitZeroLong(t *testing.T) { var innerWriter collectWrites - splitWriter := NewWriter(&innerWriter, 0, AddSplitSequence(1_000_000_000_000_000_000, 0)) + splitWriter := NewWriter(&innerWriter, NewRepeatedSplitIterator(RepeatedSplit{1, 0}, RepeatedSplit{1_000_000_000_000_000_000, 0})) n, err := splitWriter.Write([]byte("Request")) require.NoError(t, err) require.Equal(t, 7, n) @@ -65,7 +65,7 @@ func TestWrite_SplitZeroLong(t *testing.T) { func TestWrite_SplitZeroPrefix(t *testing.T) { var innerWriter collectWrites - splitWriter := NewWriter(&innerWriter, 0, AddSplitSequence(3, 2)) + splitWriter := NewWriter(&innerWriter, NewRepeatedSplitIterator(RepeatedSplit{1, 0}, RepeatedSplit{3, 2})) n, err := splitWriter.Write([]byte("Request")) require.NoError(t, err) require.Equal(t, 7, n) @@ -74,7 +74,7 @@ func TestWrite_SplitZeroPrefix(t *testing.T) { func TestWrite_SplitMulti(t *testing.T) { var innerWriter collectWrites - splitWriter := NewWriter(&innerWriter, 1, AddSplitSequence(3, 2), AddSplitSequence(2, 3)) + splitWriter := NewWriter(&innerWriter, NewRepeatedSplitIterator(RepeatedSplit{1, 1}, RepeatedSplit{3, 2}, RepeatedSplit{2, 3})) n, err := splitWriter.Write([]byte("RequestRequestRequest")) require.NoError(t, err) require.Equal(t, 21, n) @@ -83,7 +83,7 @@ func TestWrite_SplitMulti(t *testing.T) { func TestWrite_ShortWrite(t *testing.T) { var innerWriter collectWrites - splitWriter := NewWriter(&innerWriter, 10) + splitWriter := NewWriter(&innerWriter, NewFixedSplitIterator(10)) n, err := splitWriter.Write([]byte("Request")) require.NoError(t, err) require.Equal(t, 7, n) @@ -92,7 +92,7 @@ func TestWrite_ShortWrite(t *testing.T) { func TestWrite_Zero(t *testing.T) { var innerWriter collectWrites - splitWriter := NewWriter(&innerWriter, 0) + splitWriter := NewWriter(&innerWriter, NewFixedSplitIterator(0)) n, err := splitWriter.Write([]byte("Request")) require.NoError(t, err) require.Equal(t, 7, n) @@ -101,7 +101,7 @@ func TestWrite_Zero(t *testing.T) { func TestWrite_NeedsTwoWrites(t *testing.T) { var innerWriter collectWrites - splitWriter := NewWriter(&innerWriter, 5) + splitWriter := NewWriter(&innerWriter, NewFixedSplitIterator(5)) n, err := splitWriter.Write([]byte("Re")) require.NoError(t, err) require.Equal(t, 2, n) @@ -113,7 +113,7 @@ func TestWrite_NeedsTwoWrites(t *testing.T) { func TestWrite_Compound(t *testing.T) { var innerWriter collectWrites - splitWriter := NewWriter(NewWriter(&innerWriter, 4), 1) + splitWriter := NewWriter(NewWriter(&innerWriter, NewFixedSplitIterator(4)), NewFixedSplitIterator(1)) n, err := splitWriter.Write([]byte("Request")) require.NoError(t, err) require.Equal(t, 7, n) @@ -122,7 +122,7 @@ func TestWrite_Compound(t *testing.T) { func TestWrite_RepeatNumber3_SkipBytes5(t *testing.T) { var innerWriter collectWrites - splitWriter := NewWriter(&innerWriter, 1, AddSplitSequence(3, 5)) + splitWriter := NewWriter(&innerWriter, NewRepeatedSplitIterator(RepeatedSplit{1, 1}, RepeatedSplit{3, 5})) n, err := splitWriter.Write([]byte("RequestRequestRequest.")) require.NoError(t, err) require.Equal(t, 7*3+1, n) @@ -137,7 +137,7 @@ func TestWrite_RepeatNumber3_SkipBytes5(t *testing.T) { func TestWrite_RepeatNumber3_SkipBytes0(t *testing.T) { var innerWriter collectWrites - splitWriter := NewWriter(&innerWriter, 1, AddSplitSequence(0, 3)) + splitWriter := NewWriter(&innerWriter, NewRepeatedSplitIterator(RepeatedSplit{1, 1}, RepeatedSplit{0, 3})) n, err := splitWriter.Write([]byte("Request")) require.NoError(t, err) require.Equal(t, 7, n) @@ -161,7 +161,7 @@ func (r *collectReader) Read(buf []byte) (int, error) { } func TestReadFrom(t *testing.T) { - splitWriter := NewWriter(&bytes.Buffer{}, 3) + splitWriter := NewWriter(&bytes.Buffer{}, NewFixedSplitIterator(3)) rf, ok := splitWriter.(io.ReaderFrom) require.True(t, ok) @@ -179,7 +179,7 @@ func TestReadFrom(t *testing.T) { } func TestReadFrom_Multi(t *testing.T) { - splitWriter := NewWriter(&bytes.Buffer{}, 1, AddSplitSequence(3, 2), AddSplitSequence(2, 3)) + splitWriter := NewWriter(&bytes.Buffer{}, NewRepeatedSplitIterator(RepeatedSplit{1, 1}, RepeatedSplit{3, 2}, RepeatedSplit{2, 3})) rf, ok := splitWriter.(io.ReaderFrom) require.True(t, ok) @@ -191,7 +191,7 @@ func TestReadFrom_Multi(t *testing.T) { } func TestReadFrom_ShortRead(t *testing.T) { - splitWriter := NewWriter(&bytes.Buffer{}, 10) + splitWriter := NewWriter(&bytes.Buffer{}, NewFixedSplitIterator(10)) rf, ok := splitWriter.(io.ReaderFrom) require.True(t, ok) cr := &collectReader{Reader: bytes.NewReader([]byte("Request1"))} @@ -210,7 +210,7 @@ func TestReadFrom_ShortRead(t *testing.T) { func BenchmarkReadFrom(b *testing.B) { for n := 0; n < b.N; n++ { reader := bytes.NewReader(make([]byte, n)) - writer := NewWriter(io.Discard, 10) + writer := NewWriter(io.Discard, NewFixedSplitIterator(10)) rf, ok := writer.(io.ReaderFrom) require.True(b, ok) _, err := rf.ReadFrom(reader) From 4801a034fbafdd6d725292b6fcb958ea08617b46 Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Tue, 5 Nov 2024 18:06:28 -0500 Subject: [PATCH 10/21] Fix Dialer --- transport/split/stream_dialer.go | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/transport/split/stream_dialer.go b/transport/split/stream_dialer.go index 59326a15..84d374fa 100644 --- a/transport/split/stream_dialer.go +++ b/transport/split/stream_dialer.go @@ -25,14 +25,13 @@ import ( // Use [NewStreamDialer] to create new instances. type splitDialer struct { dialer transport.StreamDialer - nextSplit func() int64 + nextSplit SplitIterator } var _ transport.StreamDialer = (*splitDialer)(nil) -// NewStreamDialer creates a [transport.StreamDialer] that splits the outgoing stream after writing "prefixBytes" bytes -// using the split writer. You can specify multiple sequences with the [AddSplitSequence] option. -func NewStreamDialer(dialer transport.StreamDialer, prefixBytes int64, nextSplit func() int64) (transport.StreamDialer, error) { +// NewStreamDialer creates a [transport.StreamDialer] that splits the outgoing stream according to nextSplit. +func NewStreamDialer(dialer transport.StreamDialer, nextSplit SplitIterator) (transport.StreamDialer, error) { if dialer == nil { return nil, errors.New("argument dialer must not be nil") } From cb7d702eeefd13d10e629d597bc2b0ddb78b9b00 Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Tue, 5 Nov 2024 18:55:04 -0500 Subject: [PATCH 11/21] Add sleep --- transport/split/writer.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/transport/split/writer.go b/transport/split/writer.go index 21d2bf66..50ca8e92 100644 --- a/transport/split/writer.go +++ b/transport/split/writer.go @@ -16,6 +16,7 @@ package split import ( "io" + "time" ) type splitWriter struct { @@ -129,6 +130,8 @@ func (w *splitWriter) Write(data []byte) (written int, err error) { for 0 < w.nextSplitBytes && w.nextSplitBytes < int64(len(data)) { dataToSend := data[:w.nextSplitBytes] n, err := w.writer.Write(dataToSend) + // Sleep to ensure bytes are properly split. + time.Sleep(100 * time.Microsecond) written += n w.advance(int64(n)) if err != nil { From 637628ff48dee1e122902b7f5d8a4132b60d852e Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Tue, 5 Nov 2024 19:00:28 -0500 Subject: [PATCH 12/21] Update config --- x/configurl/split.go | 2 +- x/go.mod | 2 +- x/go.sum | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/x/configurl/split.go b/x/configurl/split.go index e6d89b3a..48242010 100644 --- a/x/configurl/split.go +++ b/x/configurl/split.go @@ -34,6 +34,6 @@ func registerSplitStreamDialer(r TypeRegistry[transport.StreamDialer], typeID st if err != nil { return nil, fmt.Errorf("prefixBytes is not a number: %v. Split config should be in split: format", prefixBytesStr) } - return split.NewStreamDialer(sd, int64(prefixBytes)) + return split.NewStreamDialer(sd, split.NewFixedSplitIterator(int64(prefixBytes))) }) } diff --git a/x/go.mod b/x/go.mod index ef66a9d5..f5f452c5 100644 --- a/x/go.mod +++ b/x/go.mod @@ -3,7 +3,7 @@ module github.com/Jigsaw-Code/outline-sdk/x go 1.22 require ( - github.com/Jigsaw-Code/outline-sdk v0.0.17 + github.com/Jigsaw-Code/outline-sdk v0.0.18-0.20241105230628-4801a034fbaf // Use github.com/Psiphon-Labs/psiphon-tunnel-core@staging-client as per // https://github.com/Psiphon-Labs/psiphon-tunnel-core/?tab=readme-ov-file#using-psiphon-with-go-modules github.com/Psiphon-Labs/psiphon-tunnel-core v1.0.11-0.20240619172145-03cade11f647 diff --git a/x/go.sum b/x/go.sum index 70bbc73a..b8c6f4f3 100644 --- a/x/go.sum +++ b/x/go.sum @@ -6,8 +6,8 @@ github.com/AndreasBriese/bbloom v0.0.0-20170702084017-28f7e881ca57 h1:CVuXDbdzPW github.com/AndreasBriese/bbloom v0.0.0-20170702084017-28f7e881ca57/go.mod h1:bOvUY6CB00SOBii9/FifXqc0awNKxLFCL/+pkDPuyl8= github.com/BurntSushi/toml v1.3.2 h1:o7IhLm0Msx3BaB+n3Ag7L8EVlByGnpq14C4YWiu/gL8= github.com/BurntSushi/toml v1.3.2/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= -github.com/Jigsaw-Code/outline-sdk v0.0.17 h1:xkGsp3fs+5EbIZEeCEPI1rl0oyCrOLuO7YSK0gB75HE= -github.com/Jigsaw-Code/outline-sdk v0.0.17/go.mod h1:CFDKyGZA4zatKE4vMLe8TyQpZCyINOeRFbMAmYHxodw= +github.com/Jigsaw-Code/outline-sdk v0.0.18-0.20241105230628-4801a034fbaf h1:nJFLb0ukT/K4zyQod4vXLqimWq9ekzgC7i5Q+J3x6Rg= +github.com/Jigsaw-Code/outline-sdk v0.0.18-0.20241105230628-4801a034fbaf/go.mod h1:CFDKyGZA4zatKE4vMLe8TyQpZCyINOeRFbMAmYHxodw= github.com/Psiphon-Inc/rotate-safe-writer v0.0.0-20210303140923-464a7a37606e h1:NPfqIbzmijrl0VclX2t8eO5EPBhqe47LLGKpRrcVjXk= github.com/Psiphon-Inc/rotate-safe-writer v0.0.0-20210303140923-464a7a37606e/go.mod h1:ZdY5pBfat/WVzw3eXbIf7N1nZN0XD5H5+X8ZMDWbCs4= github.com/Psiphon-Labs/bolt v0.0.0-20200624191537-23cedaef7ad7 h1:Hx/NCZTnvoKZuIBwSmxE58KKoNLXIGG6hBJYN7pj9Ag= From 3354513b85f8b18369d6157c46f6c12c41200c3a Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Tue, 5 Nov 2024 19:33:01 -0500 Subject: [PATCH 13/21] Revert sleep --- transport/split/writer.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/transport/split/writer.go b/transport/split/writer.go index 50ca8e92..21d2bf66 100644 --- a/transport/split/writer.go +++ b/transport/split/writer.go @@ -16,7 +16,6 @@ package split import ( "io" - "time" ) type splitWriter struct { @@ -130,8 +129,6 @@ func (w *splitWriter) Write(data []byte) (written int, err error) { for 0 < w.nextSplitBytes && w.nextSplitBytes < int64(len(data)) { dataToSend := data[:w.nextSplitBytes] n, err := w.writer.Write(dataToSend) - // Sleep to ensure bytes are properly split. - time.Sleep(100 * time.Microsecond) written += n w.advance(int64(n)) if err != nil { From 404aa6554da4bd5f760c8dbea8e0a416a018342c Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Tue, 5 Nov 2024 19:34:19 -0500 Subject: [PATCH 14/21] Update config --- x/configurl/doc.go | 4 ++-- x/configurl/split.go | 33 ++++++++++++++++++++++++++++----- x/go.mod | 2 +- x/go.sum | 4 ++-- 4 files changed, 33 insertions(+), 10 deletions(-) diff --git a/x/configurl/doc.go b/x/configurl/doc.go index 056c7a23..9b2191d2 100644 --- a/x/configurl/doc.go +++ b/x/configurl/doc.go @@ -91,9 +91,9 @@ These strategies manipulate packets to bypass SNI-based blocking. Stream split transport (streams only, package [github.com/Jigsaw-Code/outline-sdk/transport/split]) -It takes the length of the prefix. The stream will be split when PREFIX_LENGTH bytes are first written. +It takes a list of count*length pairs meaning splitting the sequence in count segments of the given length. If you omit "[COUNT]*", it's assumed to be 1. - split:[PREFIX_LENGTH] + split:[COUNT1]*[LENGTH1],[COUNT2]*[LENGTH2],... TLS fragmentation (streams only, package [github.com/Jigsaw-Code/outline-sdk/transport/tlsfrag]). diff --git a/x/configurl/split.go b/x/configurl/split.go index 48242010..33737fbe 100644 --- a/x/configurl/split.go +++ b/x/configurl/split.go @@ -18,6 +18,7 @@ import ( "context" "fmt" "strconv" + "strings" "github.com/Jigsaw-Code/outline-sdk/transport" "github.com/Jigsaw-Code/outline-sdk/transport/split" @@ -29,11 +30,33 @@ func registerSplitStreamDialer(r TypeRegistry[transport.StreamDialer], typeID st if err != nil { return nil, err } - prefixBytesStr := config.URL.Opaque - prefixBytes, err := strconv.Atoi(prefixBytesStr) - if err != nil { - return nil, fmt.Errorf("prefixBytes is not a number: %v. Split config should be in split: format", prefixBytesStr) + configText := config.URL.Opaque + splits := make([]split.RepeatedSplit, 0) + for _, part := range strings.Split(configText, ",") { + var count int + var bytes int64 + subparts := strings.Split(strings.TrimSpace(part), "*") + switch len(subparts) { + case 1: + count = 1 + bytes, err = strconv.ParseInt(subparts[0], 10, 64) + if err != nil { + return nil, fmt.Errorf("bytes is not a number: %v", subparts[0]) + } + case 2: + count, err = strconv.Atoi(subparts[0]) + if err != nil { + return nil, fmt.Errorf("count is not a number: %v", subparts[0]) + } + bytes, err = strconv.ParseInt(subparts[1], 10, 64) + if err != nil { + return nil, fmt.Errorf("bytes is not a number: %v", subparts[0]) + } + default: + return nil, fmt.Errorf("split format must be a comma-separated list of '[$COUNT*]$BYTES' (e.g. '100,5*2'). Got %v", part) + } + splits = append(splits, split.RepeatedSplit{Count: count, Bytes: bytes}) } - return split.NewStreamDialer(sd, split.NewFixedSplitIterator(int64(prefixBytes))) + return split.NewStreamDialer(sd, split.NewRepeatedSplitIterator(splits...)) }) } diff --git a/x/go.mod b/x/go.mod index f5f452c5..ba02180e 100644 --- a/x/go.mod +++ b/x/go.mod @@ -3,7 +3,7 @@ module github.com/Jigsaw-Code/outline-sdk/x go 1.22 require ( - github.com/Jigsaw-Code/outline-sdk v0.0.18-0.20241105230628-4801a034fbaf + github.com/Jigsaw-Code/outline-sdk v0.0.18-0.20241106003301-3354513b85f8 // Use github.com/Psiphon-Labs/psiphon-tunnel-core@staging-client as per // https://github.com/Psiphon-Labs/psiphon-tunnel-core/?tab=readme-ov-file#using-psiphon-with-go-modules github.com/Psiphon-Labs/psiphon-tunnel-core v1.0.11-0.20240619172145-03cade11f647 diff --git a/x/go.sum b/x/go.sum index b8c6f4f3..46047f1c 100644 --- a/x/go.sum +++ b/x/go.sum @@ -6,8 +6,8 @@ github.com/AndreasBriese/bbloom v0.0.0-20170702084017-28f7e881ca57 h1:CVuXDbdzPW github.com/AndreasBriese/bbloom v0.0.0-20170702084017-28f7e881ca57/go.mod h1:bOvUY6CB00SOBii9/FifXqc0awNKxLFCL/+pkDPuyl8= github.com/BurntSushi/toml v1.3.2 h1:o7IhLm0Msx3BaB+n3Ag7L8EVlByGnpq14C4YWiu/gL8= github.com/BurntSushi/toml v1.3.2/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= -github.com/Jigsaw-Code/outline-sdk v0.0.18-0.20241105230628-4801a034fbaf h1:nJFLb0ukT/K4zyQod4vXLqimWq9ekzgC7i5Q+J3x6Rg= -github.com/Jigsaw-Code/outline-sdk v0.0.18-0.20241105230628-4801a034fbaf/go.mod h1:CFDKyGZA4zatKE4vMLe8TyQpZCyINOeRFbMAmYHxodw= +github.com/Jigsaw-Code/outline-sdk v0.0.18-0.20241106003301-3354513b85f8 h1:3bdohee2xFJmsXYdAB/p1zZRdrKMQlp1kiyNJgCE1OY= +github.com/Jigsaw-Code/outline-sdk v0.0.18-0.20241106003301-3354513b85f8/go.mod h1:CFDKyGZA4zatKE4vMLe8TyQpZCyINOeRFbMAmYHxodw= github.com/Psiphon-Inc/rotate-safe-writer v0.0.0-20210303140923-464a7a37606e h1:NPfqIbzmijrl0VclX2t8eO5EPBhqe47LLGKpRrcVjXk= github.com/Psiphon-Inc/rotate-safe-writer v0.0.0-20210303140923-464a7a37606e/go.mod h1:ZdY5pBfat/WVzw3eXbIf7N1nZN0XD5H5+X8ZMDWbCs4= github.com/Psiphon-Labs/bolt v0.0.0-20200624191537-23cedaef7ad7 h1:Hx/NCZTnvoKZuIBwSmxE58KKoNLXIGG6hBJYN7pj9Ag= From 74e5fe2a5a81f147d239ad659b41015ae78c81e3 Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Wed, 6 Nov 2024 16:56:08 -0500 Subject: [PATCH 15/21] feat: create opts package --- x/sockopt/sockopt.go | 85 +++++++++++++++++++++++++++++++++++++++ x/sockopt/sockopt_test.go | 54 +++++++++++++++++++++++++ 2 files changed, 139 insertions(+) create mode 100644 x/sockopt/sockopt.go create mode 100644 x/sockopt/sockopt_test.go diff --git a/x/sockopt/sockopt.go b/x/sockopt/sockopt.go new file mode 100644 index 00000000..15d79382 --- /dev/null +++ b/x/sockopt/sockopt.go @@ -0,0 +1,85 @@ +// Copyright 2024 The Outline Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package sockopt provides cross-platform ways to interact with socket options. +package sockopt + +import ( + "fmt" + "net" + "net/netip" + + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +// HasHopLimit enables manipulation of the hop limit option. +type HasHopLimit interface { + // HopLimit returns the hop limit field value for outgoing packets. + HopLimit() (int, error) + // SetHopLimit sets the hop limit field value for future outgoing packets. + SetHopLimit(hoplim int) error +} + +// hopLimitOption implements HasHopLimit. +type hopLimitOption struct { + hopLimit func() (int, error) + setHopLimit func(hoplim int) error +} + +func (o *hopLimitOption) HopLimit() (int, error) { + return o.hopLimit() +} + +func (o *hopLimitOption) SetHopLimit(hoplim int) error { + return o.setHopLimit(hoplim) +} + +var _ HasHopLimit = (*hopLimitOption)(nil) + +// TCPOptions represents options for TCP connections. +type TCPOptions interface { + HasHopLimit +} + +type tcpOptions struct { + hopLimitOption +} + +var _ TCPOptions = (*tcpOptions)(nil) + +// NewTCPOptions creates a [TCPOptions] for the given [net.TCPConn]. +func NewTCPOptions(conn *net.TCPConn) (TCPOptions, error) { + addr, err := netip.ParseAddrPort(conn.RemoteAddr().String()) + if err != nil { + return nil, fmt.Errorf("could not parse remote addr: %w", err) + } + + opts := &tcpOptions{} + + switch { + case addr.Addr().Is4(): + conn := ipv4.NewConn(conn) + opts.hopLimitOption.hopLimit = conn.TTL + opts.hopLimitOption.setHopLimit = conn.SetTTL + case addr.Addr().Is6(): + conn := ipv6.NewConn(conn) + opts.hopLimitOption.hopLimit = conn.HopLimit + opts.hopLimitOption.setHopLimit = conn.SetHopLimit + default: + return nil, fmt.Errorf("unknown remote addr type (%v)", addr.Addr().String()) + } + + return opts, nil +} diff --git a/x/sockopt/sockopt_test.go b/x/sockopt/sockopt_test.go new file mode 100644 index 00000000..74b34341 --- /dev/null +++ b/x/sockopt/sockopt_test.go @@ -0,0 +1,54 @@ +// Copyright 2024 The Outline Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sockopt + +import ( + "net" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestTCPOptions(t *testing.T) { + type Params struct { + Net string + Addr string + } + for _, params := range []Params{{Net: "tcp4", Addr: "127.0.0.1:0"}, {Net: "tcp6", Addr: "[::1]:0"}} { + l, err := net.Listen(params.Net, params.Addr) + require.NoError(t, err) + defer l.Close() + + conn, err := net.Dial("tcp", l.Addr().String()) + require.NoError(t, err) + tcpConn, ok := conn.(*net.TCPConn) + require.True(t, ok) + + opts, err := NewTCPOptions(tcpConn) + require.NoError(t, err) + + require.NoError(t, opts.SetHopLimit(1)) + + hoplim, err := opts.HopLimit() + require.NoError(t, err) + require.Equal(t, 1, hoplim) + + require.NoError(t, opts.SetHopLimit(20)) + + hoplim, err = opts.HopLimit() + require.NoError(t, err) + require.Equal(t, 20, hoplim) + } +} From bbf6f92e17e7de7bf1b1a516f83cb7f61d4febcf Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Wed, 6 Nov 2024 17:13:29 -0500 Subject: [PATCH 16/21] Make hoplimit option reusable for UDP --- x/sockopt/sockopt.go | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/x/sockopt/sockopt.go b/x/sockopt/sockopt.go index 15d79382..94ca4037 100644 --- a/x/sockopt/sockopt.go +++ b/x/sockopt/sockopt.go @@ -59,27 +59,33 @@ type tcpOptions struct { var _ TCPOptions = (*tcpOptions)(nil) -// NewTCPOptions creates a [TCPOptions] for the given [net.TCPConn]. -func NewTCPOptions(conn *net.TCPConn) (TCPOptions, error) { - addr, err := netip.ParseAddrPort(conn.RemoteAddr().String()) +// newHopLimit creates a hopLimitOption from a [net.Conn]. Works for both TCP or UDP. +func newHopLimit(conn net.Conn) (*hopLimitOption, error) { + addr, err := netip.ParseAddrPort(conn.LocalAddr().String()) if err != nil { - return nil, fmt.Errorf("could not parse remote addr: %w", err) + return nil, err } - - opts := &tcpOptions{} - + opt := &hopLimitOption{} switch { case addr.Addr().Is4(): conn := ipv4.NewConn(conn) - opts.hopLimitOption.hopLimit = conn.TTL - opts.hopLimitOption.setHopLimit = conn.SetTTL + opt.hopLimit = conn.TTL + opt.setHopLimit = conn.SetTTL case addr.Addr().Is6(): conn := ipv6.NewConn(conn) - opts.hopLimitOption.hopLimit = conn.HopLimit - opts.hopLimitOption.setHopLimit = conn.SetHopLimit + opt.hopLimit = conn.HopLimit + opt.setHopLimit = conn.SetHopLimit default: - return nil, fmt.Errorf("unknown remote addr type (%v)", addr.Addr().String()) + return nil, fmt.Errorf("address is not IPv4 or IPv6 (%v)", addr.Addr().String()) } + return opt, nil +} - return opts, nil +// NewTCPOptions creates a [TCPOptions] for the given [net.TCPConn]. +func NewTCPOptions(conn *net.TCPConn) (TCPOptions, error) { + hopLimit, err := newHopLimit(conn) + if err != nil { + return nil, err + } + return &tcpOptions{hopLimitOption: *hopLimit}, nil } From b3bdc9ef7bebbfa47590579b572f64e744004b0f Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Wed, 6 Nov 2024 18:34:02 -0500 Subject: [PATCH 17/21] Apply suggestions from code review Co-authored-by: J. Yi <93548144+jyyi1@users.noreply.github.com> --- transport/split/writer.go | 2 +- x/configurl/split.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/transport/split/writer.go b/transport/split/writer.go index 21d2bf66..352dc0db 100644 --- a/transport/split/writer.go +++ b/transport/split/writer.go @@ -34,7 +34,7 @@ type splitWriterReaderFrom struct { var _ io.ReaderFrom = (*splitWriterReaderFrom)(nil) -// Split Iterator is a function that returns how many bytes until the next split point, or zero if there are no more splits to do. +// SplitIterator is a function that returns how many bytes until the next split point, or zero if there are no more splits to do. type SplitIterator func() int64 // NewFixedSplitIterator is a helper function that returns a [SplitIterator] that returns the input number once, followed by zero. diff --git a/x/configurl/split.go b/x/configurl/split.go index 33737fbe..6c01aec8 100644 --- a/x/configurl/split.go +++ b/x/configurl/split.go @@ -50,7 +50,7 @@ func registerSplitStreamDialer(r TypeRegistry[transport.StreamDialer], typeID st } bytes, err = strconv.ParseInt(subparts[1], 10, 64) if err != nil { - return nil, fmt.Errorf("bytes is not a number: %v", subparts[0]) + return nil, fmt.Errorf("bytes is not a number: %v", subparts[1]) } default: return nil, fmt.Errorf("split format must be a comma-separated list of '[$COUNT*]$BYTES' (e.g. '100,5*2'). Got %v", part) From faffebb12629a216e4f30f9b4558b372fdb7fc27 Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Wed, 6 Nov 2024 18:37:08 -0500 Subject: [PATCH 18/21] Check for nil --- transport/split/stream_dialer.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/transport/split/stream_dialer.go b/transport/split/stream_dialer.go index 84d374fa..19aa696c 100644 --- a/transport/split/stream_dialer.go +++ b/transport/split/stream_dialer.go @@ -35,6 +35,9 @@ func NewStreamDialer(dialer transport.StreamDialer, nextSplit SplitIterator) (tr if dialer == nil { return nil, errors.New("argument dialer must not be nil") } + if nextSplit == nil { + return nil, errors.New("argument nextSplit must not be nil") + } return &splitDialer{dialer: dialer, nextSplit: nextSplit}, nil } From cae6dd4a8559a9ec40c8d7b2d3d996b103724d24 Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Wed, 6 Nov 2024 18:38:06 -0500 Subject: [PATCH 19/21] Update mod --- x/go.mod | 2 +- x/go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/x/go.mod b/x/go.mod index ba02180e..b432aec4 100644 --- a/x/go.mod +++ b/x/go.mod @@ -3,7 +3,7 @@ module github.com/Jigsaw-Code/outline-sdk/x go 1.22 require ( - github.com/Jigsaw-Code/outline-sdk v0.0.18-0.20241106003301-3354513b85f8 + github.com/Jigsaw-Code/outline-sdk v0.0.18-0.20241106233708-faffebb12629 // Use github.com/Psiphon-Labs/psiphon-tunnel-core@staging-client as per // https://github.com/Psiphon-Labs/psiphon-tunnel-core/?tab=readme-ov-file#using-psiphon-with-go-modules github.com/Psiphon-Labs/psiphon-tunnel-core v1.0.11-0.20240619172145-03cade11f647 diff --git a/x/go.sum b/x/go.sum index 46047f1c..a7238e9a 100644 --- a/x/go.sum +++ b/x/go.sum @@ -6,8 +6,8 @@ github.com/AndreasBriese/bbloom v0.0.0-20170702084017-28f7e881ca57 h1:CVuXDbdzPW github.com/AndreasBriese/bbloom v0.0.0-20170702084017-28f7e881ca57/go.mod h1:bOvUY6CB00SOBii9/FifXqc0awNKxLFCL/+pkDPuyl8= github.com/BurntSushi/toml v1.3.2 h1:o7IhLm0Msx3BaB+n3Ag7L8EVlByGnpq14C4YWiu/gL8= github.com/BurntSushi/toml v1.3.2/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= -github.com/Jigsaw-Code/outline-sdk v0.0.18-0.20241106003301-3354513b85f8 h1:3bdohee2xFJmsXYdAB/p1zZRdrKMQlp1kiyNJgCE1OY= -github.com/Jigsaw-Code/outline-sdk v0.0.18-0.20241106003301-3354513b85f8/go.mod h1:CFDKyGZA4zatKE4vMLe8TyQpZCyINOeRFbMAmYHxodw= +github.com/Jigsaw-Code/outline-sdk v0.0.18-0.20241106233708-faffebb12629 h1:sHi1X4vwtNNBUDCbxynGXe7cM/inwTbavowHziaxlbk= +github.com/Jigsaw-Code/outline-sdk v0.0.18-0.20241106233708-faffebb12629/go.mod h1:CFDKyGZA4zatKE4vMLe8TyQpZCyINOeRFbMAmYHxodw= github.com/Psiphon-Inc/rotate-safe-writer v0.0.0-20210303140923-464a7a37606e h1:NPfqIbzmijrl0VclX2t8eO5EPBhqe47LLGKpRrcVjXk= github.com/Psiphon-Inc/rotate-safe-writer v0.0.0-20210303140923-464a7a37606e/go.mod h1:ZdY5pBfat/WVzw3eXbIf7N1nZN0XD5H5+X8ZMDWbCs4= github.com/Psiphon-Labs/bolt v0.0.0-20200624191537-23cedaef7ad7 h1:Hx/NCZTnvoKZuIBwSmxE58KKoNLXIGG6hBJYN7pj9Ag= From b7aaba911d58d8c695eef9fd416ebc2b7a3d799a Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Thu, 7 Nov 2024 12:17:35 -0500 Subject: [PATCH 20/21] Rename --- x/sockopt/sockopt.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/x/sockopt/sockopt.go b/x/sockopt/sockopt.go index 94ca4037..44cb575e 100644 --- a/x/sockopt/sockopt.go +++ b/x/sockopt/sockopt.go @@ -68,13 +68,13 @@ func newHopLimit(conn net.Conn) (*hopLimitOption, error) { opt := &hopLimitOption{} switch { case addr.Addr().Is4(): - conn := ipv4.NewConn(conn) - opt.hopLimit = conn.TTL - opt.setHopLimit = conn.SetTTL + ipConn := ipv4.NewConn(conn) + opt.hopLimit = ipConn.TTL + opt.setHopLimit = ipConn.SetTTL case addr.Addr().Is6(): - conn := ipv6.NewConn(conn) - opt.hopLimit = conn.HopLimit - opt.setHopLimit = conn.SetHopLimit + ipConn := ipv6.NewConn(conn) + opt.hopLimit = ipConn.HopLimit + opt.setHopLimit = ipConn.SetHopLimit default: return nil, fmt.Errorf("address is not IPv4 or IPv6 (%v)", addr.Addr().String()) } From 851a01a4d78e65a453d461b7b8d7d250ca0bb18b Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Thu, 7 Nov 2024 17:26:20 -0500 Subject: [PATCH 21/21] fix: disable revocation test --- transport/tls/stream_dialer_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/transport/tls/stream_dialer_test.go b/transport/tls/stream_dialer_test.go index 42781439..3f6980f9 100644 --- a/transport/tls/stream_dialer_test.go +++ b/transport/tls/stream_dialer_test.go @@ -17,7 +17,6 @@ package tls import ( "context" "crypto/x509" - "runtime" "testing" "github.com/Jigsaw-Code/outline-sdk/transport" @@ -55,9 +54,10 @@ func TestExpired(t *testing.T) { } func TestRevoked(t *testing.T) { - if runtime.GOOS == "linux" || runtime.GOOS == "windows" { - t.Skip("Certificate revocation list is not up-to-date in Linux and Windows") - } + t.Skip("Certificate revocation list is not working") + + // TODO(fortuna): implement proper revocation test. + // See https://www.cossacklabs.com/blog/tls-validation-implementing-ocsp-and-crl-in-go/ sd, err := NewStreamDialer(&transport.TCPDialer{}) require.NoError(t, err)