From bf61e1468dc1ff28980f0b8aac24b02befbff31e Mon Sep 17 00:00:00 2001 From: Alim Shapiev Date: Sat, 2 Nov 2024 22:18:29 +0100 Subject: [PATCH] Fix TTL and TCP MD5 Signature --- x/configurl/fake.go | 5 +-- x/fake/md5_sig.go | 63 ---------------------------- x/fake/signature/signature.go | 79 +++++++++++++++++++++++++++++++++++ x/fake/stream_dialer.go | 12 +++--- x/fake/writer.go | 31 +++++++++++++- x/fake/writer_test.go | 8 ++-- x/ttl/ttl.go | 33 +++++++++++++++ 7 files changed, 153 insertions(+), 78 deletions(-) delete mode 100644 x/fake/md5_sig.go create mode 100644 x/fake/signature/signature.go create mode 100644 x/ttl/ttl.go diff --git a/x/configurl/fake.go b/x/configurl/fake.go index 03dd7d56..66077a87 100644 --- a/x/configurl/fake.go +++ b/x/configurl/fake.go @@ -67,13 +67,12 @@ func registerFakeStreamDialer(r TypeRegistry[transport.StreamDialer], typeID str if err != nil { return nil, fmt.Errorf("prefixBytes is not a number: %v. Fake config should be in fake: format", prefixBytesStr) } - // TODO: Read fake data from the CLI + // TODO: Read fake data from the CLI or use a default value (depending on the protocol). var fakeData []byte - // or use a default value (depending on the protocol). // TODO: Read fake offset from the CLI var fakeOffset int64 = 0 // TODO: Read fake TTL from the CLI or use a default value (8) - var fakeTtl int64 = 8 + var fakeTtl int = 8 // TODO: Read md5 signature from the CLI or use a default value (false). var md5Sig bool = false return fake.NewStreamDialer(sd, int64(prefixBytes), fakeData, fakeOffset, fakeTtl, md5Sig) diff --git a/x/fake/md5_sig.go b/x/fake/md5_sig.go deleted file mode 100644 index 14b9da0c..00000000 --- a/x/fake/md5_sig.go +++ /dev/null @@ -1,63 +0,0 @@ -package fake - -import ( - "fmt" - "golang.org/x/sys/unix" - "net" - "unsafe" -) - -const tcpMd5sigFlag = 14 - -type md5Signature struct { - Addr [16]byte - Len uint16 - Flags uint16 - Key [80]byte -} - -func setMd5Sig(conn *net.TCPConn, remoteAddr string, data string) error { - ip := net.ParseIP(remoteAddr) - if ip == nil { - return fmt.Errorf("invalid remote IP address: %s", remoteAddr) - } - - address, err := ip.To16().MarshalText() - if err != nil { - return fmt.Errorf("failed to marshal IP address: %v", err) - } - - key := []byte(data) - - md5sig := md5Signature{ - Addr: [16]byte(address), - Len: uint16(len(data)), - Key: [80]byte(key), - } - - if err := setSocketOption(conn, md5sig); err != nil { - return fmt.Errorf("failed to set socket option: %v", err) - } - - return nil -} - -func setSocketOption(conn *net.TCPConn, md5sig md5Signature) error { - file, err := conn.File() - if err != nil { - return fmt.Errorf("failed to get underlying file descriptor: %v", err) - } - defer file.Close() - - size := unsafe.Sizeof(md5sig) - - buffer := (*[unsafe.Sizeof(md5sig)]byte)(unsafe.Pointer(&md5sig))[:size] - fd := int(file.Fd()) - - err = unix.SetsockoptString(fd, unix.IPPROTO_TCP, tcpMd5sigFlag, string(buffer)) - if err != nil { - return fmt.Errorf("failed to set TCP_MD5SIG: %v", err) - } - - return nil -} diff --git a/x/fake/signature/signature.go b/x/fake/signature/signature.go new file mode 100644 index 00000000..0abc3166 --- /dev/null +++ b/x/fake/signature/signature.go @@ -0,0 +1,79 @@ +package signature + +import ( + "fmt" + "golang.org/x/sys/unix" + "net" + "unsafe" +) + +const socketFlag = 14 + +type signature struct { + Addr [16]byte + Len uint16 + Flags uint16 + Key [80]byte +} + +func Add(conn *net.TCPConn, remoteAddr string, data string) error { + ip := net.ParseIP(remoteAddr) + if ip == nil { + return fmt.Errorf("invalid remote IP address: %s", remoteAddr) + } + + address, err := ip.To16().MarshalText() + if err != nil { + return fmt.Errorf("failed to marshal IP address: %w", err) + } + + key := []byte(data) + + sig := signature{ + Addr: [16]byte(address), + Len: uint16(len(data)), + Key: [80]byte(key), + } + + if err := setOption(conn, sig); err != nil { + return fmt.Errorf("failed to set socket option: %w", err) + } + + return nil +} + +func setOption(conn *net.TCPConn, md5sig signature) error { + file, err := conn.File() + if err != nil { + return fmt.Errorf("failed to get file descriptor: %w", err) + } + defer file.Close() + + size := unsafe.Sizeof(md5sig) + buffer := (*[unsafe.Sizeof(md5sig)]byte)(unsafe.Pointer(&md5sig))[:size] + fd := int(file.Fd()) + + err = unix.SetsockoptString(fd, unix.IPPROTO_TCP, socketFlag, string(buffer)) + if err != nil { + return fmt.Errorf("failed to set TCP_MD5SIG: %w", err) + } + + return nil +} + +func Remove(conn *net.TCPConn) error { + file, err := conn.File() + if err != nil { + return fmt.Errorf("failed to get underlying file descriptor: %w", err) + } + defer file.Close() + + fd := int(file.Fd()) + + err = unix.SetsockoptString(fd, unix.IPPROTO_TCP, socketFlag, "") + if err != nil { + return fmt.Errorf("failed to clear TCP_MD5SIG: %w", err) + } + + return nil +} diff --git a/x/fake/stream_dialer.go b/x/fake/stream_dialer.go index 1570d4a8..e3ceb81c 100644 --- a/x/fake/stream_dialer.go +++ b/x/fake/stream_dialer.go @@ -18,6 +18,7 @@ import ( "context" "errors" "fmt" + "github.com/Jigsaw-Code/outline-sdk/x/fake/signature" "net" "github.com/Jigsaw-Code/outline-sdk/transport" @@ -28,7 +29,7 @@ type fakeDialer struct { splitPoint int64 fakeData []byte fakeOffset int64 - fakeTtl int64 + fakeTtl int md5Sig bool } @@ -41,7 +42,7 @@ func NewStreamDialer( prefixBytes int64, fakeData []byte, fakeOffset int64, - fakeTtl int64, + fakeTtl int, md5Sig bool, ) (transport.StreamDialer, error) { if dialer == nil { @@ -63,11 +64,10 @@ func (d *fakeDialer) DialStream(ctx context.Context, remoteAddr string) (transpo if err != nil { return nil, err } - if d.md5Sig { - conn := innerConn.(*net.TCPConn) - err := setMd5Sig(conn, remoteAddr, conn.RemoteAddr().String()) + if tcpInnerConn, isTcp := innerConn.(*net.TCPConn); isTcp && d.md5Sig { + err := signature.Add(tcpInnerConn, remoteAddr, tcpInnerConn.RemoteAddr().String()) if err != nil { - return nil, fmt.Errorf("failed to set MD5 signature: %w", err) + return nil, fmt.Errorf("failed to add MD5 signature: %w", err) } } return transport.WrapConn(innerConn, innerConn, NewWriter(innerConn, d.splitPoint, d.fakeData, d.fakeOffset, d.fakeTtl)), nil diff --git a/x/fake/writer.go b/x/fake/writer.go index 172f2661..d72539bb 100644 --- a/x/fake/writer.go +++ b/x/fake/writer.go @@ -2,7 +2,10 @@ package fake import ( "bytes" + "fmt" + "github.com/Jigsaw-Code/outline-sdk/x/ttl" "io" + "net" ) type fakeWriter struct { @@ -10,7 +13,7 @@ type fakeWriter struct { fakeBytes int64 fakeData []byte fakeOffset int64 - ttl int64 + ttl int } var _ io.Writer = (*fakeWriter)(nil) @@ -26,7 +29,7 @@ var _ io.ReaderFrom = (*fakeWriterReaderFrom)(nil) // A write will end right after byte index fakeBytes - 1, before a write starting at byte index fakeBytes. // For example, if you have a write of [0123456789], fakeData = [abc], fakeOffset = 1, and fakeBytes = 3, // you will get writes [bc] and [0123456789]. If the input writer is a [io.ReaderFrom], the output writer will be too. -func NewWriter(writer io.Writer, fakeBytes int64, fakeData []byte, fakeOffset int64, fakeTtl int64) io.Writer { +func NewWriter(writer io.Writer, fakeBytes int64, fakeData []byte, fakeOffset int64, fakeTtl int) io.Writer { sw := &fakeWriter{writer, fakeBytes, fakeData, fakeOffset, fakeTtl} if rf, ok := writer.(io.ReaderFrom); ok { return &fakeWriterReaderFrom{sw, rf} @@ -35,8 +38,20 @@ func NewWriter(writer io.Writer, fakeBytes int64, fakeData []byte, fakeOffset in } func (w *fakeWriterReaderFrom) ReadFrom(source io.Reader) (written int64, err error) { + conn, isNetConn := w.writer.(net.Conn) fakeData := w.getFakeData() if fakeData != nil { + if isNetConn { + oldTtl, err := ttl.Set(conn, w.ttl) + if err != nil { + return written, fmt.Errorf("failed to set TTL before writing fake data: %w", err) + } + defer func() { + if _, err = ttl.Set(conn, oldTtl); err != nil { + err = fmt.Errorf("failed to restore TTL after writing fake data: %w", err) + } + }() + } fakeN, err := w.rf.ReadFrom(bytes.NewReader(fakeData)) written += fakeN if err != nil { @@ -50,8 +65,20 @@ func (w *fakeWriterReaderFrom) ReadFrom(source io.Reader) (written int64, err er } func (w *fakeWriter) Write(data []byte) (written int, err error) { + conn, isNetConn := w.writer.(net.Conn) fakeData := w.getFakeData() if fakeData != nil { + if isNetConn { + oldTtl, err := ttl.Set(conn, w.ttl) + if err != nil { + return written, fmt.Errorf("failed to set TTL before writing fake data: %w", err) + } + defer func() { + if _, err = ttl.Set(conn, oldTtl); err != nil { + err = fmt.Errorf("failed to restore TTL after writing fake data: %w", err) + } + }() + } fakeN, err := w.writer.Write(fakeData) written += fakeN if err != nil { diff --git a/x/fake/writer_test.go b/x/fake/writer_test.go index 7470774c..8d048e11 100644 --- a/x/fake/writer_test.go +++ b/x/fake/writer_test.go @@ -105,14 +105,14 @@ func TestWrite_Compound(t *testing.T) { fakeData1 := []byte("F") fakeBytes1 := int64(1) fakeOffset1 := int64(0) - fakeTtl1 := int64(0) + fakeTtl1 := 0 writer1 := NewWriter(&innerWriter, fakeBytes1, fakeData1, fakeOffset1, fakeTtl1) // Second fakeWriter: fakeBytes=3, fakeData="ake d", fakeOffset=0 fakeData2 := []byte("ake") // Total fakeData now: "Fake d" fakeBytes2 := int64(3) fakeOffset2 := int64(0) - fakeTtl2 := int64(0) + fakeTtl2 := 0 fakeWriter := NewWriter(writer1, fakeBytes2, fakeData2, fakeOffset2, fakeTtl2) // Write "Request" @@ -205,14 +205,14 @@ func TestReadFrom_Compound(t *testing.T) { fakeData1 := []byte("Fake ") fakeBytes1 := int64(3) fakeOffset1 := int64(0) - fakeTtl1 := int64(0) + fakeTtl1 := 0 writer1 := NewWriter(&innerWriter, fakeBytes1, fakeData1, fakeOffset1, fakeTtl1) // Second fakeWriter: fakeBytes=5, fakeData="data", fakeOffset=0 fakeData2 := []byte("data") fakeBytes2 := int64(5) fakeOffset2 := int64(0) - fakeTtl2 := int64(0) + fakeTtl2 := 0 writer2 := NewWriter(writer1, fakeBytes2, fakeData2, fakeOffset2, fakeTtl2) n, err := writer2.Write([]byte("Request")) diff --git a/x/ttl/ttl.go b/x/ttl/ttl.go new file mode 100644 index 00000000..60817220 --- /dev/null +++ b/x/ttl/ttl.go @@ -0,0 +1,33 @@ +package ttl + +import ( + "fmt" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" + "net" + "net/netip" +) + +func Set(conn net.Conn, ttl int) (old int, err error) { + addr, err := netip.ParseAddrPort(conn.RemoteAddr().String()) + if err != nil { + return 0, err + } + + switch { + case addr.Addr().Is4(): + conn := ipv4.NewConn(conn) + old, _ = conn.TTL() + if err := conn.SetTTL(ttl); err != nil { + return 0, fmt.Errorf("failed to set TTL: %w", err) + } + case addr.Addr().Is6(): + conn := ipv6.NewConn(conn) + old, _ = conn.HopLimit() + if err := conn.SetHopLimit(ttl); err != nil { + return 0, fmt.Errorf("failed to set hop limit: %w", err) + } + } + + return +}