diff --git a/x/configurl/config.go b/x/configurl/config.go index d074b088..2c427f8d 100644 --- a/x/configurl/config.go +++ b/x/configurl/config.go @@ -19,11 +19,9 @@ import ( "errors" "fmt" "net/url" - "strconv" "strings" "github.com/Jigsaw-Code/outline-sdk/transport" - "github.com/Jigsaw-Code/outline-sdk/transport/tlsfrag" ) // ConfigToDialer enables the creation of stream and packet dialers based on a config. The config is @@ -144,21 +142,10 @@ func NewDefaultConfigToDialer() *ConfigToDialer { registerTLSStreamDialer(p, "tls", p.NewStreamDialerFromConfig) - p.RegisterStreamDialerType("tlsfrag", func(ctx context.Context, config *Config) (transport.StreamDialer, error) { - sd, err := p.NewStreamDialerFromConfig(ctx, config.BaseConfig) - if err != nil { - return nil, err - } - lenStr := config.URL.Opaque - fixedLen, err := strconv.Atoi(lenStr) - if err != nil { - return nil, fmt.Errorf("invalid tlsfrag option: %v. It should be in tlsfrag: format", lenStr) - } - return tlsfrag.NewFixedLenStreamDialer(sd, fixedLen) - }) + registerTLSFragStreamDialer(p, "tlsfrag", p.NewStreamDialerFromConfig) - p.RegisterStreamDialerType("ws", newWebsocketStreamDialerFactory(p.NewStreamDialerFromConfig)) - p.RegisterPacketDialerType("ws", newWebsocketPacketDialerFactory(p.NewStreamDialerFromConfig)) + registerWebsocketStreamDialer(p, "ws", p.NewStreamDialerFromConfig) + registerWebsocketPacketDialer(p, "ws", p.NewStreamDialerFromConfig) return p } diff --git a/x/configurl/tlsfrag.go b/x/configurl/tlsfrag.go new file mode 100644 index 00000000..687fbfc2 --- /dev/null +++ b/x/configurl/tlsfrag.go @@ -0,0 +1,39 @@ +// Copyright 2023 The Outline Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package configurl + +import ( + "context" + "fmt" + "strconv" + + "github.com/Jigsaw-Code/outline-sdk/transport" + "github.com/Jigsaw-Code/outline-sdk/transport/tlsfrag" +) + +func registerTLSFragStreamDialer(r StreamDialerRegistry, typeID string, newSD NewStreamDialerFunc) { + r.RegisterStreamDialerType(typeID, func(ctx context.Context, config *Config) (transport.StreamDialer, error) { + sd, err := newSD(ctx, config.BaseConfig) + if err != nil { + return nil, err + } + lenStr := config.URL.Opaque + fixedLen, err := strconv.Atoi(lenStr) + if err != nil { + return nil, fmt.Errorf("invalid tlsfrag option: %v. It should be in tlsfrag: format", lenStr) + } + return tlsfrag.NewFixedLenStreamDialer(sd, fixedLen) + }) +} diff --git a/x/configurl/websocket.go b/x/configurl/websocket.go index 659fba15..99e64130 100644 --- a/x/configurl/websocket.go +++ b/x/configurl/websocket.go @@ -71,8 +71,8 @@ func (c *wsToStreamConn) CloseWrite() error { return c.Close() } -func newWebsocketStreamDialerFactory(newSD NewStreamDialerFunc) NewStreamDialerFunc { - return func(ctx context.Context, config *Config) (transport.StreamDialer, error) { +func registerWebsocketStreamDialer(r StreamDialerRegistry, typeID string, newSD NewStreamDialerFunc) { + r.RegisterStreamDialerType(typeID, func(ctx context.Context, config *Config) (transport.StreamDialer, error) { sd, err := newSD(ctx, config.BaseConfig) if err != nil { return nil, err @@ -102,11 +102,11 @@ func newWebsocketStreamDialerFactory(newSD NewStreamDialerFunc) NewStreamDialerF } return &wsToStreamConn{wsConn}, nil }), nil - } + }) } -func newWebsocketPacketDialerFactory(newSD NewStreamDialerFunc) NewPacketDialerFunc { - return func(ctx context.Context, config *Config) (transport.PacketDialer, error) { +func registerWebsocketPacketDialer(r PacketDialerRegistry, typeID string, newSD NewStreamDialerFunc) { + r.RegisterPacketDialerType(typeID, func(ctx context.Context, config *Config) (transport.PacketDialer, error) { sd, err := newSD(ctx, config.BaseConfig) if err != nil { return nil, err @@ -136,5 +136,5 @@ func newWebsocketPacketDialerFactory(newSD NewStreamDialerFunc) NewPacketDialerF } return wsConn, nil }), nil - } + }) }