From 64afbfd7e859cd7cd0bb6c8dd30baa917bc71d5a Mon Sep 17 00:00:00 2001 From: Petr Zhizhin Date: Tue, 12 Nov 2024 22:30:00 +0100 Subject: [PATCH] Add waitstream strategy --- x/configurl/module.go | 2 ++ x/configurl/wait_stream.go | 32 +++++++++++++++++++ x/wait_stream/stream_dialer.go | 58 ++++++++++++++++++++++++++++++++++ x/wait_stream/writer.go | 49 ++++++++++++++++++++++++++++ 4 files changed, 141 insertions(+) create mode 100644 x/configurl/wait_stream.go create mode 100644 x/wait_stream/stream_dialer.go create mode 100644 x/wait_stream/writer.go diff --git a/x/configurl/module.go b/x/configurl/module.go index 83e14b89..3e11f10c 100644 --- a/x/configurl/module.go +++ b/x/configurl/module.go @@ -65,6 +65,8 @@ func RegisterDefaultProviders(c *ProviderContainer) *ProviderContainer { registerWebsocketStreamDialer(&c.StreamDialers, "ws", c.StreamDialers.NewInstance) registerWebsocketPacketDialer(&c.PacketDialers, "ws", c.StreamDialers.NewInstance) + registerWaitStreamDialer(&c.StreamDialers, "waitstream", c.StreamDialers.NewInstance) + return c } diff --git a/x/configurl/wait_stream.go b/x/configurl/wait_stream.go new file mode 100644 index 00000000..337f6b05 --- /dev/null +++ b/x/configurl/wait_stream.go @@ -0,0 +1,32 @@ +// Copyright 2024 The Outline Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package configurl + +import ( + "context" + + "github.com/Jigsaw-Code/outline-sdk/transport" + "github.com/Jigsaw-Code/outline-sdk/x/wait_stream" +) + +func registerWaitStreamDialer(r TypeRegistry[transport.StreamDialer], typeID string, newSD BuildFunc[transport.StreamDialer]) { + r.RegisterType(typeID, func(ctx context.Context, config *Config) (transport.StreamDialer, error) { + sd, err := newSD(ctx, config.BaseConfig) + if err != nil { + return nil, err + } + return wait_stream.NewStreamDialer(sd) + }) +} diff --git a/x/wait_stream/stream_dialer.go b/x/wait_stream/stream_dialer.go new file mode 100644 index 00000000..e2ffff98 --- /dev/null +++ b/x/wait_stream/stream_dialer.go @@ -0,0 +1,58 @@ +// Copyright 2024 The Outline Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wait_stream + +import ( + "context" + "errors" + "net" + + "github.com/Jigsaw-Code/outline-sdk/transport" + "github.com/Jigsaw-Code/outline-sdk/x/sockopt" +) + +type waitStreamDialer struct { + dialer transport.StreamDialer +} + +var _ transport.StreamDialer = (*waitStreamDialer)(nil) + +func NewStreamDialer(dialer transport.StreamDialer) (transport.StreamDialer, error) { + if dialer == nil { + return nil, errors.New("argument dialer must not be nil") + } + return &waitStreamDialer{dialer: dialer}, nil +} + +func (d *waitStreamDialer) DialStream(ctx context.Context, remoteAddr string) (transport.StreamConn, error) { + innerConn, err := d.dialer.DialStream(ctx, remoteAddr) + if err != nil { + return nil, err + } + + tcpInnerConn, ok := innerConn.(*net.TCPConn) + if !ok { + return nil, errors.New("wait_stream strategy: expected base dialer to return TCPConn") + } + + tcpOptions, err := sockopt.NewTCPOptions(tcpInnerConn) + if err != nil { + return nil, err + } + + dw := NewWriter(innerConn, tcpOptions) + + return transport.WrapConn(innerConn, innerConn, dw), nil +} diff --git a/x/wait_stream/writer.go b/x/wait_stream/writer.go new file mode 100644 index 00000000..103a3a64 --- /dev/null +++ b/x/wait_stream/writer.go @@ -0,0 +1,49 @@ +// Copyright 2024 The Outline Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wait_stream + +import ( + "errors" + "fmt" + "io" + + "github.com/Jigsaw-Code/outline-sdk/x/sockopt" +) + +type waitStreamWriter struct { + conn io.Writer + tcpOptions sockopt.TCPOptions +} + +var _ io.Writer = (*waitStreamWriter)(nil) + +func NewWriter(conn io.Writer, tcpOptions sockopt.TCPOptions) io.Writer { + return &waitStreamWriter{ + conn: conn, + tcpOptions: tcpOptions, + } +} + +func (w *waitStreamWriter) Write(data []byte) (written int, err error) { + written, err = w.conn.Write(data) + + // This may not be implemented, so it's best effort really. + waitUntilBytesAreSentErr := w.tcpOptions.WaitUntilBytesAreSent() + if waitUntilBytesAreSentErr != nil && !errors.Is(waitUntilBytesAreSentErr, errors.ErrUnsupported) { + return written, fmt.Errorf("error when waiting for stream to send all bytes: %w", waitUntilBytesAreSentErr) + } + + return +}