diff --git a/x/fake/stream_dialer.go b/x/fake/stream_dialer.go index f2752881..45e25d74 100644 --- a/x/fake/stream_dialer.go +++ b/x/fake/stream_dialer.go @@ -18,10 +18,8 @@ import ( "context" "errors" "fmt" - "github.com/Jigsaw-Code/outline-sdk/x/md5signature" - "net" - "github.com/Jigsaw-Code/outline-sdk/transport" + "net" ) // Example of fake data for TLS @@ -98,13 +96,34 @@ func (d *fakeDialer) DialStream(ctx context.Context, remoteAddr string) (transpo if err != nil { return nil, err } - if tcpInnerConn, isTcp := innerConn.(*net.TCPConn); isTcp && d.md5Sig { - err := md5signature.Add(tcpInnerConn, remoteAddr, tcpInnerConn.RemoteAddr().String()) - if err != nil { - return nil, fmt.Errorf("failed to add MD5 signature: %w", err) - } + tcpConn, ok := innerConn.(*net.TCPConn) + if !ok { + return nil, fmt.Errorf("oob strategy only works with direct TCP connections") + } + + fd, err := getSocketDescriptor(tcpConn) + if err != nil { + return nil, fmt.Errorf("oob strategy was unable to get conn fd: %w", err) + } + + err = tcpConn.SetNoDelay(true) + if err != nil { + return nil, fmt.Errorf("setting tcp NO_DELAY failed: %w", err) } - return transport.WrapConn(innerConn, innerConn, NewWriter(innerConn, d.fakeData, d.fakeOffset, d.fakeBytes, d.fakeTtl)), nil + + return transport.WrapConn( + innerConn, + innerConn, + NewWriter( + tcpConn, + fd, + tcpConn, + d.fakeData, + d.fakeOffset, + d.fakeBytes, + d.fakeTtl, + ), + ), nil } func getDefaultFakeData(isHttp bool) []byte { diff --git a/x/fake/unix_ops.go b/x/fake/unix_ops.go new file mode 100644 index 00000000..474c6f07 --- /dev/null +++ b/x/fake/unix_ops.go @@ -0,0 +1,42 @@ +//go:build linux || darwin + +package fake + +import ( + "fmt" + "net" + "syscall" +) + +type SocketDescriptor int + +func setsockoptInt(fd SocketDescriptor, level, opt int, value int) error { + fmt.Printf("setsockoptInt: %x, %x, %d to value %d\n", fd, level, opt, value) + return syscall.SetsockoptInt(int(fd), level, opt, value) +} + +func setSocketLinger(fd SocketDescriptor, onoff int32, linger int32) error { + fmt.Printf("setSocketLinger: %d, %d\n", onoff, linger) + return syscall.SetsockoptLinger(int(fd), syscall.SOL_SOCKET, syscall.SO_LINGER, &syscall.Linger{ + Onoff: onoff, + Linger: linger, + }) +} + +func clearSocketLinger(fd SocketDescriptor) error { + fmt.Printf("clearSocketLinger\n") + return syscall.SetsockoptLinger(int(fd), syscall.SOL_SOCKET, syscall.SO_LINGER, nil) +} + +func sendTo(fd SocketDescriptor, data []byte, flags int) (err error) { + fmt.Printf("sendTo: %d, %v, %d\n", fd, data, flags) + return syscall.Sendto(int(fd), data, flags, nil) +} + +func getSocketDescriptor(conn *net.TCPConn) (SocketDescriptor, error) { + file, err := conn.File() + if err != nil { + return 0, err + } + return SocketDescriptor(file.Fd()), nil +} diff --git a/x/fake/writer.go b/x/fake/writer.go index 10b9428e..9fe1ecf7 100644 --- a/x/fake/writer.go +++ b/x/fake/writer.go @@ -5,12 +5,15 @@ import ( "bytes" "fmt" "github.com/Jigsaw-Code/outline-sdk/x/packet" - "github.com/Jigsaw-Code/outline-sdk/x/ttl" "io" "net" + "syscall" + "time" ) type fakeWriter struct { + conn *net.TCPConn + fd SocketDescriptor writer io.Writer fakeData []byte fakeOffset int64 @@ -31,8 +34,8 @@ var _ io.ReaderFrom = (*fakeWriterReaderFrom)(nil) // A write will end right after byte index FakeBytes - 1, before a write starting at byte index FakeBytes. // For example, if you have a write of [0123456789], FakeData = [abc], FakeOffset = 1, and FakeBytes = 3, // you will get writes [bc] and [0123456789]. If the input writer is a [io.ReaderFrom], the output writer will be too. -func NewWriter(writer io.Writer, fakeData []byte, fakeOffset int64, fakeBytes int64, fakeTtl int) io.Writer { - sw := &fakeWriter{writer, fakeData, fakeOffset, fakeBytes, fakeTtl} +func NewWriter(conn *net.TCPConn, fd SocketDescriptor, writer io.Writer, fakeData []byte, fakeOffset int64, fakeBytes int64, ttl int) io.Writer { + sw := &fakeWriter{conn, fd, writer, fakeData, fakeOffset, fakeBytes, ttl} if rf, ok := writer.(io.ReaderFrom); ok { return &fakeWriterReaderFrom{sw, rf} } @@ -40,59 +43,38 @@ func NewWriter(writer io.Writer, fakeData []byte, fakeOffset int64, fakeBytes in } func (w *fakeWriterReaderFrom) ReadFrom(source io.Reader) (written int64, err error) { - conn, isNetConn := w.writer.(net.Conn) - bufioReader := bufio.NewReader(source) - fakeData := w.getFakeData(bufioReader) - if fakeData != nil { - if isNetConn { - oldTtl, err := ttl.Set(conn, w.ttl) - if err != nil { - return written, fmt.Errorf("failed to set TTL before writing fake data: %w", err) - } - defer func() { - if _, err = ttl.Set(conn, oldTtl); err != nil { - err = fmt.Errorf("failed to restore TTL after writing fake data: %w", err) - } - }() - } - fakeN, err := w.rf.ReadFrom(bytes.NewReader(fakeData)) - written += fakeN - if err != nil { - return written, err - } - } - reader := io.MultiReader(io.LimitReader(source, w.fakeBytes), source) - n, err := w.rf.ReadFrom(reader) - written += n - return written, err + panic("implement me") } func (w *fakeWriter) Write(data []byte) (written int, err error) { - conn, isNetConn := w.writer.(net.Conn) fakeData := w.getFakeData(bufio.NewReader(bytes.NewReader(data))) if fakeData != nil { - if isNetConn { - oldTtl, err := ttl.Set(conn, w.ttl) - if err != nil { - return written, fmt.Errorf("failed to set TTL before writing fake data: %w", err) - } - defer func() { - if _, err = ttl.Set(conn, oldTtl); err != nil { - err = fmt.Errorf("failed to restore TTL after writing fake data: %w", err) - } - }() + if err := setsockoptInt(w.fd, syscall.IPPROTO_IP, syscall.IP_TTL, w.ttl); err != nil { + return written, fmt.Errorf("failed to set TTL before writing fake data: %w", err) + } + if err := setSocketLinger(w.fd, 1, 0); err != nil { + return written, fmt.Errorf("failed to set SO_LINGER before writing fake data: %w", err) } fmt.Printf("Writing fake data with TTL %d:\n---\n%s\n---\n", w.ttl, fakeData) fakeData = append(fakeData, make([]byte, len(data)-len(fakeData))...) - fakeN, err := w.writer.Write(fakeData) - written += fakeN + err := w.send(fakeData, 0) + written += len(fakeData) if err != nil { return written, err } + time.Sleep(200 * time.Millisecond) + //if err := setsockoptInt(w.fd, syscall.IPPROTO_IP, syscall.IP_TTL, 68); err != nil { + // err = fmt.Errorf("failed to restore TTL after writing fake data: %w", err) + //} + if err := clearSocketLinger(w.fd); err != nil { + err = fmt.Errorf("failed to restore SO_LINGER after writing fake data: %w", err) + } } fmt.Printf("Writing real data:\n---\n%s\n---\n", data) - //n, err := w.writer.Write(data) - //written += n + if err := w.send(data, 0); err != nil { + return written, fmt.Errorf("failed to send real data: %w", err) + } + written += len(data) return written, err } @@ -114,3 +96,24 @@ func (w *fakeWriter) getFakeData(dataReader *bufio.Reader) []byte { } return fakeData } + +func (w *fakeWriter) send(data []byte, flags int) error { + // Use SyscallConn to access the underlying file descriptor safely + rawConn, err := w.conn.SyscallConn() + if err != nil { + return fmt.Errorf("oob strategy was unable to get raw conn: %w", err) + } + + // Use Control to execute Sendto on the file descriptor + var sendErr error + err = rawConn.Control(func(fd uintptr) { + sendErr = sendTo(SocketDescriptor(fd), data, flags) + }) + if err != nil { + return fmt.Errorf("oob strategy was unable to control socket: %w", err) + } + if sendErr != nil { + return fmt.Errorf("oob strategy was unable to send data: %w", sendErr) + } + return nil +}