Skip to content

Commit

Permalink
mirage-crypto-pk: rsa avoid bytes.unsafe_of_string @reynir
Browse files Browse the repository at this point in the history
  • Loading branch information
hannesm committed Mar 5, 2024
1 parent 95d73d5 commit d24659b
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
17 changes: 9 additions & 8 deletions pk/rsa.ml
Original file line number Diff line number Diff line change
Expand Up @@ -319,12 +319,15 @@ module MGF1 (H : Hash.S) = struct
(* Assumes len < 2^32 * H.digest_size. *)
let mgf ~seed len =
let rec go acc c = function
| 0 -> String.sub (String.concat "" (List.rev acc)) 0 len
| n -> let h = Cstruct.to_string (H.digesti (iter2 (Cstruct.of_string seed) (repr c))) in
| 0 -> Bytes.sub (Bytes.concat Bytes.empty (List.rev acc)) 0 len
| n -> let h = Cstruct.to_bytes (H.digesti (iter2 (Cstruct.of_string seed) (repr c))) in
go (h :: acc) Int32.(succ c) (pred n) in
go [] 0l (len // H.digest_size)

let mask ~seed buf = xor (mgf ~seed (String.length buf)) buf
let mask ~seed buf =
let mgf_data = mgf ~seed (String.length buf) in
xor_into buf mgf_data (String.length buf);
mgf_data
end

module OAEP (H : Hash.S) = struct
Expand All @@ -339,16 +342,16 @@ module OAEP (H : Hash.S) = struct
let seed = Cstruct.to_string (Mirage_crypto_rng.generate ?g hlen)
and pad = String.make (max_msg_bytes k - String.length msg) '\x00' in
let db = String.concat "" [ Cstruct.to_string (H.digest (Cstruct.of_string label)) ; pad ; bx01 ; msg ] in
let mdb = MGF.mask ~seed db in
let mseed = MGF.mask ~seed:mdb seed in
let mdb = Bytes.unsafe_to_string (MGF.mask ~seed db) in
let mseed = Bytes.unsafe_to_string (MGF.mask ~seed:mdb seed) in
String.concat "" [ bx00 ; mseed ; mdb ]

let eme_oaep_decode ?(label = "") msg =
let b0 = String.sub msg 0 1
and ms = String.sub msg 1 hlen
and mdb = String.sub msg (1 + hlen) (String.length msg - 1 - hlen)
in
let db = MGF.mask ~seed:(MGF.mask ~seed:mdb ms) mdb in
let db = Bytes.unsafe_to_string (MGF.mask ~seed:(Bytes.unsafe_to_string (MGF.mask ~seed:mdb ms)) mdb) in
let i = ct_find_uint8 ~default:0 ~off:hlen ~f:((<>) 0x00) db in
let c1 = Eqaf.equal (String.sub db 0 hlen) (Cstruct.to_string H.(digest (Cstruct.of_string label)))
and c2 = string_get_uint8 b0 0 = 0x00
Expand Down Expand Up @@ -393,7 +396,6 @@ module PSS (H: Hash.S) = struct
let h = digest ~salt msg in
let db = String.concat "" [ String.make (n - slen - hlen - 2) '\x00' ; bx01 ; Cstruct.to_string salt ] in
let mdb = MGF.mask ~seed:(Cstruct.to_string h) db in
let mdb = Bytes.unsafe_of_string mdb in
Bytes.set_uint8 mdb 0 @@ Bytes.get_uint8 mdb 0 land b0mask emlen ;
String.concat "" [ Bytes.unsafe_to_string mdb ; Cstruct.to_string h ; bxbc ]

Expand All @@ -403,7 +405,6 @@ module PSS (H: Hash.S) = struct
and bxx = String.sub em (String.length em - 1) 1
in
let db = MGF.mask ~seed:h mdb in
let db = Bytes.unsafe_of_string db in
Bytes.set_uint8 db 0 (Bytes.get_uint8 db 0 land b0mask emlen) ;
let db = Bytes.unsafe_to_string db in
let salt = String.sub db (String.length db - slen) slen in
Expand Down
1 change: 1 addition & 0 deletions src/mirage_crypto.mli
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ module Uncommon : sig
val iter3 : 'a -> 'a -> 'a -> ('a -> unit) -> unit

val xor : string -> string -> string
val xor_into : string -> bytes -> int -> unit

val invalid_arg : ('a, Format.formatter, unit, unit, unit, 'b) format6 -> 'a
val failwith : ('a, Format.formatter, unit, unit, unit, 'b) format6 -> 'a
Expand Down

0 comments on commit d24659b

Please sign in to comment.