Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix goroutine leak, deadlock, and DTLS deadline issue #274

Merged
merged 7 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ app:
[ -d $(EXE_DIR) ] || mkdir -p $(EXE_DIR)
go build -o ${EXE_DIR}/application ./cmd/application

app-dbg:
[ -d $(EXE_DIR) ] || mkdir -p $(EXE_DIR)
go build -tags debug -o ${EXE_DIR}/application ./cmd/application

libtd:
cd ./libtapdance/ && make libtapdance.a

Expand Down
2 changes: 2 additions & 0 deletions cmd/application/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ func main() {
flag.StringVar(&zmqAddress, "zmq-address", "ipc://@zmq-proxy", "Address of ZMQ proxy")
flag.Parse()

startPProf()

// Init stats
cj.Stat()

Expand Down
6 changes: 6 additions & 0 deletions cmd/application/pprof.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
//go:build !debug

package main

func startPProf() {
}
15 changes: 15 additions & 0 deletions cmd/application/pprof_dbg.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
//go:build debug

package main

import (
"log"
"net/http"
_ "net/http/pprof"
)

func startPProf() {
go func() {
log.Println(http.ListenAndServe("localhost:6060", nil))
}()
}
47 changes: 36 additions & 11 deletions pkg/dtls/dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"crypto/x509"
"fmt"
"net"
"time"

"github.com/pion/dtls/v2"
"github.com/pion/dtls/v2/pkg/protocol/handshake"
Expand Down Expand Up @@ -33,6 +34,40 @@ func Client(conn net.Conn, config *Config) (net.Conn, error) {

// DialWithContext creates a DTLS connection to the given network address using the given shared secret
func ClientWithContext(ctx context.Context, conn net.Conn, config *Config) (net.Conn, error) {

dtlsConn, err := dtlsCtx(ctx, conn, config)
if err != nil {
return nil, fmt.Errorf("error creating dtls connection: %w", err)
}

ddl, ok := ctx.Deadline()
if ok {
err := conn.SetDeadline(ddl)
if err != nil {
return nil, fmt.Errorf("error setting deadline: %v", err)
}
}

wrappedConn, err := wrapSCTP(dtlsConn, config)
if err != nil {
dtlsConn.Close()
return nil, err
}

err = conn.SetDeadline(time.Time{})
if err != nil {
return nil, fmt.Errorf("error setting deadline: %v", err)
}

err = wrappedConn.SetDeadline(time.Time{})
if err != nil {
return nil, fmt.Errorf("error setting deadline: %v", err)
}

return wrappedConn, nil
}

func dtlsCtx(ctx context.Context, conn net.Conn, config *Config) (net.Conn, error) {
clientCert, serverCert, err := certsFromSeed(config.PSK)

if err != nil {
Expand Down Expand Up @@ -68,16 +103,6 @@ func ClientWithContext(ctx context.Context, conn net.Conn, config *Config) (net.
VerifyPeerCertificate: verifyServerCertificate,
}

dtlsConn, err := dtls.ClientWithContext(ctx, conn, dtlsConf)

if err != nil {
return nil, fmt.Errorf("error creating dtls connection: %w", err)
}

wrappedConn, err := wrapSCTP(dtlsConn, config)
if err != nil {
return nil, err
}
return dtls.ClientWithContext(ctx, conn, dtlsConf)

return wrappedConn, nil
}
2 changes: 1 addition & 1 deletion pkg/dtls/goroutine_leak_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func passGoroutineLeak(testFunc func(*testing.T), t *testing.T) bool {
}

func TestGoroutineLeak(t *testing.T) {
testFuncs := []func(*testing.T){TestSend, TestServerFail, TestClientFail, TestListenSuccess, TestListenFail}
testFuncs := []func(*testing.T){TestSend, TestServerFail, TestClientFail, TestListenSuccess, TestListenFail, TestFailSCTP}

for _, test := range testFuncs {
require.True(t, passGoroutineLeak(test, t))
Expand Down
58 changes: 49 additions & 9 deletions pkg/dtls/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,23 +149,36 @@ func (l *Listener) Addr() net.Addr {

func (l *Listener) verifyConnection(state *dtls.State) error {

certs, ok := l.connToCert[state.RemoteRandomBytes()]
if !ok {
return fmt.Errorf("no matching certificate found with client hello random")
certs, err := l.getCert(state.RemoteRandomBytes())
if err != nil {
return err
}

if len(state.PeerCertificates) != 1 {
return fmt.Errorf("expected 1 peer certificate, got %v", len(state.PeerCertificates))
}

err := verifyCert(state.PeerCertificates[0], certs.clientCert.Certificate[0])
err = verifyCert(state.PeerCertificates[0], certs.clientCert.Certificate[0])
if err != nil {
return fmt.Errorf("error verifying peer certificate: %v", err)
}

return nil
}

func (l *Listener) getCert(id [handshake.RandomBytesLength]byte) (*certPair, error) {
l.connToCertMutex.Lock()
defer l.connToCertMutex.Unlock()

certs, ok := l.connToCert[id]
if !ok {
return nil, fmt.Errorf("no matching certificate found with client hello random")
}

return certs, nil

}

// Accept accepts a connection with shared secret
func (l *Listener) Accept(config *Config) (net.Conn, error) {
// Call the new function with a background context
Expand All @@ -174,6 +187,37 @@ func (l *Listener) Accept(config *Config) (net.Conn, error) {

// AcceptWithContext accepts a connection with shared secret, with a context
func (l *Listener) AcceptWithContext(ctx context.Context, config *Config) (net.Conn, error) {

conn, err := l.acceptDTLSConn(ctx, config)
if err != nil {
return nil, err
}

ddl, ok := ctx.Deadline()
if ok {
err := conn.SetDeadline(ddl)
if err != nil {
return nil, fmt.Errorf("error setting deadline: %v", err)
}
}
wrappedConn, err := wrapSCTP(conn, config)
if err != nil {
conn.Close()
return nil, err
}
err = conn.SetDeadline(time.Time{})
if err != nil {
return nil, fmt.Errorf("error setting deadline: %v", err)
}

err = wrappedConn.SetDeadline(time.Time{})
if err != nil {
return nil, fmt.Errorf("error setting deadline: %v", err)
}
return wrappedConn, nil
}

func (l *Listener) acceptDTLSConn(ctx context.Context, config *Config) (net.Conn, error) {
clientCert, serverCert, err := certsFromSeed(config.PSK)
if err != nil {
return &dtls.Conn{}, fmt.Errorf("error generating certificatess from seed: %v", err)
Expand All @@ -198,11 +242,7 @@ func (l *Listener) AcceptWithContext(ctx context.Context, config *Config) (net.C

select {
case conn := <-connCh:
wrappedConn, err := wrapSCTP(conn, config)
if err != nil {
return nil, err
}
return wrappedConn, nil
return conn, nil
case <-ctx.Done():
return nil, ctx.Err()
}
Expand Down
8 changes: 3 additions & 5 deletions pkg/dtls/sctpconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,9 @@ func (s *SCTPConn) Close() error {

s.closeOnce.Do(func() { close(s.closed) })

err := s.stream.Close()
if err != nil {
return err
}
return s.conn.Close()
s.stream.Close()
s.conn.Close()
return nil
}

func (s *SCTPConn) Write(b []byte) (int, error) {
Expand Down
20 changes: 20 additions & 0 deletions pkg/dtls/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"crypto/x509"
"fmt"
"net"
"time"

"github.com/pion/dtls/v2"
)
Expand Down Expand Up @@ -51,11 +52,30 @@ func ServerWithContext(ctx context.Context, conn net.Conn, config *Config) (net.
return nil, err
}

ddl, ok := ctx.Deadline()
if ok {
err := conn.SetDeadline(ddl)
if err != nil {
return nil, fmt.Errorf("error setting deadline: %v", err)
}
}

wrappedConn, err := wrapSCTP(dtlsConn, config)
if err != nil {
dtlsConn.Close()
return nil, err
}

err = conn.SetDeadline(time.Time{})
if err != nil {
return nil, fmt.Errorf("error setting deadline: %v", err)
}

err = wrappedConn.SetDeadline(time.Time{})
if err != nil {
return nil, fmt.Errorf("error setting deadline: %v", err)
}

return wrappedConn, nil
}

Expand Down
24 changes: 24 additions & 0 deletions pkg/dtls/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,27 @@ func TestClientFail(t *testing.T) {
t.Fatalf("Connect does not respect context")
}
}

func TestFailSCTP(t *testing.T) {

ctxTime := 3 * time.Second
ctx, cancel := context.WithTimeout(context.Background(), ctxTime)
defer cancel()

server, client := net.Pipe()

before := time.Now()

go func() {
_, _ = dtlsCtx(ctx, client, &Config{PSK: sharedSecret, SCTP: ClientOpen})
}()

_, err := ServerWithContext(ctx, server, &Config{PSK: sharedSecret, SCTP: ServerAccept})

require.NotNil(t, err)
dur := time.Since(before)
if dur > ctxTime*2 {
t.Fatalf("Connect does not respect context")
}

}
2 changes: 1 addition & 1 deletion pkg/station/lib/registration_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func TestConjureLibConfigResolveBlocklisted(t *testing.T) {
goodTestCases := map[string][]string{
"128.0.2.1:25": {"128.0.2.1:25"},
"[2001:db8::1]:80": {"[2001:db8::1]:80"},
"example.com:1234": {"93.184.216.34:1234", "[2606:2800:220:1:248:1893:25c8:1946]:1234"},
"example.com:1234": {"93.184.215.14:1234", "[2606:2800:21f:cb07:6820:80da:af6b:8b2c]:1234"},
"[::2]:443": {"[::2]:443"},
}

Expand Down
2 changes: 2 additions & 0 deletions pkg/station/lib/registration_ingest.go
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,8 @@ func handleConnectingTpReg(regManager *RegistrationManager, reg *DecoyRegistrati
return
}

defer conn.Close()

// regManager.connectingStats.AddCreatedToSuccessfulConnecting(asn, cc, transport.Name())

Stat().AddConn()
Expand Down
49 changes: 36 additions & 13 deletions pkg/transports/connecting/dtls/dtls.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,11 @@ func (t *Transport) Connect(ctx context.Context, reg transports.Registration) (n
return nil, fmt.Errorf("transport params is not *pb.DTLSTransportParams")
}

connCh := make(chan net.Conn, 2)
errCh := make(chan error, 2)
connCh := make(chan net.Conn)
errCh := make(chan error)

ctxCancel, cancel := context.WithCancel(ctx)
defer cancel()

go func() {

Expand All @@ -94,7 +97,10 @@ func (t *Transport) Connect(ctx context.Context, reg transports.Registration) (n

err := t.DNAT.AddEntry(&clientAddr.IP, uint16(clientAddr.Port), reg.PhantomIP(), reg.GetDstPort())
if err != nil {
errCh <- fmt.Errorf("error adding DNAT entry: %v", err)
select {
case errCh <- fmt.Errorf("error adding DNAT entry: %v", err):
case <-ctxCancel.Done():
}
return
}

Expand All @@ -108,30 +114,47 @@ func (t *Transport) Connect(ctx context.Context, reg transports.Registration) (n

udpConn, err := reuseport.Dial("udp", laddr.String(), clientAddr.String())
if err != nil {
errCh <- fmt.Errorf("error connecting to dtls client: %v", err)
select {
case errCh <- fmt.Errorf("error connecting to dtls client: %v", err):
case <-ctxCancel.Done():
}
return
}

dtlsConn, err := dtls.ClientWithContext(ctx, udpConn, &dtls.Config{PSK: reg.SharedSecret(), SCTP: dtls.ServerAccept, Unordered: params.GetUnordered()})
dtlsConn, err := dtls.ClientWithContext(ctxCancel, udpConn, &dtls.Config{PSK: reg.SharedSecret(), SCTP: dtls.ServerAccept, Unordered: params.GetUnordered()})
if err != nil {
errCh <- fmt.Errorf("error connecting to dtls client: %v", err)
select {
case errCh <- fmt.Errorf("error connecting to dtls client: %v", err):
case <-ctxCancel.Done():
}
return
}
t.logDialSuccess(&clientAddr.IP)

connCh <- dtlsConn
select {
case connCh <- dtlsConn:
t.logDialSuccess(&clientAddr.IP)
case <-ctxCancel.Done():
dtlsConn.Close()
}
}()

go func() {
conn, err := t.dtlsListener.AcceptWithContext(ctx, &dtls.Config{PSK: reg.SharedSecret(), SCTP: dtls.ServerAccept, Unordered: params.GetUnordered()})
conn, err := t.dtlsListener.AcceptWithContext(ctxCancel, &dtls.Config{PSK: reg.SharedSecret(), SCTP: dtls.ServerAccept, Unordered: params.GetUnordered()})
if err != nil {
errCh <- fmt.Errorf("error accepting dtls connection from secret: %v", err)
select {
case errCh <- fmt.Errorf("error accepting dtls connection from secret: %v", err):
case <-ctxCancel.Done():
}
return
}
logip := net.ParseIP(reg.GetRegistrationAddress())
t.logListenSuccess(&logip)

connCh <- conn
select {
case connCh <- conn:
logip := net.ParseIP(reg.GetRegistrationAddress())
t.logListenSuccess(&logip)
case <-ctxCancel.Done():
conn.Close()
}
}()

var errs []error
Expand Down
Loading
Loading