diff --git a/mirage-crypto-ec.opam b/mirage-crypto-ec.opam index 4fab25a5..fb7d8b51 100644 --- a/mirage-crypto-ec.opam +++ b/mirage-crypto-ec.opam @@ -36,6 +36,7 @@ depends: [ "ppx_deriving_yojson" {with-test} "ppx_deriving" {with-test} "yojson" {with-test & >= "1.6.0"} + "asn1-combinators" {with-test & >= "0.3.1"} ] conflicts: [ "ocaml-freestanding" diff --git a/tests/dune b/tests/dune index 05ae7adb..55bfc3e8 100644 --- a/tests/dune +++ b/tests/dune @@ -51,13 +51,14 @@ (test (name test_ec_wycheproof) + (modes native) (modules test_ec_wycheproof) (deps ecdh_secp256r1_test.json ecdsa_secp256r1_sha256_test.json ecdsa_secp256r1_sha512_test.json ecdh_secp384r1_test.json ecdsa_secp384r1_sha384_test.json ecdsa_secp384r1_sha512_test.json ecdh_secp521r1_test.json ecdsa_secp521r1_sha512_test.json x25519_test.json eddsa_test.json) - (libraries alcotest mirage-crypto-ec wycheproof digestif) + (libraries alcotest mirage-crypto-ec wycheproof digestif asn1-combinators) (package mirage-crypto-ec)) (tests diff --git a/tests/test_ec_wycheproof.ml b/tests/test_ec_wycheproof.ml index 8d8010a1..f62b8976 100644 --- a/tests/test_ec_wycheproof.ml +++ b/tests/test_ec_wycheproof.ml @@ -15,138 +15,33 @@ let string_get_uint8 d off = let hex = Alcotest.testable Wycheproof.pp_hex Wycheproof.equal_hex module Asn = struct - (* This is a handcrafted asn1 parser, sufficient for the wycheproof tests. - The underlying reason is to avoid a dependency on asn1-grammars and - mirage-crypto-pk (which depends on gmp and zarith, and are cumbersome to - build on windows with CL.EXE). *) - - let guard p e = if p then Ok () else Error e - - let decode_len start_off buf = - let len = string_get_uint8 buf start_off in - if len >= 0x80 then - let bytes = len - 0x80 in - let rec g acc off = - if off = bytes then - Ok (acc, bytes + start_off + 1) - else - let this = string_get_uint8 buf (start_off + 1 + off) in - let* () = guard (off = 0 && this >= 0x80) "badly encoded length" in - let acc' = acc lsl 8 + this in - let* () = guard (acc <= acc') "decode_len overflow in acc" in - g (acc lsl 8 + string_get_uint8 buf (start_off + 1 + off)) (succ off) - in - g 0 0 - else - Ok (len, start_off + 1) - - let decode_seq data = - let* () = guard (String.length data > 2) "decode_seq: data too short" in - let tag = string_get_uint8 data 0 in - let* () = guard (tag = 0x30) "decode_seq: bad tag (should be 0x30)" in - let* len, off = decode_len 1 data in - let* () = guard (String.length data - off >= len) "decode_seq: too short" in - Ok (String.sub data off len, - if String.length data - off > len then - Some (String.sub data (off + len) (String.length data - len - off)) - else - None) - - let decode_2_oid data = - let decode_one off = - let tag = string_get_uint8 data off in - let* () = guard (tag = 0x06) "decode_oid: bad tag (should be 0x06)" in - let len = string_get_uint8 data (off + 1) in - let* () = guard (String.length data - 2 - off >= len) "decode_oid: data too short" in - Ok (String.sub data (off + 2) len, off + 2 + len) - in - let* first, off = decode_one 0 in - let* second, off = decode_one off in - let* () = guard (off = String.length data) "decode_oid: leftover data" in - Ok (first, second) - - let decode_bit_string data = - let tag = string_get_uint8 data 0 in - let* () = guard (tag = 0x03) "decode_bit_string: bad tag (expected 0x03)" in - let* len, off = decode_len 1 data in - let* () = guard (String.length data - off = len) "decode_bit_string: leftover or too short data" in - let unused = string_get_uint8 data off in - let* () = guard (unused = 0) "unused is not 0" in - Ok (String.sub data (off + 1) (len - 1)) - - let decode_int_pair data = - let decode_int off = - let* () = guard (String.length data - off > 2) "decode_int: data too short" in - let tag = string_get_uint8 data off in - let* () = guard (tag = 0x02) "decode_int: bad tag (should be 0x02)" in - let len = string_get_uint8 data (off + 1) in - let* () = guard (String.length data - off - 2 >= len) "decode_int: too short" in - let fix_one = if string_get_uint8 data (off + 2) = 0x00 then 1 else 0 in - let* () = guard (string_get_uint8 data (off + 2) land 0x80 = 0) "decode_int: negative number" in - let* () = - if String.length data > off + 3 && fix_one = 1 then - guard (string_get_uint8 data (off + 3) <> 0x00) "decode_int: leading extra 0 byte" - else - Ok () - in - Ok (String.sub data (fix_one + off + 2) (len - fix_one), off + len + 2) - in - let* first, off = decode_int 0 in - let* second, off = decode_int off in - let* () = guard (off = String.length data) "decode_int: leftover data" in - Ok (first, second) - - let encode_oid = function - | first :: second :: rt -> - let oct1 = 40 * first + second in - let octs = concat_map (fun x -> - let fst = x / 16384 - and snd = x / 128 - and thr = x mod 128 - in - assert (fst < 128); - (if fst > 0 then [ 128 (* set high bit *) + fst ] else []) @ - (if snd > 0 then [ 128 + snd ] else []) @ - [ thr ]) - rt - in - String.init (1 + List.length octs) (function - | 0 -> char_of_int oct1 - | n -> char_of_int (List.nth octs (pred n))) - | _ -> assert false - let parse_point curve s = - let ec_public_key = encode_oid [ 1 ; 2 ; 840; 10045; 2; 1 ] in - let prime_oid = encode_oid (match curve with - | "secp256r1" -> [ 1 ; 2 ; 840; 10045; 3; 1; 7 ] - | "secp384r1" -> [ 1 ; 3 ; 132; 0; 34 ] - | "secp521r1" -> [ 1 ; 3 ; 132; 0; 35 ] - | _ -> assert false) + let seq2 a b = Asn.S.(sequence2 (required a) (required b)) in + let term = Asn.S.(seq2 (seq2 oid oid) bit_string_octets) in + let ec_public_key = Asn.OID.(base 1 2 <|| [ 840; 10045; 2; 1 ]) in + let prime_oid = match curve with + | "secp256r1" -> Asn.OID.(base 1 2 <|| [ 840; 10045; 3; 1; 7 ]) + | "secp384r1" -> Asn.OID.(base 1 3 <|| [ 132; 0; 34 ]) + | "secp521r1" -> Asn.OID.(base 1 3 <|| [ 132; 0; 35 ]) + | _ -> assert false in - let* r = decode_seq s in - match r with - | _data, Some _ -> Error "expected no leftover" - | data, None -> - let* r = decode_seq data in - match r with - | _oids, None -> Error "expected some data" - | oids, Some data -> - let* oid1, oid2 = decode_2_oid oids in - let* data = decode_bit_string data in - if not (String.equal oid1 ec_public_key) then - Error "ASN1: wrong oid 1" - else if not (String.equal oid2 prime_oid) then - Error "ASN1: wrong oid 2" - else - Ok data - - let parse_signature s = - let* r = decode_seq s in - match r with - | _data, Some _ -> Error "expected no leftover" - | data, None -> - let* r, s = decode_int_pair data in - Ok (r, s) + match Asn.decode (Asn.codec Asn.ber term) s with + | Error _ -> Error "ASN1 parse error" + | Ok (((oid1, oid2), data), rest) -> + if String.length rest <> 0 then Error "ASN1 leftover" + else if not (Asn.OID.equal oid1 ec_public_key) then + Error "ASN1: wrong oid 1" + else if not (Asn.OID.equal oid2 prime_oid) then Error "ASN1: wrong oid 2" + else Ok data + + let parse_signature cs = + let asn = Asn.S.(sequence2 (required unsigned_integer) (required unsigned_integer)) in + match Asn.(decode (codec der asn) cs) with + | Error _ -> Error "ASN1 parse error" + | Ok (r_s, rest) -> + if String.length rest <> 0 then Error "ASN1 leftover" + else + Ok r_s end let to_string_result ~pp_error = function