Skip to content

Commit

Permalink
ssl: Add more guards to api functions
Browse files Browse the repository at this point in the history
Gives badarg earlier instead of strange crashes/error messages later.
  • Loading branch information
dgud committed Feb 19, 2024
1 parent a5566ca commit 110b699
Showing 1 changed file with 51 additions and 36 deletions.
87 changes: 51 additions & 36 deletions lib/ssl/src/ssl.erl
Original file line number Diff line number Diff line change
Expand Up @@ -1769,6 +1769,9 @@ can verify.
-type tls_options_name() :: atom().
%% -------------------------------------------------------------------------------------------------------

-define(IS_TIMEOUT(Timeout),
((is_integer(Timeout) andalso Timeout >= 0) orelse (Timeout == infinity))).

%%%--------------------------------------------------------------------
%%% API
%%%--------------------------------------------------------------------
Expand Down Expand Up @@ -1818,7 +1821,8 @@ stop() ->
TCPSocket :: socket(),
TLSOptions :: [tls_client_option()].

connect(Socket, SslOptions) ->
connect(Socket, SslOptions)
when is_list(SslOptions) ->
connect(Socket, SslOptions, infinity).

-doc """
Expand Down Expand Up @@ -1857,8 +1861,8 @@ owning the sslsocket will receive messages of type `t:active_msgs/0`
Port :: inet:port_number(),
TLSOptions :: [tls_client_option()].

connect(Socket, SslOptions0, Timeout) when is_list(SslOptions0) andalso
(is_integer(Timeout) andalso Timeout >= 0) or (Timeout == infinity) ->
connect(Socket, SslOptions0, Timeout)
when is_list(SslOptions0), ?IS_TIMEOUT(Timeout) ->

try
CbInfo = handle_option_cb_info(SslOptions0, tls),
Expand All @@ -1868,8 +1872,9 @@ connect(Socket, SslOptions0, Timeout) when is_list(SslOptions0) andalso
catch
_:{error, Reason} ->
{error, Reason}
end;
connect(Host, Port, Options) ->
end;
connect(Host, Port, Options)
when is_integer(Port), is_list(Options) ->
connect(Host, Port, Options, infinity).

-doc """
Expand Down Expand Up @@ -1914,7 +1919,8 @@ owning the sslsocket will receive messages of type `t:active_msgs/0`
TLSOptions :: [tls_client_option()],
Timeout :: timeout().

connect(Host, Port, Options, Timeout) when (is_integer(Timeout) andalso Timeout >= 0) or (Timeout == infinity) ->
connect(Host, Port, Options, Timeout)
when is_integer(Port), is_list(Options), ?IS_TIMEOUT(Timeout) ->
try
{ok, Config} = handle_options(Options, client, Host),
case Config#config.connection_cb of
Expand All @@ -1940,7 +1946,8 @@ connect(Host, Port, Options, Timeout) when (is_integer(Timeout) andalso Timeout
%%--------------------------------------------------------------------
listen(_Port, []) ->
{error, nooptions};
listen(Port, Options0) ->
listen(Port, Options0)
when is_integer(Port), is_list(Options0) ->
try
{ok, Config} = handle_options(Options0, server, undefined),
do_listen(Port, Config, Config#config.connection_cb)
Expand Down Expand Up @@ -1984,8 +1991,8 @@ connection is accepted within the given time, `{error, timeout}` is returned.
SslSocket :: sslsocket().

transport_accept(#sslsocket{pid = {ListenSocket,
#config{connection_cb = ConnectionCb} = Config}}, Timeout)
when (is_integer(Timeout) andalso Timeout >= 0) or (Timeout == infinity) ->
#config{connection_cb = ConnectionCb} = Config}}, Timeout)
when ?IS_TIMEOUT(Timeout) ->
case ConnectionCb of
tls_gen_connection ->
tls_socket:accept(ListenSocket, Config, Timeout);
Expand Down Expand Up @@ -2039,8 +2046,8 @@ owning the sslsocket will receive messages of type `t:active_msgs/0`
Ext :: protocol_extensions(),
Reason :: closed | timeout | error_alert().

handshake(#sslsocket{} = Socket, Timeout) when (is_integer(Timeout) andalso Timeout >= 0) or
(Timeout == infinity) ->
handshake(#sslsocket{} = Socket, Timeout)
when ?IS_TIMEOUT(Timeout) ->
ssl_gen_statem:handshake(Socket, Timeout);

%% If Socket is a ordinary socket(): upgrades a gen_tcp, or equivalent, socket to
Expand Down Expand Up @@ -2091,11 +2098,11 @@ owning the sslsocket will receive messages of type `t:active_msgs/0`
Ext :: protocol_extensions(),
Reason :: closed | timeout | {options, any()} | error_alert().

handshake(#sslsocket{} = Socket, [], Timeout) when (is_integer(Timeout) andalso Timeout >= 0) or
(Timeout == infinity)->
handshake(#sslsocket{} = Socket, [], Timeout)
when ?IS_TIMEOUT(Timeout) ->
handshake(Socket, Timeout);
handshake(#sslsocket{fd = {_, _, _, Trackers}} = Socket, SslOpts, Timeout) when
(is_integer(Timeout) andalso Timeout >= 0) or (Timeout == infinity)->
handshake(#sslsocket{fd = {_, _, _, Trackers}} = Socket, SslOpts, Timeout)
when is_list(SslOpts), ?IS_TIMEOUT(Timeout) ->
try
Tracker = proplists:get_value(option_tracker, Trackers),
{ok, EmOpts, _} = tls_socket:get_all_opts(Tracker),
Expand All @@ -2104,16 +2111,17 @@ handshake(#sslsocket{fd = {_, _, _, Trackers}} = Socket, SslOpts, Timeout) when
catch
Error = {error, _Reason} -> Error
end;
handshake(#sslsocket{pid = [Pid|_], fd = {_, _, _}} = Socket, SslOpts, Timeout) when
(is_integer(Timeout) andalso Timeout >= 0) or (Timeout == infinity)->
handshake(#sslsocket{pid = [Pid|_], fd = {_, _, _}} = Socket, SslOpts, Timeout)
when is_list(SslOpts), ?IS_TIMEOUT(Timeout) ->
try
{ok, EmOpts, _} = dtls_packet_demux:get_all_opts(Pid),
ssl_gen_statem:handshake(Socket, {SslOpts,
tls_socket:emulated_socket_options(EmOpts, #socket_options{})}, Timeout)
catch
Error = {error, _Reason} -> Error
end;
handshake(Socket, SslOptions, Timeout) when (is_integer(Timeout) andalso Timeout >= 0) or (Timeout == infinity) ->
handshake(Socket, SslOptions, Timeout)
when is_list(SslOptions), ?IS_TIMEOUT(Timeout) ->
try
CbInfo = handle_option_cb_info(SslOptions, tls),
Transport = element(1, CbInfo),
Expand Down Expand Up @@ -2147,6 +2155,7 @@ handshake(Socket, SslOptions, Timeout) when (is_integer(Timeout) andalso Timeout
%%--------------------------------------------------------------------
handshake_continue(Socket, SSLOptions) ->
handshake_continue(Socket, SSLOptions, infinity).

%%--------------------------------------------------------------------
-doc "Continue the TLS handshake, possibly with new, additional or changed options.".
-doc(#{since => <<"OTP 21.0">>}).
Expand All @@ -2161,8 +2170,10 @@ handshake_continue(Socket, SSLOptions) ->
%%
%% Description: Continues the handshake possible with newly supplied options.
%%--------------------------------------------------------------------
handshake_continue(Socket, SSLOptions, Timeout) ->
handshake_continue(Socket, SSLOptions, Timeout)
when is_list(SSLOptions), ?IS_TIMEOUT(Timeout) ->
ssl_gen_statem:handshake_continue(Socket, SSLOptions, Timeout).

%%--------------------------------------------------------------------
-doc "Cancel the handshake with a fatal `USER_CANCELED` alert.".
-doc(#{since => <<"OTP 21.0">>}).
Expand Down Expand Up @@ -2208,19 +2219,17 @@ connection.
%%
%% Description: Close an ssl connection
%%--------------------------------------------------------------------
close(#sslsocket{pid = [TLSPid|_]},
{Pid, Timeout} = DownGrade) when is_pid(TLSPid),
is_pid(Pid),
(is_integer(Timeout) andalso Timeout >= 0) or (Timeout == infinity) ->
close(#sslsocket{pid = [TLSPid|_]}, {Pid, Timeout} = DownGrade)
when is_pid(TLSPid), is_pid(Pid), ?IS_TIMEOUT(Timeout) ->
case ssl_gen_statem:close(TLSPid, {close, DownGrade}) of
ok -> %% In normal close {error, closed} is regarded as ok, as it is not interesting which side
%% that got to do the actual close. But in the downgrade case only {ok, Port} is a success.
{error, closed};
Other ->
Other
end;
close(#sslsocket{pid = [TLSPid|_]}, Timeout) when is_pid(TLSPid),
(is_integer(Timeout) andalso Timeout >= 0) or (Timeout == infinity) ->
close(#sslsocket{pid = [TLSPid|_]}, Timeout)
when is_pid(TLSPid), ?IS_TIMEOUT(Timeout) ->
ssl_gen_statem:close(TLSPid, {close, Timeout});
close(#sslsocket{pid = {dtls = ListenSocket, #config{transport_info={Transport,_,_,_,_}}}}, _) ->
dtls_socket:close(Transport, ListenSocket);
Expand Down Expand Up @@ -2287,9 +2296,7 @@ value is `infinity`.
HttpPacket :: any().

recv(#sslsocket{pid = [Pid|_]}, Length, Timeout)
when is_pid(Pid) andalso
(is_integer(Length) andalso Length >= 0) andalso
((is_integer(Timeout) andalso Timeout >= 0) orelse Timeout == infinity) ->
when is_pid(Pid), (is_integer(Length) andalso Length >= 0), ?IS_TIMEOUT(Timeout) ->
ssl_gen_statem:recv(Pid, Length, Timeout);
recv(#sslsocket{pid = {dtls,_}}, _, _) ->
{error,enotconn};
Expand All @@ -2311,14 +2318,16 @@ the owner of an SSL socket, and receives all messages from the socket.
%% Description: Changes process that receives the messages when active = true
%% or once.
%%--------------------------------------------------------------------
controlling_process(#sslsocket{pid = [Pid|_]}, NewOwner) when is_pid(Pid), is_pid(NewOwner) ->
controlling_process(#sslsocket{pid = [Pid|_]}, NewOwner)
when is_pid(Pid), is_pid(NewOwner) ->
ssl_gen_statem:new_user(Pid, NewOwner);
controlling_process(#sslsocket{pid = {dtls, _}},
NewOwner) when is_pid(NewOwner) ->
controlling_process(#sslsocket{pid = {dtls, _}}, NewOwner)
when is_pid(NewOwner) ->
ok; %% Meaningless but let it be allowed to conform with TLS
controlling_process(#sslsocket{pid = {Listen,
#config{transport_info = {Transport,_,_,_,_}}}},
NewOwner) when is_pid(NewOwner) ->
NewOwner)
when is_pid(NewOwner) ->
%% Meaningless but let it be allowed to conform with normal sockets
Transport:controlling_process(Listen, NewOwner).

Expand Down Expand Up @@ -2376,7 +2385,8 @@ set to `true`.
%%
%% Description: Return SSL information for the connection
%%--------------------------------------------------------------------
connection_information(#sslsocket{pid = [Pid|_]}, Items) when is_pid(Pid) ->
connection_information(#sslsocket{pid = [Pid|_]}, Items)
when is_pid(Pid), is_list(Items) ->
case ssl_gen_statem:connection_information(Pid, include_security_info(Items)) of
{ok, Info} ->
{ok, [Item || Item = {Key, Value} <- Info, lists:member(Key, Items),
Expand Down Expand Up @@ -2751,7 +2761,9 @@ groups(default) ->
%%--------------------------------------------------------------------
getopts(#sslsocket{pid = [Pid|_]}, OptionTags) when is_pid(Pid), is_list(OptionTags) ->
ssl_gen_statem:get_opts(Pid, OptionTags);
getopts(#sslsocket{pid = {dtls, #config{transport_info = {Transport,_,_,_,_}}}} = ListenSocket, OptionTags) when is_list(OptionTags) ->
getopts(#sslsocket{pid = {dtls, #config{transport_info = {Transport,_,_,_,_}}}} = ListenSocket,
OptionTags)
when is_list(OptionTags) ->
try dtls_socket:getopts(Transport, ListenSocket, OptionTags) of
{ok, _} = Result ->
Result;
Expand Down Expand Up @@ -2813,7 +2825,9 @@ setopts(#sslsocket{pid = [Pid|_]}, Options0) when is_pid(Pid), is_list(Options0)
_:_ ->
{error, {options, {not_a_proplist, Options0}}}
end;
setopts(#sslsocket{pid = {dtls, #config{transport_info = {Transport,_,_,_,_}}}} = ListenSocket, Options) when is_list(Options) ->
setopts(#sslsocket{pid = {dtls, #config{transport_info = {Transport,_,_,_,_}}}} = ListenSocket,
Options)
when is_list(Options) ->
try dtls_socket:setopts(Transport, ListenSocket, Options) of
ok ->
ok;
Expand All @@ -2823,7 +2837,8 @@ setopts(#sslsocket{pid = {dtls, #config{transport_info = {Transport,_,_,_,_}}}}
_:Error ->
{error, {options, {socket_options, Options, Error}}}
end;
setopts(#sslsocket{pid = {_, #config{transport_info = {Transport,_,_,_,_}}}} = ListenSocket, Options) when is_list(Options) ->
setopts(#sslsocket{pid = {_, #config{transport_info = {Transport,_,_,_,_}}}} = ListenSocket, Options)
when is_list(Options) ->
try tls_socket:setopts(Transport, ListenSocket, Options) of
ok ->
ok;
Expand Down

0 comments on commit 110b699

Please sign in to comment.