Skip to content

Commit

Permalink
Try ReaderFrom
Browse files Browse the repository at this point in the history
  • Loading branch information
fortuna committed Nov 2, 2023
1 parent c103980 commit 8257d95
Show file tree
Hide file tree
Showing 20 changed files with 444 additions and 277 deletions.
3 changes: 1 addition & 2 deletions network/lwip2transport/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ package lwip2transport

import (
"context"
"io"
"net"

"github.com/Jigsaw-Code/outline-sdk/transport"
Expand Down Expand Up @@ -52,7 +51,7 @@ func (h *tcpHandler) Handle(conn net.Conn, target *net.TCPAddr) error {
//
// rightConn's read end and leftConn's write end will be closed after copyOneWay returns.
func copyOneWay(leftConn, rightConn transport.StreamConn) (int64, error) {
n, err := io.Copy(leftConn, rightConn)
n, err := leftConn.ReadFrom(rightConn)
// Send FIN to indicate EOF
leftConn.CloseWrite()
// Release reader resources
Expand Down
42 changes: 42 additions & 0 deletions transport/io.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright 2019 Jigsaw Operations LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package transport

import (
"bytes"
"io"
"sync"
)

type writerAdaptor struct {
// The underlying io.ReaderFrom.
rf io.ReaderFrom
// Coverts []byte to io.Reader.
r bytes.Reader
// Sequences the writes.
mu sync.Mutex
}

func (w *writerAdaptor) Write(buf []byte) (int, error) {
w.mu.Lock()
defer w.mu.Unlock()
w.r.Reset(buf)
n, err := w.rf.ReadFrom(&w.r)
return int(n), err
}

func AsWriter(rf io.ReaderFrom) io.Writer {
return &writerAdaptor{rf: rf}
}
6 changes: 4 additions & 2 deletions transport/shadowsocks/compatibility_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package shadowsocks

import (
"bytes"
"io"
"net"
"sync"
Expand All @@ -36,11 +37,12 @@ func TestCompatibility(t *testing.T) {
wait.Add(1)
key, err := NewEncryptionKey(cipherName, secret)
require.NoError(t, err, "NewCipher failed: %v", err)
ssWriter := NewWriter(left, key)
// TODO: wrap left in a ReaderFrom adaptor.
ssWriter := NewReaderFrom(bytes.NewReader(left), key)
go func() {
defer wait.Done()
var err error
ssWriter.Write([]byte(fromLeft))
ssWriter.ReadFrom(bytes.NewReader([]byte(fromLeft)))

ssReader := NewReader(left, key)
receivedByLeft := make([]byte, len(fromRight))
Expand Down
20 changes: 5 additions & 15 deletions transport/shadowsocks/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
package shadowsocks

import (
"bytes"
"crypto/cipher"
"encoding/binary"
"fmt"
Expand Down Expand Up @@ -44,11 +43,9 @@ type Writer struct {
mu sync.Mutex
// Indicates that a concurrent flush is currently allowed.
needFlush bool
writer io.Writer
rf io.ReaderFrom
key *EncryptionKey
saltGenerator SaltGenerator
// Wrapper for input that arrives as a slice.
byteWrapper bytes.Reader
// Number of plaintext bytes that are currently buffered.
pending int
// These are populated by init():
Expand All @@ -59,14 +56,13 @@ type Writer struct {
}

var (
_ io.Writer = (*Writer)(nil)
_ io.ReaderFrom = (*Writer)(nil)
)

// NewWriter creates a [Writer] that encrypts the given [io.Writer] using
// NewReaderFrom creates a [Writer] that encrypts the given [io.Writer] using
// the shadowsocks protocol with the given encryption key.
func NewWriter(writer io.Writer, key *EncryptionKey) *Writer {
return &Writer{writer: writer, key: key, saltGenerator: RandomSaltGenerator}
func NewReaderFrom(rf io.ReaderFrom, key *EncryptionKey) *Writer {
return &Writer{rf: rf, key: key, saltGenerator: RandomSaltGenerator}
}

// SetSaltGenerator sets the salt generator to be used. Must be called before the first write.
Expand Down Expand Up @@ -107,12 +103,6 @@ func (sw *Writer) encryptBlock(plaintext []byte) int {
return len(out)
}

func (sw *Writer) Write(p []byte) (int, error) {
sw.byteWrapper.Reset(p)
n, err := sw.ReadFrom(&sw.byteWrapper)
return int(n), err
}

// LazyWrite queues p to be written, but doesn't send it until Flush() is
// called, a non-lazy write is made, or the buffer is filled.
func (sw *Writer) LazyWrite(p []byte) (int, error) {
Expand Down Expand Up @@ -252,7 +242,7 @@ func (sw *Writer) flush() error {
binary.BigEndian.PutUint16(sizeBuf, uint16(sw.pending))
sizeBlockSize := sw.encryptBlock(sizeBuf)
payloadSize := sw.encryptBlock(payloadBuf[:sw.pending])
_, err := sw.writer.Write(sw.buf[start : saltSize+sizeBlockSize+payloadSize])
_, err := sw.rf.Write(sw.buf[start : saltSize+sizeBlockSize+payloadSize])

Check failure on line 245 in transport/shadowsocks/stream.go

View workflow job for this annotation

GitHub Actions / test (ubuntu-latest)

sw.rf.Write undefined (type io.ReaderFrom has no field or method Write)
sw.pending = 0
return err
}
Expand Down
2 changes: 1 addition & 1 deletion transport/shadowsocks/stream_dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func (c *StreamDialer) Dial(ctx context.Context, remoteAddr string) (transport.S
if err != nil {
return nil, err
}
ssw := NewWriter(proxyConn, c.key)
ssw := NewReaderFrom(proxyConn, c.key)
if c.SaltGenerator != nil {
ssw.SetSaltGenerator(c.SaltGenerator)
}
Expand Down
8 changes: 4 additions & 4 deletions transport/shadowsocks/stream_dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ func TestStreamDialer_TCPPrefix(t *testing.T) {
if err != nil {
t.Fatalf("StreamDialer.Dial failed: %v", err)
}
conn.Write(nil)
conn.ReadFrom(nil)
conn.Close()
running.Wait()
}
Expand Down Expand Up @@ -212,8 +212,8 @@ func startShadowsocksTCPEchoProxy(key *EncryptionKey, expectedTgtAddr string, t
defer running.Done()
defer clientConn.Close()
ssr := NewReader(clientConn, key)
ssw := NewWriter(clientConn, key)
ssClientConn := transport.WrapConn(clientConn, ssr, ssw)
ssrf := NewReaderFrom(clientConn, key)
ssClientConn := transport.WrapConn(clientConn, ssr, ssrf)

tgtAddr, err := socks.ReadAddr(ssClientConn)
if err != nil {
Expand All @@ -222,7 +222,7 @@ func startShadowsocksTCPEchoProxy(key *EncryptionKey, expectedTgtAddr string, t
if tgtAddr.String() != expectedTgtAddr {
t.Fatalf("Expected target address '%v'. Got '%v'", expectedTgtAddr, tgtAddr)
}
io.Copy(ssw, ssr)
ssrf.ReadFrom(ssr)
}()
}
}()
Expand Down
37 changes: 19 additions & 18 deletions transport/shadowsocks/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,8 @@ func TestEndToEnd(t *testing.T) {
require.NoError(t, err)

connReader, connWriter := io.Pipe()
writer := NewWriter(connWriter, key)
// TODO: ReaderFrom to Writer adapter
readerFrom := NewReaderFrom(connWriter, key)
reader := NewReader(connReader, key)
expected := "Test"
wg := sync.WaitGroup{}
Expand All @@ -175,7 +176,7 @@ func TestEndToEnd(t *testing.T) {
defer connWriter.Close()
wg.Add(1)
defer wg.Done()
_, writeErr = writer.Write([]byte(expected))
_, writeErr = readerFrom.ReadFrom(bytes.NewReader([]byte(expected)))
}()
var output bytes.Buffer
_, readErr := reader.WriteTo(&output)
Expand All @@ -195,24 +196,24 @@ func TestLazyWriteFlush(t *testing.T) {
key, err := NewEncryptionKey(CHACHA20IETFPOLY1305, "test secret")
require.NoError(t, err)
buf := new(bytes.Buffer)
writer := NewWriter(buf, key)
readerFrom := NewReaderFrom(buf, key)
header := []byte{1, 2, 3, 4}
n, err := writer.LazyWrite(header)
n, err := readerFrom.LazyWrite(header)
require.NoError(t, err, "LazyWrite failed: %v", err)
require.Equal(t, len(header), n, "Wrong write size")
require.Equal(t, 0, buf.Len(), "LazyWrite isn't lazy")

err = writer.Flush()
err = readerFrom.Flush()
require.NoError(t, err, "Flush failed: %v", err)

len1 := buf.Len()
require.Greater(t, len1, len(header), "Not enough bytes flushed")

// Check that normal writes now work
body := []byte{5, 6, 7}
n, err = writer.Write(body)
n64, err := readerFrom.ReadFrom(bytes.NewReader(body))
require.NoError(t, err, "Write failed: %v", err)
require.Equal(t, len(body), n, "Wrong write size")
require.Equal(t, int64(len(body)), n64, "Wrong write size")
require.Greater(t, buf.Len(), len1, "No write observed")

// Verify content arrives in two blocks
Expand All @@ -233,9 +234,9 @@ func TestLazyWriteConcat(t *testing.T) {
key, err := NewEncryptionKey(CHACHA20IETFPOLY1305, "test secret")
require.NoError(t, err)
buf := new(bytes.Buffer)
writer := NewWriter(buf, key)
readerFrom := NewReaderFrom(buf, key)
header := []byte{1, 2, 3, 4}
n, err := writer.LazyWrite(header)
n, err := readerFrom.LazyWrite(header)
if n != len(header) {
t.Errorf("Wrong write size: %d", n)
}
Expand All @@ -248,8 +249,8 @@ func TestLazyWriteConcat(t *testing.T) {

// Write additional data and flush the header.
body := []byte{5, 6, 7}
n, err = writer.Write(body)
if n != len(body) {
n64, err := readerFrom.ReadFrom(bytes.NewReader(body))
if int(n64) != len(body) {
t.Errorf("Wrong write size: %d", n)
}
if err != nil {
Expand All @@ -261,7 +262,7 @@ func TestLazyWriteConcat(t *testing.T) {
}

// Flush after write should have no effect
if err = writer.Flush(); err != nil {
if err = readerFrom.Flush(); err != nil {
t.Errorf("Flush failed: %v", err)
}
if buf.Len() != len1 {
Expand All @@ -288,7 +289,7 @@ func TestLazyWriteOversize(t *testing.T) {
key, err := NewEncryptionKey(CHACHA20IETFPOLY1305, "test secret")
require.NoError(t, err)
buf := new(bytes.Buffer)
writer := NewWriter(buf, key)
writer := NewReaderFrom(buf, key)
N := 25000 // More than one block, less than two.
data := make([]byte, N)
for i := range data {
Expand Down Expand Up @@ -329,7 +330,7 @@ func TestLazyWriteConcurrentFlush(t *testing.T) {
key, err := NewEncryptionKey(CHACHA20IETFPOLY1305, "test secret")
require.NoError(t, err)
buf := new(bytes.Buffer)
writer := NewWriter(buf, key)
writer := NewReaderFrom(buf, key)
header := []byte{1, 2, 3, 4}
n, err := writer.LazyWrite(header)
require.NoError(t, err, "LazyWrite failed: %v", err)
Expand Down Expand Up @@ -382,8 +383,8 @@ func TestLazyWriteConcurrentFlush(t *testing.T) {

type nullIO struct{}

func (n *nullIO) Write(b []byte) (int, error) {
return len(b), nil
func (n *nullIO) ReadFrom(r io.Reader) (int64, error) {
return io.Copy(io.Discard, r)
}

func (r *nullIO) Read(b []byte) (int, error) {
Expand All @@ -397,11 +398,11 @@ func BenchmarkWriter(b *testing.B) {

key, err := NewEncryptionKey(CHACHA20IETFPOLY1305, "test secret")
require.NoError(b, err)
writer := NewWriter(new(nullIO), key)
readerFrom := NewReaderFrom(new(nullIO), key)

start := time.Now()
b.StartTimer()
io.CopyN(writer, new(nullIO), int64(b.N))
readerFrom.ReadFrom(&io.LimitedReader{R: new(nullIO), N: int64(b.N)})
b.StopTimer()
elapsed := time.Since(start)

Expand Down
30 changes: 15 additions & 15 deletions transport/socks5/socks5.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package socks5
import (
"errors"
"fmt"
"io"
"net"
"strconv"
)
Expand Down Expand Up @@ -72,16 +73,16 @@ const (
addrTypeIPv6 = 0x04
)

// appendSOCKS5Address adds the address to buffer b in SOCKS5 format,
// writeSOCKS5Address adds the address to buffer b in SOCKS5 format,
// as specified in https://datatracker.ietf.org/doc/html/rfc1928#section-4
func appendSOCKS5Address(b []byte, address string) ([]byte, error) {
func writeSOCKS5Address(w io.Writer, address string) error {
host, portStr, err := net.SplitHostPort(address)
if err != nil {
return nil, err
return err
}
portNum, err := strconv.Atoi(portStr)
if err != nil {
return nil, err
return err
}
// The SOCKS address format is as follows:
// +------+----------+----------+
Expand All @@ -92,23 +93,22 @@ func appendSOCKS5Address(b []byte, address string) ([]byte, error) {
// See https://datatracker.ietf.org/doc/html/rfc1928#section-5 for DST.ADDR details.
if ip := net.ParseIP(host); ip != nil {
if ip4 := ip.To4(); ip4 != nil {
b = append(b, addrTypeIPv4)
b = append(b, ip4...)
w.Write([]byte{addrTypeIPv4})
w.Write(ip4)
} else if ip6 := ip.To16(); ip6 != nil {
b = append(b, addrTypeIPv6)
b = append(b, ip6...)
w.Write([]byte{addrTypeIPv6})
w.Write(ip6)
} else {
// This should never happen.
return nil, errors.New("IP address not IPv4 or IPv6")
return errors.New("IP address not IPv4 or IPv6")
}
} else {
if len(host) > 255 {
return nil, fmt.Errorf("domain name length = %v is over 255", len(host))
return fmt.Errorf("domain name length = %v is over 255", len(host))
}
b = append(b, addrTypeDomainName)
b = append(b, byte(len(host)))
b = append(b, host...)
w.Write([]byte{addrTypeDomainName, byte(len(host))})
w.Write([]byte(host))
}
b = append(b, byte(portNum>>8), byte(portNum))
return b, nil
w.Write([]byte{byte(portNum >> 8), byte(portNum)})
return nil
}
Loading

0 comments on commit 8257d95

Please sign in to comment.