Skip to content

Commit

Permalink
Implementation read waiter for pipe
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Dec 6, 2023
1 parent c5c692f commit 67b9bf3
Show file tree
Hide file tree
Showing 4 changed files with 301 additions and 1 deletion.
3 changes: 3 additions & 0 deletions common/pipe/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# pipe

mod from go1.21.4
243 changes: 243 additions & 0 deletions common/pipe/pipe.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package pipe

import (
"io"
"net"
"os"
"sync"
"time"

"github.com/sagernet/sing/common/buf"
)

// pipeDeadline is an abstraction for handling timeouts.
type pipeDeadline struct {
mu sync.Mutex // Guards timer and cancel
timer *time.Timer
cancel chan struct{} // Must be non-nil
}

func makePipeDeadline() pipeDeadline {
return pipeDeadline{cancel: make(chan struct{})}
}

// set sets the point in time when the deadline will time out.
// A timeout event is signaled by closing the channel returned by waiter.
// Once a timeout has occurred, the deadline can be refreshed by specifying a
// t value in the future.
//
// A zero value for t prevents timeout.
func (d *pipeDeadline) set(t time.Time) {
d.mu.Lock()
defer d.mu.Unlock()

if d.timer != nil && !d.timer.Stop() {
<-d.cancel // Wait for the timer callback to finish and close cancel
}
d.timer = nil

// Time is zero, then there is no deadline.
closed := isClosedChan(d.cancel)
if t.IsZero() {
if closed {
d.cancel = make(chan struct{})
}
return
}

// Time in the future, setup a timer to cancel in the future.
if dur := time.Until(t); dur > 0 {
if closed {
d.cancel = make(chan struct{})
}
d.timer = time.AfterFunc(dur, func() {
close(d.cancel)
})
return
}

// Time in the past, so close immediately.
if !closed {
close(d.cancel)
}
}

// wait returns a channel that is closed when the deadline is exceeded.
func (d *pipeDeadline) wait() chan struct{} {
d.mu.Lock()
defer d.mu.Unlock()
return d.cancel
}

func isClosedChan(c <-chan struct{}) bool {
select {
case <-c:
return true
default:
return false
}
}

type pipeAddr struct{}

func (pipeAddr) Network() string { return "pipe" }
func (pipeAddr) String() string { return "pipe" }

type pipe struct {
wrMu sync.Mutex // Serialize Write operations

// Used by local Read to interact with remote Write.
// Successful receive on rdRx is always followed by send on rdTx.
rdRx <-chan []byte
rdTx chan<- int

// Used by local Write to interact with remote Read.
// Successful send on wrTx is always followed by receive on wrRx.
wrTx chan<- []byte
wrRx <-chan int

once sync.Once // Protects closing localDone
localDone chan struct{}
remoteDone <-chan struct{}

readDeadline pipeDeadline
writeDeadline pipeDeadline

newBuffer func() *buf.Buffer
}

// Pipe creates a synchronous, in-memory, full duplex
// network connection; both ends implement the Conn interface.
// Reads on one end are matched with writes on the other,
// copying data directly between the two; there is no internal
// buffering.
func Pipe() (net.Conn, net.Conn) {
cb1 := make(chan []byte)
cb2 := make(chan []byte)
cn1 := make(chan int)
cn2 := make(chan int)
done1 := make(chan struct{})
done2 := make(chan struct{})

p1 := &pipe{
rdRx: cb1, rdTx: cn1,
wrTx: cb2, wrRx: cn2,
localDone: done1, remoteDone: done2,
readDeadline: makePipeDeadline(),
writeDeadline: makePipeDeadline(),
}
p2 := &pipe{
rdRx: cb2, rdTx: cn2,
wrTx: cb1, wrRx: cn1,
localDone: done2, remoteDone: done1,
readDeadline: makePipeDeadline(),
writeDeadline: makePipeDeadline(),
}
return p1, p2
}

func (*pipe) LocalAddr() net.Addr { return pipeAddr{} }
func (*pipe) RemoteAddr() net.Addr { return pipeAddr{} }

func (p *pipe) Read(b []byte) (int, error) {
n, err := p.read(b)
if err != nil && err != io.EOF && err != io.ErrClosedPipe {
err = &net.OpError{Op: "read", Net: "pipe", Err: err}
}
return n, err
}

func (p *pipe) read(b []byte) (n int, err error) {
switch {
case isClosedChan(p.localDone):
return 0, io.ErrClosedPipe
case isClosedChan(p.remoteDone):
return 0, io.EOF
case isClosedChan(p.readDeadline.wait()):
return 0, os.ErrDeadlineExceeded
}

select {
case bw := <-p.rdRx:
nr := copy(b, bw)
p.rdTx <- nr
return nr, nil
case <-p.localDone:
return 0, io.ErrClosedPipe
case <-p.remoteDone:
return 0, io.EOF
case <-p.readDeadline.wait():
return 0, os.ErrDeadlineExceeded
}
}

func (p *pipe) Write(b []byte) (int, error) {
n, err := p.write(b)
if err != nil && err != io.ErrClosedPipe {
err = &net.OpError{Op: "write", Net: "pipe", Err: err}
}
return n, err
}

func (p *pipe) write(b []byte) (n int, err error) {
switch {
case isClosedChan(p.localDone):
return 0, io.ErrClosedPipe
case isClosedChan(p.remoteDone):
return 0, io.ErrClosedPipe
case isClosedChan(p.writeDeadline.wait()):
return 0, os.ErrDeadlineExceeded
}

p.wrMu.Lock() // Ensure entirety of b is written together
defer p.wrMu.Unlock()
for once := true; once || len(b) > 0; once = false {
select {
case p.wrTx <- b:
nw := <-p.wrRx
b = b[nw:]
n += nw
case <-p.localDone:
return n, io.ErrClosedPipe
case <-p.remoteDone:
return n, io.ErrClosedPipe
case <-p.writeDeadline.wait():
return n, os.ErrDeadlineExceeded
}
}
return n, nil
}

func (p *pipe) SetDeadline(t time.Time) error {
if isClosedChan(p.localDone) || isClosedChan(p.remoteDone) {
return io.ErrClosedPipe
}
p.readDeadline.set(t)
p.writeDeadline.set(t)
return nil
}

func (p *pipe) SetReadDeadline(t time.Time) error {
if isClosedChan(p.localDone) || isClosedChan(p.remoteDone) {
return io.ErrClosedPipe
}
p.readDeadline.set(t)
return nil
}

func (p *pipe) SetWriteDeadline(t time.Time) error {
if isClosedChan(p.localDone) || isClosedChan(p.remoteDone) {
return io.ErrClosedPipe
}
p.writeDeadline.set(t)
return nil
}

func (p *pipe) Close() error {
p.once.Do(func() { close(p.localDone) })
return nil
}
53 changes: 53 additions & 0 deletions common/pipe/pipe_wait.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package pipe

import (
"io"
"net"
"os"

"github.com/sagernet/sing/common/buf"
N "github.com/sagernet/sing/common/network"
)

var _ N.ReadWaiter = (*pipe)(nil)

func (p *pipe) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
p.newBuffer = newBuffer
}

func (p *pipe) WaitReadBuffer() (buffer *buf.Buffer, err error) {
buffer, err = p.waitReadBuffer()
if err != nil && err != io.EOF && err != io.ErrClosedPipe {
err = &net.OpError{Op: "read", Net: "pipe", Err: err}
}
return
}

func (p *pipe) waitReadBuffer() (buffer *buf.Buffer, err error) {
switch {
case isClosedChan(p.localDone):
return nil, io.ErrClosedPipe
case isClosedChan(p.remoteDone):
return nil, io.EOF
case isClosedChan(p.readDeadline.wait()):
return nil, os.ErrDeadlineExceeded
}
select {
case bw := <-p.rdRx:
buffer = p.newBuffer()
var nr int
nr, err = buffer.Write(bw)
if err != nil {
buffer.Release()
return
}
p.rdTx <- nr
return
case <-p.localDone:
return nil, io.ErrClosedPipe
case <-p.remoteDone:
return nil, io.EOF
case <-p.readDeadline.wait():
return nil, os.ErrDeadlineExceeded
}
}
3 changes: 2 additions & 1 deletion protocol/http/handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
F "github.com/sagernet/sing/common/format"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/pipe"
)

type Handler = N.TCPConnectionHandler
Expand Down Expand Up @@ -102,7 +103,7 @@ func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Read
DialContext: func(context context.Context, network, address string) (net.Conn, error) {
metadata.Destination = M.ParseSocksaddr(address)
metadata.Protocol = "http"
input, output := net.Pipe()
input, output := pipe.Pipe()
go func() {
hErr := handler.NewConnection(ctx, output, metadata)
if hErr != nil {
Expand Down

0 comments on commit 67b9bf3

Please sign in to comment.