diff --git a/client.go b/client.go index 1c79acd..ceb563d 100644 --- a/client.go +++ b/client.go @@ -8,6 +8,7 @@ package rtmp import ( + "crypto/tls" "net" "github.com/pkg/errors" @@ -17,6 +18,13 @@ func Dial(protocol, addr string, config *ConnConfig) (*ClientConn, error) { return DialWithDialer(&net.Dialer{}, protocol, addr, config) } +func TLSDial(protocol, addr string, config *ConnConfig, tlsConfig *tls.Config) (*ClientConn, error) { + return DialWithTLSDialer(&tls.Dialer{ + NetDialer: &net.Dialer{}, + Config: tlsConfig, + }, protocol, addr, config) +} + func DialWithDialer(dialer *net.Dialer, protocol, addr string, config *ConnConfig) (*ClientConn, error) { if protocol != "rtmp" { return nil, errors.Errorf("Unknown protocol: %s", protocol) @@ -30,13 +38,15 @@ func DialWithDialer(dialer *net.Dialer, protocol, addr string, config *ConnConfi return newClientConnWithSetup(rwc, config) } -func makeValidAddr(addr string) (string, error) { - host, port, err := net.SplitHostPort(addr) +func DialWithTLSDialer(dialer *tls.Dialer, protocol, addr string, config *ConnConfig) (*ClientConn, error) { + if protocol != "rtmps" { + return nil, errors.Errorf("Unknown protocol: %s", protocol) + } + + rwc, err := dialer.Dial("tcp", addr) if err != nil { - if err, ok := err.(*net.AddrError); ok && err.Err == "missing port in address" { - return makeValidAddr(addr + ":1935") // Default RTMP port - } - return "", err + return nil, err } - return net.JoinHostPort(host, port), nil + + return newClientConnWithSetup(rwc, config) } diff --git a/client_test.go b/client_test.go deleted file mode 100644 index 93dacba..0000000 --- a/client_test.go +++ /dev/null @@ -1,32 +0,0 @@ -// -// Copyright (c) 2018- yutopp (yutopp@gmail.com) -// -// Distributed under the Boost Software License, Version 1.0. (See accompanying -// file LICENSE_1_0.txt or copy at https://www.boost.org/LICENSE_1_0.txt) -// - -package rtmp - -import ( - "testing" - - "github.com/stretchr/testify/require" -) - -func TestClientValidAddr(t *testing.T) { - addr, err := makeValidAddr("host:123") - require.Equal(t, nil, err) - require.Equal(t, "host:123", addr) - - addr, err = makeValidAddr("host") - require.Equal(t, nil, err) - require.Equal(t, "host:1935", addr) - - addr, err = makeValidAddr("host:") - require.Equal(t, nil, err) - require.Equal(t, "host:", addr) - - addr, err = makeValidAddr(":1111") - require.Equal(t, nil, err) - require.Equal(t, ":1111", addr) -}