From 34ccabd130ac688c648b3d24821d0518f03245b0 Mon Sep 17 00:00:00 2001
From: Jean-Baptiste Reich <jeanbaptiste.reich@gmail.com>
Date: Tue, 9 Jul 2024 21:51:39 +0200
Subject: [PATCH] Added rtmps protocol on client connection

---
 client.go      | 24 +++++++++++++++++-------
 client_test.go | 32 --------------------------------
 2 files changed, 17 insertions(+), 39 deletions(-)
 delete mode 100644 client_test.go

diff --git a/client.go b/client.go
index 1c79acd..ceb563d 100644
--- a/client.go
+++ b/client.go
@@ -8,6 +8,7 @@
 package rtmp
 
 import (
+	"crypto/tls"
 	"net"
 
 	"github.com/pkg/errors"
@@ -17,6 +18,13 @@ func Dial(protocol, addr string, config *ConnConfig) (*ClientConn, error) {
 	return DialWithDialer(&net.Dialer{}, protocol, addr, config)
 }
 
+func TLSDial(protocol, addr string, config *ConnConfig, tlsConfig *tls.Config) (*ClientConn, error) {
+	return DialWithTLSDialer(&tls.Dialer{
+		NetDialer: &net.Dialer{},
+		Config:    tlsConfig,
+	}, protocol, addr, config)
+}
+
 func DialWithDialer(dialer *net.Dialer, protocol, addr string, config *ConnConfig) (*ClientConn, error) {
 	if protocol != "rtmp" {
 		return nil, errors.Errorf("Unknown protocol: %s", protocol)
@@ -30,13 +38,15 @@ func DialWithDialer(dialer *net.Dialer, protocol, addr string, config *ConnConfi
 	return newClientConnWithSetup(rwc, config)
 }
 
-func makeValidAddr(addr string) (string, error) {
-	host, port, err := net.SplitHostPort(addr)
+func DialWithTLSDialer(dialer *tls.Dialer, protocol, addr string, config *ConnConfig) (*ClientConn, error) {
+	if protocol != "rtmps" {
+		return nil, errors.Errorf("Unknown protocol: %s", protocol)
+	}
+
+	rwc, err := dialer.Dial("tcp", addr)
 	if err != nil {
-		if err, ok := err.(*net.AddrError); ok && err.Err == "missing port in address" {
-			return makeValidAddr(addr + ":1935") // Default RTMP port
-		}
-		return "", err
+		return nil, err
 	}
-	return net.JoinHostPort(host, port), nil
+
+	return newClientConnWithSetup(rwc, config)
 }
diff --git a/client_test.go b/client_test.go
deleted file mode 100644
index 93dacba..0000000
--- a/client_test.go
+++ /dev/null
@@ -1,32 +0,0 @@
-//
-// Copyright (c) 2018- yutopp (yutopp@gmail.com)
-//
-// Distributed under the Boost Software License, Version 1.0. (See accompanying
-// file LICENSE_1_0.txt or copy at  https://www.boost.org/LICENSE_1_0.txt)
-//
-
-package rtmp
-
-import (
-	"testing"
-
-	"github.com/stretchr/testify/require"
-)
-
-func TestClientValidAddr(t *testing.T) {
-	addr, err := makeValidAddr("host:123")
-	require.Equal(t, nil, err)
-	require.Equal(t, "host:123", addr)
-
-	addr, err = makeValidAddr("host")
-	require.Equal(t, nil, err)
-	require.Equal(t, "host:1935", addr)
-
-	addr, err = makeValidAddr("host:")
-	require.Equal(t, nil, err)
-	require.Equal(t, "host:", addr)
-
-	addr, err = makeValidAddr(":1111")
-	require.Equal(t, nil, err)
-	require.Equal(t, ":1111", addr)
-}