diff --git a/src/cipher_block.ml b/src/cipher_block.ml index e196fc06..af958e28 100644 --- a/src/cipher_block.ml +++ b/src/cipher_block.ml @@ -152,6 +152,15 @@ module Counters = struct end end +let check_offset ~tag ~buf ~off ~len actual_len = + if off < 0 then + invalid_arg "%s: %s off %u < 0" + tag buf off; + if actual_len - off < len then + invalid_arg "%s: %s length %u - off %u < len %u" + tag buf actual_len off len +[@@inline] + module Modes = struct module ECB_of (Core : Block.Core) : Block.ECB = struct @@ -167,12 +176,8 @@ module Modes = struct let ecb xform key src src_off dst dst_off len = if len mod block_size <> 0 then invalid_arg "ECB: length %u not of block size" len; - if String.length src - src_off < len then - invalid_arg "ECB: source length %u - src_off %u < len %u" - (String.length src) src_off len; - if Bytes.length dst - dst_off < len then - invalid_arg "ECB: dst length %u - dst_off %u < len %u" - (Bytes.length dst) dst_off len; + check_offset ~tag:"ECB" ~buf:"src" ~off:src_off ~len (String.length src); + check_offset ~tag:"ECB" ~buf:"dst" ~off:dst_off ~len (Bytes.length dst); unsafe_ecb xform key src src_off dst dst_off len let encrypt_into ~key:(key, _) src ~src_off dst ~dst_off len = @@ -209,15 +214,16 @@ module Modes = struct let of_secret = Core.of_secret - let bounds_check ?(off = 0) ~iv cs = + let block_size_check ?(off = 0) ~iv cs = if String.length iv <> block then invalid_arg "CBC: IV length %u not of block size" (String.length iv); if (String.length cs - off) mod block <> 0 then invalid_arg "CBC: argument length %u (off %u) not of block size" (String.length cs) off + [@@inline] let next_iv ?(off = 0) cs ~iv = - bounds_check ~iv cs ~off ; + block_size_check ~iv cs ~off ; if String.length cs > off then String.sub cs (String.length cs - block_size) block_size else iv @@ -237,13 +243,9 @@ module Modes = struct unsafe_encrypt_into_inplace ~key ~iv dst ~dst_off len let encrypt_into ~key ~iv src ~src_off dst ~dst_off len = - bounds_check ~off:src_off ~iv src; - if String.length src - src_off < len then - invalid_arg "CBC: src has insufficient length (%u - src_off:%u < len %u)" - (String.length src) src_off len; - if Bytes.length dst - dst_off < len then - invalid_arg "CBC: dst has insufficient length (%u - dst_off:%u < len %u)" - (Bytes.length dst) dst_off len; + block_size_check ~off:src_off ~iv src; + check_offset ~tag:"CBC" ~buf:"src" ~off:src_off ~len (String.length src); + check_offset ~tag:"CBC" ~buf:"dst" ~off:dst_off ~len (Bytes.length dst); unsafe_encrypt_into ~key ~iv src ~src_off dst ~dst_off len let encrypt ~key ~iv src = @@ -260,13 +262,9 @@ module Modes = struct end let decrypt_into ~key ~iv src ~src_off dst ~dst_off len = - bounds_check ~off:src_off ~iv src; - if String.length src - src_off < len then - invalid_arg "CBC: src has insufficient length (%u - src_off:%u < len %u)" - (String.length src) src_off len; - if Bytes.length dst - dst_off < len then - invalid_arg "CBC: dst has insufficient length (%u - dst_off:%u < len %u)" - (Bytes.length dst) dst_off len; + block_size_check ~off:src_off ~iv src; + check_offset ~tag:"CBC" ~buf:"src" ~off:src_off ~len (String.length src); + check_offset ~tag:"CBC" ~buf:"dst" ~off:dst_off ~len (Bytes.length dst); unsafe_decrypt_into ~key ~iv src ~src_off dst ~dst_off len let decrypt ~key ~iv src = @@ -302,9 +300,7 @@ module Modes = struct end let stream_into ~key ~ctr buf ~off len = - if Bytes.length buf - off < len then - invalid_arg "CTR: buffer length %u - off %u < len %u" - (Bytes.length buf) off len; + check_offset ~tag:"CTR" ~buf:"buf" ~off ~len (Bytes.length buf); unsafe_stream_into ~key ~ctr buf ~off len let stream ~key ~ctr n = @@ -317,12 +313,8 @@ module Modes = struct Uncommon.unsafe_xor_into src ~src_off dst ~dst_off len let encrypt_into ~key ~ctr src ~src_off dst ~dst_off len = - if String.length src - src_off < len then - invalid_arg "CTR: src length %u - src_off %u < len %u" - (String.length src) src_off len; - if Bytes.length dst - dst_off < len then - invalid_arg "CTR: dst length %u - dst_off %u < len %u" - (Bytes.length dst) dst_off len; + check_offset ~tag:"CTR" ~buf:"src" ~off:src_off ~len (String.length src); + check_offset ~tag:"CTR" ~buf:"dst" ~off:dst_off ~len (Bytes.length dst); unsafe_encrypt_into ~key ~ctr src ~src_off dst ~dst_off len let encrypt ~key ~ctr src = @@ -416,15 +408,9 @@ module Modes = struct unsafe_tag_into ~key ~hkey ~ctr ?adata (Bytes.unsafe_to_string dst) ~off:dst_off ~len dst ~tag_off let authenticate_encrypt_into ~key ~nonce ?adata src ~src_off dst ~dst_off ~tag_off len = - if String.length src - src_off < len then - invalid_arg "GCM: source length %u - src_off %u < len %u" - (String.length src) src_off len; - if Bytes.length dst - dst_off < len then - invalid_arg "GCM: dst length %u - dst_off %u < len %u" - (Bytes.length dst) dst_off len; - if Bytes.length dst - tag_off < tag_size then - invalid_arg "GCM: dst length %u - tag_off %u < tag_size %u" - (Bytes.length dst) tag_off tag_size; + check_offset ~tag:"GCM" ~buf:"src" ~off:src_off ~len (String.length src); + check_offset ~tag:"GCM" ~buf:"dst" ~off:dst_off ~len (Bytes.length dst); + check_offset ~tag:"GCM" ~buf:"dst tag" ~off:tag_off ~len:tag_size (Bytes.length dst); unsafe_authenticate_encrypt_into ~key ~nonce ?adata src ~src_off dst ~dst_off ~tag_off len let authenticate_encrypt ~key ~nonce ?adata data = @@ -446,15 +432,9 @@ module Modes = struct Eqaf.equal (String.sub src tag_off tag_size) (Bytes.unsafe_to_string ctag) let authenticate_decrypt_into ~key ~nonce ?adata src ~src_off ~tag_off dst ~dst_off len = - if String.length src - src_off < len then - invalid_arg "GCM: source length %u - src_off %u < len %u" - (String.length src) src_off len; - if Bytes.length dst - dst_off < len then - invalid_arg "GCM: dst length %u - dst_off %u < len %u" - (Bytes.length dst) dst_off len; - if String.length src - tag_off < tag_size then - invalid_arg "GCM: src length %u - tag_off %u < tag_size %u" - (String.length src) tag_off tag_size; + check_offset ~tag:"GCM" ~buf:"src" ~off:src_off ~len (String.length src); + check_offset ~tag:"GCM" ~buf:"src tag" ~off:tag_off ~len:tag_size (String.length src); + check_offset ~tag:"GCM" ~buf:"dst" ~off:dst_off ~len (Bytes.length dst); unsafe_authenticate_decrypt_into ~key ~nonce ?adata src ~src_off ~tag_off dst ~dst_off len let authenticate_decrypt ~key ~nonce ?adata cdata = @@ -498,15 +478,9 @@ module Modes = struct invalid_arg "CCM: nonce length not between 7 and 13: %u" nsize let authenticate_encrypt_into ~key ~nonce ?adata src ~src_off dst ~dst_off ~tag_off len = - if String.length src - src_off < len then - invalid_arg "CCM: source length %u - src_off %u < len %u" - (String.length src) src_off len; - if Bytes.length dst - dst_off < len then - invalid_arg "CCM: dst length %u - dst_off %u < len %u" - (Bytes.length dst) dst_off len; - if Bytes.length dst - tag_off < tag_size then - invalid_arg "CCM: dst length %u - tag_off %u < tag_size %u" - (Bytes.length dst) tag_off tag_size; + check_offset ~tag:"CCM" ~buf:"src" ~off:src_off ~len (String.length src); + check_offset ~tag:"CCM" ~buf:"dst" ~off:dst_off ~len (Bytes.length dst); + check_offset ~tag:"CCM" ~buf:"dst tag" ~off:tag_off ~len:tag_size (Bytes.length dst); valid_nonce nonce; unsafe_authenticate_encrypt_into ~key ~nonce ?adata src ~src_off dst ~dst_off ~tag_off len @@ -525,15 +499,9 @@ module Modes = struct Ccm.unsafe_decryption_verification_into ~cipher ~key ~nonce ~maclen:tag_size ~adata src ~src_off ~tag_off dst ~dst_off len let authenticate_decrypt_into ~key ~nonce ?adata src ~src_off ~tag_off dst ~dst_off len = - if String.length src - src_off < len then - invalid_arg "CCM: source length %u - src_off %u < len %u" - (String.length src) src_off len; - if Bytes.length dst - dst_off < len then - invalid_arg "CCM: dst length %u - dst_off %u < len %u" - (Bytes.length dst) dst_off len; - if String.length src - tag_off < tag_size then - invalid_arg "CCM: src length %u - tag_off %u < tag_size %u" - (String.length src) tag_off tag_size; + check_offset ~tag:"CCM" ~buf:"src" ~off:src_off ~len (String.length src); + check_offset ~tag:"CCM" ~buf:"src tag" ~off:tag_off ~len:tag_size (String.length src); + check_offset ~tag:"CCM" ~buf:"dst" ~off:dst_off ~len (Bytes.length dst); valid_nonce nonce; unsafe_authenticate_decrypt_into ~key ~nonce ?adata src ~src_off ~tag_off dst ~dst_off len