diff --git a/x/configurl/module.go b/x/configurl/module.go index 83e14b89..3e11f10c 100644 --- a/x/configurl/module.go +++ b/x/configurl/module.go @@ -65,6 +65,8 @@ func RegisterDefaultProviders(c *ProviderContainer) *ProviderContainer { registerWebsocketStreamDialer(&c.StreamDialers, "ws", c.StreamDialers.NewInstance) registerWebsocketPacketDialer(&c.PacketDialers, "ws", c.StreamDialers.NewInstance) + registerWaitStreamDialer(&c.StreamDialers, "waitstream", c.StreamDialers.NewInstance) + return c } diff --git a/x/configurl/wait_stream.go b/x/configurl/wait_stream.go new file mode 100644 index 00000000..f6e1f4bf --- /dev/null +++ b/x/configurl/wait_stream.go @@ -0,0 +1,62 @@ +// Copyright 2024 The Outline Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package configurl + +import ( + "context" + "fmt" + "net/url" + "time" + + "github.com/Jigsaw-Code/outline-sdk/transport" + "github.com/Jigsaw-Code/outline-sdk/x/wait_stream" +) + +func registerWaitStreamDialer(r TypeRegistry[transport.StreamDialer], typeID string, newSD BuildFunc[transport.StreamDialer]) { + r.RegisterType(typeID, func(ctx context.Context, config *Config) (transport.StreamDialer, error) { + sd, err := newSD(ctx, config.BaseConfig) + if err != nil { + return nil, err + } + + queryUrlParameters, err := url.ParseQuery(config.URL.Opaque) + if err != nil { + return nil, fmt.Errorf("waitstream: failed to parse URL parameters: %w", err) + } + + resultStreamDialer, err := wait_stream.NewStreamDialer(sd) + if err != nil { + return nil, err + } + + if queryUrlParameters.Has("timeout") { + timeout, err := time.ParseDuration(queryUrlParameters.Get("timeout")) + if err != nil { + return nil, fmt.Errorf("waitstream: failed to parse timeout parameter: %w", err) + } + resultStreamDialer.SetWaitingTimeout(timeout) + } + + if queryUrlParameters.Has("delay") { + delay, err := time.ParseDuration(queryUrlParameters.Get("delay")) + if err != nil { + return nil, fmt.Errorf("waitstream: failed to parse delay parameter: %w", err) + } + resultStreamDialer.SetWaitingDelay(delay) + } + + return resultStreamDialer, err + }) +} diff --git a/x/wait_stream/is_sending_bytes_linux.go b/x/wait_stream/is_sending_bytes_linux.go new file mode 100644 index 00000000..751a7045 --- /dev/null +++ b/x/wait_stream/is_sending_bytes_linux.go @@ -0,0 +1,15 @@ +//go:build linux + +package wait_stream + +import ( + "golang.org/x/sys/unix" +) + +func isSocketFdSendingBytes(fd int) (bool, error) { + tcpInfo, err := unix.GetsockoptTCPInfo(fd, unix.IPPROTO_TCP, unix.TCP_INFO) + if err != nil { + return false, err + } + return tcpInfo.Notsent_bytes != 0, nil +} diff --git a/x/wait_stream/is_sending_bytes_not_implemented.go b/x/wait_stream/is_sending_bytes_not_implemented.go new file mode 100644 index 00000000..2c89203a --- /dev/null +++ b/x/wait_stream/is_sending_bytes_not_implemented.go @@ -0,0 +1,12 @@ +//go:build !linux + +package wait_stream + +import ( + "errors" + "fmt" +) + +func isSocketFdSendingBytes(_ int) (bool, error) { + return false, fmt.Errorf("%w: checking if socket is sending bytes is not implemented on this platform", errors.ErrUnsupported) +} diff --git a/x/wait_stream/stream_dialer.go b/x/wait_stream/stream_dialer.go new file mode 100644 index 00000000..4a1f8bc1 --- /dev/null +++ b/x/wait_stream/stream_dialer.go @@ -0,0 +1,76 @@ +// Copyright 2024 The Outline Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wait_stream + +import ( + "context" + "errors" + "net" + "time" + + "github.com/Jigsaw-Code/outline-sdk/transport" +) + +type WaitStreamDialer struct { + dialer transport.StreamDialer + + // Stop waiting on a packet after this timeout + waitingTimeout time.Duration + // Check if socket is sending bytes that often + waitingDelay time.Duration +} + +var _ transport.StreamDialer = (*WaitStreamDialer)(nil) + +// byeDPI uses a default delay of 500ms with 1ms sleep +// We might reconsider the defaults later, if needed. +// https://github.com/hufrea/byedpi/blob/main/desync.c#L90 +var defaultTimeout = time.Millisecond * 10 +var defaultDelay = time.Microsecond * 1 + +func NewStreamDialer(dialer transport.StreamDialer) (*WaitStreamDialer, error) { + if dialer == nil { + return nil, errors.New("argument dialer must not be nil") + } + return &WaitStreamDialer{ + dialer: dialer, + waitingTimeout: defaultTimeout, + waitingDelay: defaultDelay, + }, nil +} + +func (d *WaitStreamDialer) SetWaitingTimeout(timeout time.Duration) { + d.waitingTimeout = timeout +} + +func (d *WaitStreamDialer) SetWaitingDelay(timeout time.Duration) { + d.waitingDelay = timeout +} + +func (d *WaitStreamDialer) DialStream(ctx context.Context, remoteAddr string) (transport.StreamConn, error) { + innerConn, err := d.dialer.DialStream(ctx, remoteAddr) + if err != nil { + return nil, err + } + + tcpInnerConn, ok := innerConn.(*net.TCPConn) + if !ok { + return nil, errors.New("wait_stream strategy: expected base dialer to return TCPConn") + } + + dw := NewWriter(tcpInnerConn, d.waitingTimeout, d.waitingDelay) + + return transport.WrapConn(innerConn, innerConn, dw), nil +} diff --git a/x/wait_stream/writer.go b/x/wait_stream/writer.go new file mode 100644 index 00000000..ab59e361 --- /dev/null +++ b/x/wait_stream/writer.go @@ -0,0 +1,79 @@ +// Copyright 2024 The Outline Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wait_stream + +import ( + "errors" + "fmt" + "io" + "net" + "time" +) + +type waitStreamWriter struct { + conn *net.TCPConn + + waitingTimeout time.Duration + waitingDelay time.Duration +} + +var _ io.Writer = (*waitStreamWriter)(nil) + +func NewWriter(conn *net.TCPConn, waitingTimeout time.Duration, waitingDelay time.Duration) io.Writer { + return &waitStreamWriter{ + conn: conn, + waitingTimeout: waitingTimeout, + waitingDelay: waitingDelay, + } +} + +func isConnectionSendingBytes(conn *net.TCPConn) (result bool, err error) { + syscallConn, err := conn.SyscallConn() + if err != nil { + return false, err + } + syscallConn.Control(func(fd uintptr) { + result, err = isSocketFdSendingBytes(int(fd)) + }) + return +} + +func waitUntilBytesAreSent(conn *net.TCPConn, waitingTimeout time.Duration, waitingDelay time.Duration) error { + startTime := time.Now() + for time.Since(startTime) < waitingTimeout { + isSendingBytes, err := isConnectionSendingBytes(conn) + if err != nil { + return err + } + if !isSendingBytes { + return nil + } + + time.Sleep(waitingDelay) + } + // not sure about the right behaviour here: fail or give up waiting? + // giving up feels safer, and matches byeDPI behavior + return nil +} + +func (w *waitStreamWriter) Write(data []byte) (written int, err error) { + // This may not be implemented, so it's best effort really. + waitUntilBytesAreSentErr := waitUntilBytesAreSent(w.conn, w.waitingTimeout, w.waitingDelay) + if waitUntilBytesAreSentErr != nil && !errors.Is(waitUntilBytesAreSentErr, errors.ErrUnsupported) { + return 0, fmt.Errorf("error when waiting for stream to send all bytes: %w", waitUntilBytesAreSentErr) + } + + return w.conn.Write(data) +}