Skip to content

Commit

Permalink
Add waitstream strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
PeterZhizhin committed Nov 12, 2024
1 parent e3d398e commit 3f82afb
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 0 deletions.
2 changes: 2 additions & 0 deletions x/configurl/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ func RegisterDefaultProviders(c *ProviderContainer) *ProviderContainer {
registerWebsocketStreamDialer(&c.StreamDialers, "ws", c.StreamDialers.NewInstance)
registerWebsocketPacketDialer(&c.PacketDialers, "ws", c.StreamDialers.NewInstance)

registerWaitStreamDialer(&c.StreamDialers, "waitstream", c.StreamDialers.NewInstance)

return c
}

Expand Down
32 changes: 32 additions & 0 deletions x/configurl/wait_stream.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright 2024 The Outline Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package configurl

import (
"context"

"github.com/Jigsaw-Code/outline-sdk/transport"
"github.com/Jigsaw-Code/outline-sdk/x/wait_stream"
)

func registerWaitStreamDialer(r TypeRegistry[transport.StreamDialer], typeID string, newSD BuildFunc[transport.StreamDialer]) {
r.RegisterType(typeID, func(ctx context.Context, config *Config) (transport.StreamDialer, error) {
sd, err := newSD(ctx, config.BaseConfig)
if err != nil {
return nil, err
}
return wait_stream.NewStreamDialer(sd)
})
}
58 changes: 58 additions & 0 deletions x/wait_stream/stream_dialer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Copyright 2024 The Outline Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package wait_stream

import (
"context"
"errors"
"net"

"github.com/Jigsaw-Code/outline-sdk/transport"
"github.com/Jigsaw-Code/outline-sdk/x/sockopt"
)

type waitStreamDialer struct {
dialer transport.StreamDialer
}

var _ transport.StreamDialer = (*waitStreamDialer)(nil)

func NewStreamDialer(dialer transport.StreamDialer) (transport.StreamDialer, error) {
if dialer == nil {
return nil, errors.New("argument dialer must not be nil")
}
return &waitStreamDialer{dialer: dialer}, nil
}

func (d *waitStreamDialer) DialStream(ctx context.Context, remoteAddr string) (transport.StreamConn, error) {
innerConn, err := d.dialer.DialStream(ctx, remoteAddr)
if err != nil {
return nil, err
}

tcpInnerConn, ok := innerConn.(*net.TCPConn)
if !ok {
return nil, errors.New("wait_stream strategy: expected base dialer to return TCPConn")
}

tcpOptions, err := sockopt.NewTCPOptions(tcpInnerConn)
if err != nil {
return nil, err
}

dw := NewWriter(innerConn, tcpOptions)

return transport.WrapConn(innerConn, innerConn, dw), nil
}
49 changes: 49 additions & 0 deletions x/wait_stream/writer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright 2024 The Outline Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package wait_stream

import (
"errors"
"fmt"
"io"

"github.com/Jigsaw-Code/outline-sdk/x/sockopt"
)

type waitStreamWriter struct {
conn io.Writer
tcpOptions sockopt.TCPOptions
}

var _ io.Writer = (*waitStreamWriter)(nil)

func NewWriter(conn io.Writer, tcpOptions sockopt.TCPOptions) io.Writer {
return &waitStreamWriter{
conn: conn,
tcpOptions: tcpOptions,
}
}

func (w *waitStreamWriter) Write(data []byte) (written int, err error) {
written, err = w.conn.Write(data)

// This may not be implemented, so it's best effort really.
waitUntilBytesAreSentErr := w.tcpOptions.WaitUntilBytesAreSent()
if waitUntilBytesAreSentErr != nil && !errors.Is(waitUntilBytesAreSentErr, errors.ErrUnsupported) {
return written, fmt.Errorf("error when waiting for stream to send all bytes: %w", waitUntilBytesAreSentErr)
}

return
}

0 comments on commit 3f82afb

Please sign in to comment.