Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try ReaderFrom #131

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions network/lwip2transport/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import (
"context"
"io"
"net"

"github.com/Jigsaw-Code/outline-sdk/transport"
Expand All @@ -41,7 +40,7 @@
return err
}
// TODO: Request upstream to make `conn` a `core.TCPConn` so we can avoid this type assertion.
go relay(conn.(lwip.TCPConn), proxyConn)

Check failure on line 43 in network/lwip2transport/tcp.go

View workflow job for this annotation

GitHub Actions / test (ubuntu-latest)

cannot use conn.(lwip.TCPConn) (comma, ok expression of type core.TCPConn) as transport.StreamConn value in argument to relay: core.TCPConn does not implement transport.StreamConn (missing method ReadFrom)
return nil
}

Expand All @@ -52,7 +51,7 @@
//
// rightConn's read end and leftConn's write end will be closed after copyOneWay returns.
func copyOneWay(leftConn, rightConn transport.StreamConn) (int64, error) {
n, err := io.Copy(leftConn, rightConn)
n, err := leftConn.ReadFrom(rightConn)
// Send FIN to indicate EOF
leftConn.CloseWrite()
// Release reader resources
Expand Down
42 changes: 42 additions & 0 deletions transport/io.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright 2019 Jigsaw Operations LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package transport

import (
"bytes"
"io"
"sync"
)

type writerAdaptor struct {
// The underlying io.ReaderFrom.
rf io.ReaderFrom
// Coverts []byte to io.Reader.
r bytes.Reader
// Sequences the writes.
mu sync.Mutex
}

func (w *writerAdaptor) Write(buf []byte) (int, error) {
w.mu.Lock()
defer w.mu.Unlock()
w.r.Reset(buf)
n, err := w.rf.ReadFrom(&w.r)
return int(n), err
}

func AsWriter(rf io.ReaderFrom) io.Writer {
return &writerAdaptor{rf: rf}
}
6 changes: 4 additions & 2 deletions transport/shadowsocks/compatibility_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package shadowsocks

import (
"bytes"
"io"
"net"
"sync"
Expand All @@ -36,11 +37,12 @@ func TestCompatibility(t *testing.T) {
wait.Add(1)
key, err := NewEncryptionKey(cipherName, secret)
require.NoError(t, err, "NewCipher failed: %v", err)
ssWriter := NewWriter(left, key)
// TODO: wrap left in a ReaderFrom adaptor.
ssWriter := NewReaderFrom(bytes.NewReader(left), key)
go func() {
defer wait.Done()
var err error
ssWriter.Write([]byte(fromLeft))
ssWriter.ReadFrom(bytes.NewReader([]byte(fromLeft)))

ssReader := NewReader(left, key)
receivedByLeft := make([]byte, len(fromRight))
Expand Down
20 changes: 5 additions & 15 deletions transport/shadowsocks/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
package shadowsocks

import (
"bytes"
"crypto/cipher"
"encoding/binary"
"fmt"
Expand Down Expand Up @@ -44,11 +43,9 @@
mu sync.Mutex
// Indicates that a concurrent flush is currently allowed.
needFlush bool
writer io.Writer
rf io.ReaderFrom
key *EncryptionKey
saltGenerator SaltGenerator
// Wrapper for input that arrives as a slice.
byteWrapper bytes.Reader
// Number of plaintext bytes that are currently buffered.
pending int
// These are populated by init():
Expand All @@ -59,14 +56,13 @@
}

var (
_ io.Writer = (*Writer)(nil)
_ io.ReaderFrom = (*Writer)(nil)
)

// NewWriter creates a [Writer] that encrypts the given [io.Writer] using
// NewReaderFrom creates a [Writer] that encrypts the given [io.Writer] using
// the shadowsocks protocol with the given encryption key.
func NewWriter(writer io.Writer, key *EncryptionKey) *Writer {
return &Writer{writer: writer, key: key, saltGenerator: RandomSaltGenerator}
func NewReaderFrom(rf io.ReaderFrom, key *EncryptionKey) *Writer {
return &Writer{rf: rf, key: key, saltGenerator: RandomSaltGenerator}
}

// SetSaltGenerator sets the salt generator to be used. Must be called before the first write.
Expand Down Expand Up @@ -107,12 +103,6 @@
return len(out)
}

func (sw *Writer) Write(p []byte) (int, error) {
sw.byteWrapper.Reset(p)
n, err := sw.ReadFrom(&sw.byteWrapper)
return int(n), err
}

// LazyWrite queues p to be written, but doesn't send it until Flush() is
// called, a non-lazy write is made, or the buffer is filled.
func (sw *Writer) LazyWrite(p []byte) (int, error) {
Expand Down Expand Up @@ -252,7 +242,7 @@
binary.BigEndian.PutUint16(sizeBuf, uint16(sw.pending))
sizeBlockSize := sw.encryptBlock(sizeBuf)
payloadSize := sw.encryptBlock(payloadBuf[:sw.pending])
_, err := sw.writer.Write(sw.buf[start : saltSize+sizeBlockSize+payloadSize])
_, err := sw.rf.Write(sw.buf[start : saltSize+sizeBlockSize+payloadSize])

Check failure on line 245 in transport/shadowsocks/stream.go

View workflow job for this annotation

GitHub Actions / test (ubuntu-latest)

sw.rf.Write undefined (type io.ReaderFrom has no field or method Write)
sw.pending = 0
return err
}
Expand Down
2 changes: 1 addition & 1 deletion transport/shadowsocks/stream_dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func (c *StreamDialer) Dial(ctx context.Context, remoteAddr string) (transport.S
if err != nil {
return nil, err
}
ssw := NewWriter(proxyConn, c.key)
ssw := NewReaderFrom(proxyConn, c.key)
if c.SaltGenerator != nil {
ssw.SetSaltGenerator(c.SaltGenerator)
}
Expand Down
8 changes: 4 additions & 4 deletions transport/shadowsocks/stream_dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ func TestStreamDialer_TCPPrefix(t *testing.T) {
if err != nil {
t.Fatalf("StreamDialer.Dial failed: %v", err)
}
conn.Write(nil)
conn.ReadFrom(nil)
conn.Close()
running.Wait()
}
Expand Down Expand Up @@ -212,8 +212,8 @@ func startShadowsocksTCPEchoProxy(key *EncryptionKey, expectedTgtAddr string, t
defer running.Done()
defer clientConn.Close()
ssr := NewReader(clientConn, key)
ssw := NewWriter(clientConn, key)
ssClientConn := transport.WrapConn(clientConn, ssr, ssw)
ssrf := NewReaderFrom(clientConn, key)
ssClientConn := transport.WrapConn(clientConn, ssr, ssrf)

tgtAddr, err := socks.ReadAddr(ssClientConn)
if err != nil {
Expand All @@ -222,7 +222,7 @@ func startShadowsocksTCPEchoProxy(key *EncryptionKey, expectedTgtAddr string, t
if tgtAddr.String() != expectedTgtAddr {
t.Fatalf("Expected target address '%v'. Got '%v'", expectedTgtAddr, tgtAddr)
}
io.Copy(ssw, ssr)
ssrf.ReadFrom(ssr)
}()
}
}()
Expand Down
37 changes: 19 additions & 18 deletions transport/shadowsocks/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,8 @@ func TestEndToEnd(t *testing.T) {
require.NoError(t, err)

connReader, connWriter := io.Pipe()
writer := NewWriter(connWriter, key)
// TODO: ReaderFrom to Writer adapter
readerFrom := NewReaderFrom(connWriter, key)
reader := NewReader(connReader, key)
expected := "Test"
wg := sync.WaitGroup{}
Expand All @@ -175,7 +176,7 @@ func TestEndToEnd(t *testing.T) {
defer connWriter.Close()
wg.Add(1)
defer wg.Done()
_, writeErr = writer.Write([]byte(expected))
_, writeErr = readerFrom.ReadFrom(bytes.NewReader([]byte(expected)))
}()
var output bytes.Buffer
_, readErr := reader.WriteTo(&output)
Expand All @@ -195,24 +196,24 @@ func TestLazyWriteFlush(t *testing.T) {
key, err := NewEncryptionKey(CHACHA20IETFPOLY1305, "test secret")
require.NoError(t, err)
buf := new(bytes.Buffer)
writer := NewWriter(buf, key)
readerFrom := NewReaderFrom(buf, key)
header := []byte{1, 2, 3, 4}
n, err := writer.LazyWrite(header)
n, err := readerFrom.LazyWrite(header)
require.NoError(t, err, "LazyWrite failed: %v", err)
require.Equal(t, len(header), n, "Wrong write size")
require.Equal(t, 0, buf.Len(), "LazyWrite isn't lazy")

err = writer.Flush()
err = readerFrom.Flush()
require.NoError(t, err, "Flush failed: %v", err)

len1 := buf.Len()
require.Greater(t, len1, len(header), "Not enough bytes flushed")

// Check that normal writes now work
body := []byte{5, 6, 7}
n, err = writer.Write(body)
n64, err := readerFrom.ReadFrom(bytes.NewReader(body))
require.NoError(t, err, "Write failed: %v", err)
require.Equal(t, len(body), n, "Wrong write size")
require.Equal(t, int64(len(body)), n64, "Wrong write size")
require.Greater(t, buf.Len(), len1, "No write observed")

// Verify content arrives in two blocks
Expand All @@ -233,9 +234,9 @@ func TestLazyWriteConcat(t *testing.T) {
key, err := NewEncryptionKey(CHACHA20IETFPOLY1305, "test secret")
require.NoError(t, err)
buf := new(bytes.Buffer)
writer := NewWriter(buf, key)
readerFrom := NewReaderFrom(buf, key)
header := []byte{1, 2, 3, 4}
n, err := writer.LazyWrite(header)
n, err := readerFrom.LazyWrite(header)
if n != len(header) {
t.Errorf("Wrong write size: %d", n)
}
Expand All @@ -248,8 +249,8 @@ func TestLazyWriteConcat(t *testing.T) {

// Write additional data and flush the header.
body := []byte{5, 6, 7}
n, err = writer.Write(body)
if n != len(body) {
n64, err := readerFrom.ReadFrom(bytes.NewReader(body))
if int(n64) != len(body) {
t.Errorf("Wrong write size: %d", n)
}
if err != nil {
Expand All @@ -261,7 +262,7 @@ func TestLazyWriteConcat(t *testing.T) {
}

// Flush after write should have no effect
if err = writer.Flush(); err != nil {
if err = readerFrom.Flush(); err != nil {
t.Errorf("Flush failed: %v", err)
}
if buf.Len() != len1 {
Expand All @@ -288,7 +289,7 @@ func TestLazyWriteOversize(t *testing.T) {
key, err := NewEncryptionKey(CHACHA20IETFPOLY1305, "test secret")
require.NoError(t, err)
buf := new(bytes.Buffer)
writer := NewWriter(buf, key)
writer := NewReaderFrom(buf, key)
N := 25000 // More than one block, less than two.
data := make([]byte, N)
for i := range data {
Expand Down Expand Up @@ -329,7 +330,7 @@ func TestLazyWriteConcurrentFlush(t *testing.T) {
key, err := NewEncryptionKey(CHACHA20IETFPOLY1305, "test secret")
require.NoError(t, err)
buf := new(bytes.Buffer)
writer := NewWriter(buf, key)
writer := NewReaderFrom(buf, key)
header := []byte{1, 2, 3, 4}
n, err := writer.LazyWrite(header)
require.NoError(t, err, "LazyWrite failed: %v", err)
Expand Down Expand Up @@ -382,8 +383,8 @@ func TestLazyWriteConcurrentFlush(t *testing.T) {

type nullIO struct{}

func (n *nullIO) Write(b []byte) (int, error) {
return len(b), nil
func (n *nullIO) ReadFrom(r io.Reader) (int64, error) {
return io.Copy(io.Discard, r)
}

func (r *nullIO) Read(b []byte) (int, error) {
Expand All @@ -397,11 +398,11 @@ func BenchmarkWriter(b *testing.B) {

key, err := NewEncryptionKey(CHACHA20IETFPOLY1305, "test secret")
require.NoError(b, err)
writer := NewWriter(new(nullIO), key)
readerFrom := NewReaderFrom(new(nullIO), key)

start := time.Now()
b.StartTimer()
io.CopyN(writer, new(nullIO), int64(b.N))
readerFrom.ReadFrom(&io.LimitedReader{R: new(nullIO), N: int64(b.N)})
b.StopTimer()
elapsed := time.Since(start)

Expand Down
30 changes: 15 additions & 15 deletions transport/socks5/socks5.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package socks5
import (
"errors"
"fmt"
"io"
"net"
"strconv"
)
Expand Down Expand Up @@ -72,16 +73,16 @@ const (
addrTypeIPv6 = 0x04
)

// appendSOCKS5Address adds the address to buffer b in SOCKS5 format,
// writeSOCKS5Address adds the address to buffer b in SOCKS5 format,
// as specified in https://datatracker.ietf.org/doc/html/rfc1928#section-4
func appendSOCKS5Address(b []byte, address string) ([]byte, error) {
func writeSOCKS5Address(w io.Writer, address string) error {
host, portStr, err := net.SplitHostPort(address)
if err != nil {
return nil, err
return err
}
portNum, err := strconv.Atoi(portStr)
if err != nil {
return nil, err
return err
}
// The SOCKS address format is as follows:
// +------+----------+----------+
Expand All @@ -92,23 +93,22 @@ func appendSOCKS5Address(b []byte, address string) ([]byte, error) {
// See https://datatracker.ietf.org/doc/html/rfc1928#section-5 for DST.ADDR details.
if ip := net.ParseIP(host); ip != nil {
if ip4 := ip.To4(); ip4 != nil {
b = append(b, addrTypeIPv4)
b = append(b, ip4...)
w.Write([]byte{addrTypeIPv4})
w.Write(ip4)
} else if ip6 := ip.To16(); ip6 != nil {
b = append(b, addrTypeIPv6)
b = append(b, ip6...)
w.Write([]byte{addrTypeIPv6})
w.Write(ip6)
} else {
// This should never happen.
return nil, errors.New("IP address not IPv4 or IPv6")
return errors.New("IP address not IPv4 or IPv6")
}
} else {
if len(host) > 255 {
return nil, fmt.Errorf("domain name length = %v is over 255", len(host))
return fmt.Errorf("domain name length = %v is over 255", len(host))
}
b = append(b, addrTypeDomainName)
b = append(b, byte(len(host)))
b = append(b, host...)
w.Write([]byte{addrTypeDomainName, byte(len(host))})
w.Write([]byte(host))
}
b = append(b, byte(portNum>>8), byte(portNum))
return b, nil
w.Write([]byte{byte(portNum >> 8), byte(portNum)})
return nil
}
Loading
Loading