Skip to content

Commit

Permalink
revise bounds checks (cc @reynir @palainp), also check off >= 0
Browse files Browse the repository at this point in the history
  • Loading branch information
hannesm committed Jun 18, 2024
1 parent 3399544 commit 08a8b16
Showing 1 changed file with 35 additions and 67 deletions.
102 changes: 35 additions & 67 deletions src/cipher_block.ml
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,15 @@ module Counters = struct
end
end

let check_offset ~tag ~buf ~off ~len actual_len =
if off < 0 then
invalid_arg "%s: %s off %u < 0"
tag buf off;
if actual_len - off < len then
invalid_arg "%s: %s length %u - off %u < len %u"
tag buf actual_len off len
[@@inline]

module Modes = struct
module ECB_of (Core : Block.Core) : Block.ECB = struct

Expand All @@ -167,12 +176,8 @@ module Modes = struct
let ecb xform key src src_off dst dst_off len =
if len mod block_size <> 0 then
invalid_arg "ECB: length %u not of block size" len;
if String.length src - src_off < len then
invalid_arg "ECB: source length %u - src_off %u < len %u"
(String.length src) src_off len;
if Bytes.length dst - dst_off < len then
invalid_arg "ECB: dst length %u - dst_off %u < len %u"
(Bytes.length dst) dst_off len;
check_offset ~tag:"ECB" ~buf:"src" ~off:src_off ~len (String.length src);
check_offset ~tag:"ECB" ~buf:"dst" ~off:dst_off ~len (Bytes.length dst);
unsafe_ecb xform key src src_off dst dst_off len

let encrypt_into ~key:(key, _) src ~src_off dst ~dst_off len =
Expand Down Expand Up @@ -209,15 +214,16 @@ module Modes = struct

let of_secret = Core.of_secret

let bounds_check ?(off = 0) ~iv cs =
let block_size_check ?(off = 0) ~iv cs =
if String.length iv <> block then
invalid_arg "CBC: IV length %u not of block size" (String.length iv);
if (String.length cs - off) mod block <> 0 then
invalid_arg "CBC: argument length %u (off %u) not of block size"
(String.length cs) off
[@@inline]

let next_iv ?(off = 0) cs ~iv =
bounds_check ~iv cs ~off ;
block_size_check ~iv cs ~off ;
if String.length cs > off then
String.sub cs (String.length cs - block_size) block_size
else iv
Expand All @@ -237,13 +243,9 @@ module Modes = struct
unsafe_encrypt_into_inplace ~key ~iv dst ~dst_off len

let encrypt_into ~key ~iv src ~src_off dst ~dst_off len =
bounds_check ~off:src_off ~iv src;
if String.length src - src_off < len then
invalid_arg "CBC: src has insufficient length (%u - src_off:%u < len %u)"
(String.length src) src_off len;
if Bytes.length dst - dst_off < len then
invalid_arg "CBC: dst has insufficient length (%u - dst_off:%u < len %u)"
(Bytes.length dst) dst_off len;
block_size_check ~off:src_off ~iv src;
check_offset ~tag:"CBC" ~buf:"src" ~off:src_off ~len (String.length src);
check_offset ~tag:"CBC" ~buf:"dst" ~off:dst_off ~len (Bytes.length dst);
unsafe_encrypt_into ~key ~iv src ~src_off dst ~dst_off len

let encrypt ~key ~iv src =
Expand All @@ -260,13 +262,9 @@ module Modes = struct
end

let decrypt_into ~key ~iv src ~src_off dst ~dst_off len =
bounds_check ~off:src_off ~iv src;
if String.length src - src_off < len then
invalid_arg "CBC: src has insufficient length (%u - src_off:%u < len %u)"
(String.length src) src_off len;
if Bytes.length dst - dst_off < len then
invalid_arg "CBC: dst has insufficient length (%u - dst_off:%u < len %u)"
(Bytes.length dst) dst_off len;
block_size_check ~off:src_off ~iv src;
check_offset ~tag:"CBC" ~buf:"src" ~off:src_off ~len (String.length src);
check_offset ~tag:"CBC" ~buf:"dst" ~off:dst_off ~len (Bytes.length dst);
unsafe_decrypt_into ~key ~iv src ~src_off dst ~dst_off len

let decrypt ~key ~iv src =
Expand Down Expand Up @@ -302,9 +300,7 @@ module Modes = struct
end

let stream_into ~key ~ctr buf ~off len =
if Bytes.length buf - off < len then
invalid_arg "CTR: buffer length %u - off %u < len %u"
(Bytes.length buf) off len;
check_offset ~tag:"CTR" ~buf:"buf" ~off ~len (Bytes.length buf);
unsafe_stream_into ~key ~ctr buf ~off len

let stream ~key ~ctr n =
Expand All @@ -317,12 +313,8 @@ module Modes = struct
Uncommon.unsafe_xor_into src ~src_off dst ~dst_off len

let encrypt_into ~key ~ctr src ~src_off dst ~dst_off len =
if String.length src - src_off < len then
invalid_arg "CTR: src length %u - src_off %u < len %u"
(String.length src) src_off len;
if Bytes.length dst - dst_off < len then
invalid_arg "CTR: dst length %u - dst_off %u < len %u"
(Bytes.length dst) dst_off len;
check_offset ~tag:"CTR" ~buf:"src" ~off:src_off ~len (String.length src);
check_offset ~tag:"CTR" ~buf:"dst" ~off:dst_off ~len (Bytes.length dst);
unsafe_encrypt_into ~key ~ctr src ~src_off dst ~dst_off len

let encrypt ~key ~ctr src =
Expand Down Expand Up @@ -416,15 +408,9 @@ module Modes = struct
unsafe_tag_into ~key ~hkey ~ctr ?adata (Bytes.unsafe_to_string dst) ~off:dst_off ~len dst ~tag_off

let authenticate_encrypt_into ~key ~nonce ?adata src ~src_off dst ~dst_off ~tag_off len =
if String.length src - src_off < len then
invalid_arg "GCM: source length %u - src_off %u < len %u"
(String.length src) src_off len;
if Bytes.length dst - dst_off < len then
invalid_arg "GCM: dst length %u - dst_off %u < len %u"
(Bytes.length dst) dst_off len;
if Bytes.length dst - tag_off < tag_size then
invalid_arg "GCM: dst length %u - tag_off %u < tag_size %u"
(Bytes.length dst) tag_off tag_size;
check_offset ~tag:"GCM" ~buf:"src" ~off:src_off ~len (String.length src);
check_offset ~tag:"GCM" ~buf:"dst" ~off:dst_off ~len (Bytes.length dst);
check_offset ~tag:"GCM" ~buf:"dst tag" ~off:tag_off ~len:tag_size (Bytes.length dst);
unsafe_authenticate_encrypt_into ~key ~nonce ?adata src ~src_off dst ~dst_off ~tag_off len

let authenticate_encrypt ~key ~nonce ?adata data =
Expand All @@ -446,15 +432,9 @@ module Modes = struct
Eqaf.equal (String.sub src tag_off tag_size) (Bytes.unsafe_to_string ctag)

let authenticate_decrypt_into ~key ~nonce ?adata src ~src_off ~tag_off dst ~dst_off len =
if String.length src - src_off < len then
invalid_arg "GCM: source length %u - src_off %u < len %u"
(String.length src) src_off len;
if Bytes.length dst - dst_off < len then
invalid_arg "GCM: dst length %u - dst_off %u < len %u"
(Bytes.length dst) dst_off len;
if String.length src - tag_off < tag_size then
invalid_arg "GCM: src length %u - tag_off %u < tag_size %u"
(String.length src) tag_off tag_size;
check_offset ~tag:"GCM" ~buf:"src" ~off:src_off ~len (String.length src);
check_offset ~tag:"GCM" ~buf:"src tag" ~off:tag_off ~len:tag_size (String.length src);
check_offset ~tag:"GCM" ~buf:"dst" ~off:dst_off ~len (Bytes.length dst);
unsafe_authenticate_decrypt_into ~key ~nonce ?adata src ~src_off ~tag_off dst ~dst_off len

let authenticate_decrypt ~key ~nonce ?adata cdata =
Expand Down Expand Up @@ -498,15 +478,9 @@ module Modes = struct
invalid_arg "CCM: nonce length not between 7 and 13: %u" nsize

let authenticate_encrypt_into ~key ~nonce ?adata src ~src_off dst ~dst_off ~tag_off len =
if String.length src - src_off < len then
invalid_arg "CCM: source length %u - src_off %u < len %u"
(String.length src) src_off len;
if Bytes.length dst - dst_off < len then
invalid_arg "CCM: dst length %u - dst_off %u < len %u"
(Bytes.length dst) dst_off len;
if Bytes.length dst - tag_off < tag_size then
invalid_arg "CCM: dst length %u - tag_off %u < tag_size %u"
(Bytes.length dst) tag_off tag_size;
check_offset ~tag:"CCM" ~buf:"src" ~off:src_off ~len (String.length src);
check_offset ~tag:"CCM" ~buf:"dst" ~off:dst_off ~len (Bytes.length dst);
check_offset ~tag:"CCM" ~buf:"dst tag" ~off:tag_off ~len:tag_size (Bytes.length dst);
valid_nonce nonce;
unsafe_authenticate_encrypt_into ~key ~nonce ?adata src ~src_off dst ~dst_off ~tag_off len

Expand All @@ -525,15 +499,9 @@ module Modes = struct
Ccm.unsafe_decryption_verification_into ~cipher ~key ~nonce ~maclen:tag_size ~adata src ~src_off ~tag_off dst ~dst_off len

let authenticate_decrypt_into ~key ~nonce ?adata src ~src_off ~tag_off dst ~dst_off len =
if String.length src - src_off < len then
invalid_arg "CCM: source length %u - src_off %u < len %u"
(String.length src) src_off len;
if Bytes.length dst - dst_off < len then
invalid_arg "CCM: dst length %u - dst_off %u < len %u"
(Bytes.length dst) dst_off len;
if String.length src - tag_off < tag_size then
invalid_arg "CCM: src length %u - tag_off %u < tag_size %u"
(String.length src) tag_off tag_size;
check_offset ~tag:"CCM" ~buf:"src" ~off:src_off ~len (String.length src);
check_offset ~tag:"CCM" ~buf:"src tag" ~off:tag_off ~len:tag_size (String.length src);
check_offset ~tag:"CCM" ~buf:"dst" ~off:dst_off ~len (Bytes.length dst);
valid_nonce nonce;
unsafe_authenticate_decrypt_into ~key ~nonce ?adata src ~src_off ~tag_off dst ~dst_off len

Expand Down

0 comments on commit 08a8b16

Please sign in to comment.