From e24d02fbba5fbe0a5680f840266f7a35551b9242 Mon Sep 17 00:00:00 2001 From: kares Date: Thu, 4 Apr 2024 15:44:44 +0200 Subject: [PATCH] [fix] revert readPrivateKey so public key is not lost (#292) --- .../java/org/jruby/ext/openssl/impl/PKey.java | 110 +++++++++--------- .../ext/openssl/x509store/PEMInputOutput.java | 24 ++-- 2 files changed, 69 insertions(+), 65 deletions(-) diff --git a/src/main/java/org/jruby/ext/openssl/impl/PKey.java b/src/main/java/org/jruby/ext/openssl/impl/PKey.java index 23a4cf2e..ad623610 100644 --- a/src/main/java/org/jruby/ext/openssl/impl/PKey.java +++ b/src/main/java/org/jruby/ext/openssl/impl/PKey.java @@ -39,13 +39,10 @@ import java.security.interfaces.DSAPrivateKey; import java.security.interfaces.DSAPublicKey; import java.security.interfaces.ECPrivateKey; -import java.security.interfaces.ECPublicKey; import java.security.interfaces.RSAPrivateCrtKey; import java.security.interfaces.RSAPublicKey; import java.security.spec.DSAPrivateKeySpec; import java.security.spec.DSAPublicKeySpec; -import java.security.spec.ECParameterSpec; -import java.security.spec.ECPrivateKeySpec; import java.security.spec.InvalidKeySpecException; import java.security.spec.KeySpec; import java.security.spec.PKCS8EncodedKeySpec; @@ -67,14 +64,10 @@ import org.bouncycastle.asn1.DLSequence; import org.bouncycastle.asn1.pkcs.PKCSObjectIdentifiers; import org.bouncycastle.asn1.pkcs.PrivateKeyInfo; -import org.bouncycastle.asn1.sec.ECPrivateKeyStructure; import org.bouncycastle.asn1.x509.AlgorithmIdentifier; import org.bouncycastle.asn1.x509.DSAParameter; import org.bouncycastle.asn1.x509.SubjectPublicKeyInfo; import org.bouncycastle.asn1.x9.X9ObjectIdentifiers; -import org.bouncycastle.jce.ECNamedCurveTable; -import org.bouncycastle.jce.spec.ECNamedCurveParameterSpec; -import org.bouncycastle.jce.spec.ECPublicKeySpec; import org.bouncycastle.jcajce.provider.asymmetric.util.KeyUtil; import org.jruby.ext.openssl.SecurityHelper; @@ -87,44 +80,54 @@ */ public class PKey { - public static KeyPair readPrivateKey(final byte[] input, final String type) + public enum Type { RSA, DSA, EC; } + + public static KeyPair readPrivateKey(final Type type, final byte[] input) throws IOException, NoSuchAlgorithmException, InvalidKeySpecException { - return readPrivateKey((ASN1Sequence) new ASN1InputStream(input).readObject(), type); + return readPrivateKey(type, mockPrivateKeyInfo(type, input)); + } + + private static PrivateKeyInfo mockPrivateKeyInfo(final Type type, final byte[] input) throws IOException { + return new PrivateKeyInfo(null, new ASN1InputStream(input).readObject()); } - public static KeyPair readPrivateKey(final ASN1Sequence seq, final String type) + public static KeyPair readPrivateKey(final Type type, final PrivateKeyInfo keyInfo) throws IOException, NoSuchAlgorithmException, InvalidKeySpecException { - KeySpec pubSpec; KeySpec privSpec; - if ( type.equals("RSA") ) { - ASN1Integer mod = (ASN1Integer) seq.getObjectAt(1); - ASN1Integer pubExp = (ASN1Integer) seq.getObjectAt(2); - ASN1Integer privExp = (ASN1Integer) seq.getObjectAt(3); - ASN1Integer p1 = (ASN1Integer) seq.getObjectAt(4); - ASN1Integer p2 = (ASN1Integer) seq.getObjectAt(5); - ASN1Integer exp1 = (ASN1Integer) seq.getObjectAt(6); - ASN1Integer exp2 = (ASN1Integer) seq.getObjectAt(7); - ASN1Integer crtCoef = (ASN1Integer) seq.getObjectAt(8); - pubSpec = new RSAPublicKeySpec(mod.getValue(), pubExp.getValue()); - privSpec = new RSAPrivateCrtKeySpec(mod.getValue(), pubExp.getValue(), privExp.getValue(), p1.getValue(), p2.getValue(), exp1.getValue(), - exp2.getValue(), crtCoef.getValue()); - } - else if ( type.equals("DSA") ) { - ASN1Integer p = (ASN1Integer) seq.getObjectAt(1); - ASN1Integer q = (ASN1Integer) seq.getObjectAt(2); - ASN1Integer g = (ASN1Integer) seq.getObjectAt(3); - ASN1Integer y = (ASN1Integer) seq.getObjectAt(4); - ASN1Integer x = (ASN1Integer) seq.getObjectAt(5); - privSpec = new DSAPrivateKeySpec(x.getValue(), p.getValue(), q.getValue(), g.getValue()); - pubSpec = new DSAPublicKeySpec(y.getValue(), p.getValue(), q.getValue(), g.getValue()); + KeySpec pubSpec; KeySpec privSpec; ASN1Sequence seq; + switch (type) { + case RSA: + seq = (ASN1Sequence) keyInfo.parsePrivateKey(); + ASN1Integer mod = (ASN1Integer) seq.getObjectAt(1); + ASN1Integer pubExp = (ASN1Integer) seq.getObjectAt(2); + ASN1Integer privExp = (ASN1Integer) seq.getObjectAt(3); + ASN1Integer p1 = (ASN1Integer) seq.getObjectAt(4); + ASN1Integer p2 = (ASN1Integer) seq.getObjectAt(5); + ASN1Integer exp1 = (ASN1Integer) seq.getObjectAt(6); + ASN1Integer exp2 = (ASN1Integer) seq.getObjectAt(7); + ASN1Integer crtCoef = (ASN1Integer) seq.getObjectAt(8); + pubSpec = new RSAPublicKeySpec(mod.getValue(), pubExp.getValue()); + privSpec = new RSAPrivateCrtKeySpec( + mod.getValue(), pubExp.getValue(), privExp.getValue(), + p1.getValue(), p2.getValue(), + exp1.getValue(), exp2.getValue(), crtCoef.getValue()); + break; + case DSA: + seq = (ASN1Sequence) keyInfo.parsePrivateKey(); + ASN1Integer p = (ASN1Integer) seq.getObjectAt(1); + ASN1Integer q = (ASN1Integer) seq.getObjectAt(2); + ASN1Integer g = (ASN1Integer) seq.getObjectAt(3); + ASN1Integer y = (ASN1Integer) seq.getObjectAt(4); + ASN1Integer x = (ASN1Integer) seq.getObjectAt(5); + privSpec = new DSAPrivateKeySpec(x.getValue(), p.getValue(), q.getValue(), g.getValue()); + pubSpec = new DSAPublicKeySpec(y.getValue(), p.getValue(), q.getValue(), g.getValue()); + break; + case EC: + return readECPrivateKey(SecurityHelper.getKeyFactory("EC"), keyInfo); + default: + throw new AssertionError("unexpected key type: " + type); } - else if ( type.equals("EC") ) { - return readECPrivateKey(SecurityHelper.getKeyFactory("EC"), seq); - } - else { - throw new IllegalStateException("unsupported type: " + type); - } - KeyFactory fact = SecurityHelper.getKeyFactory(type); - return new KeyPair(fact.generatePublic(pubSpec), fact.generatePrivate(privSpec)); + final KeyFactory keyFactory = SecurityHelper.getKeyFactory(type.name()); + return new KeyPair(keyFactory.generatePublic(pubSpec), keyFactory.generatePrivate(privSpec)); } // d2i_PUBKEY_bio @@ -264,21 +267,23 @@ public static KeyPair readECPrivateKey(final byte[] input) public static KeyPair readECPrivateKey(final KeyFactory keyFactory, final byte[] input) throws IOException, InvalidKeySpecException { - return readECPrivateKey(keyFactory, (ASN1Sequence) ASN1Primitive.fromByteArray(input)); + return readECPrivateKey(keyFactory, mockPrivateKeyInfo(Type.EC, input)); } - public static KeyPair readECPrivateKey(final KeyFactory keyFactory, final ASN1Sequence input) + public static KeyPair readECPrivateKey(final KeyFactory keyFactory, final PrivateKeyInfo keyInfo) throws IOException, InvalidKeySpecException { try { - org.bouncycastle.asn1.sec.ECPrivateKey pKey = org.bouncycastle.asn1.sec.ECPrivateKey.getInstance(input); - AlgorithmIdentifier algId = new AlgorithmIdentifier(X9ObjectIdentifiers.id_ecPublicKey, pKey.getParametersObject().toASN1Primitive()); - PrivateKeyInfo privInfo = new PrivateKeyInfo(algId, pKey.toASN1Primitive()); - SubjectPublicKeyInfo pubInfo = new SubjectPublicKeyInfo(algId, pKey.getPublicKey().getBytes()); - PKCS8EncodedKeySpec privSpec = new PKCS8EncodedKeySpec(privInfo.getEncoded()); - X509EncodedKeySpec pubSpec = new X509EncodedKeySpec(pubInfo.getEncoded()); + org.bouncycastle.asn1.sec.ECPrivateKey key = org.bouncycastle.asn1.sec.ECPrivateKey.getInstance(keyInfo.parsePrivateKey()); + AlgorithmIdentifier algId = keyInfo.getPrivateKeyAlgorithm(); + // NOTE: should only happen when using mockPrivateKeyInfo(Type, byte[]) + if (algId == null) algId = new AlgorithmIdentifier(X9ObjectIdentifiers.id_ecPublicKey); + + SubjectPublicKeyInfo pubInfo = new SubjectPublicKeyInfo(algId, key.getPublicKey().getBytes()); + PKCS8EncodedKeySpec privSpec = new PKCS8EncodedKeySpec(keyInfo.getEncoded()); + X509EncodedKeySpec pubSpec = new X509EncodedKeySpec(pubInfo.getEncoded()); ECPrivateKey privateKey = (ECPrivateKey) keyFactory.generatePrivate(privSpec); - if ( algId.getParameters() instanceof ASN1ObjectIdentifier ) { + if (algId.getParameters() instanceof ASN1ObjectIdentifier) { privateKey = ECPrivateKeyWithName.wrap(privateKey, (ASN1ObjectIdentifier) algId.getParameters()); } return new KeyPair(keyFactory.generatePublic(pubSpec), privateKey); @@ -311,10 +316,11 @@ public static byte[] toDerRSAKey(RSAPublicKey pubKey, RSAPrivateCrtKey privKey) return new DERSequence(vec).toASN1Primitive().getEncoded(ASN1Encoding.DER); } - public static ASN1Sequence toASN1Primitive(RSAPublicKey pubKey) { + public static ASN1Sequence toASN1Primitive(final RSAPublicKey publicKey) { + assert publicKey != null : "null public key"; ASN1EncodableVector vec = new ASN1EncodableVector(); - vec.add(new ASN1Integer(pubKey.getModulus())); - vec.add(new ASN1Integer(pubKey.getPublicExponent())); + vec.add(new ASN1Integer(publicKey.getModulus())); + vec.add(new ASN1Integer(publicKey.getPublicExponent())); return new DERSequence(vec); } diff --git a/src/main/java/org/jruby/ext/openssl/x509store/PEMInputOutput.java b/src/main/java/org/jruby/ext/openssl/x509store/PEMInputOutput.java index 99759f7a..b515a9d8 100644 --- a/src/main/java/org/jruby/ext/openssl/x509store/PEMInputOutput.java +++ b/src/main/java/org/jruby/ext/openssl/x509store/PEMInputOutput.java @@ -81,12 +81,10 @@ import javax.crypto.spec.PBEParameterSpec; import javax.crypto.spec.SecretKeySpec; -import org.bouncycastle.asn1.ASN1Encoding; import org.bouncycastle.asn1.ASN1TaggedObject; import org.bouncycastle.asn1.pkcs.PKCS12PBEParams; import org.bouncycastle.asn1.pkcs.EncryptedPrivateKeyInfo; import org.bouncycastle.asn1.ASN1Encodable; -import org.bouncycastle.asn1.ASN1Object; import org.bouncycastle.asn1.ASN1InputStream; import org.bouncycastle.asn1.ASN1OctetString; import org.bouncycastle.asn1.ASN1OutputStream; @@ -129,6 +127,7 @@ import org.jruby.ext.openssl.Cipher.Algorithm; import org.jruby.ext.openssl.impl.ASN1Registry; import org.jruby.ext.openssl.impl.CipherSpec; +import org.jruby.ext.openssl.impl.PKey.Type; import org.jruby.ext.openssl.impl.PKCS10Request; import org.jruby.ext.openssl.SecurityHelper; import org.jruby.ext.openssl.util.ByteArrayOutputStream; @@ -349,10 +348,9 @@ else if ( line.indexOf(BEG_STRING_ECPRIVATEKEY) != -1) { else if ( line.indexOf(BEG_STRING_PKCS8INF) != -1) { try { byte[] bytes = readBase64Bytes(reader, BEF_E + PEM_STRING_PKCS8INF); - PrivateKeyInfo pInfo = PrivateKeyInfo.getInstance(bytes); - KeyFactory keyFactory = getKeyFactory( pInfo.getPrivateKeyAlgorithm() ); - PrivateKey pKey = keyFactory.generatePrivate(new PKCS8EncodedKeySpec(pInfo.getEncoded())); - return new KeyPair(null, pKey); + final PrivateKeyInfo keyInfo = PrivateKeyInfo.getInstance(bytes); + final Type type = getPrivateKeyType(keyInfo.getPrivateKeyAlgorithm()); + return org.jruby.ext.openssl.impl.PKey.readPrivateKey(type, keyInfo); } catch (Exception e) { throw mapReadException("problem creating private key: ", e); @@ -1270,7 +1268,7 @@ else if ( line.contains(endMarker) ) { } else { keyBytes = decoded; } - return org.jruby.ext.openssl.impl.PKey.readPrivateKey(keyBytes, type); + return org.jruby.ext.openssl.impl.PKey.readPrivateKey(Type.valueOf(type), keyBytes); } private static byte[] decrypt(byte[] decoded, String dekInfo, char[] passwd) @@ -1486,23 +1484,23 @@ private static CMSSignedData readPKCS7(BufferedReader in, char[] p, String endMa public static KeyFactory getKeyFactory(final AlgorithmIdentifier algId) throws NoSuchAlgorithmException { - return SecurityHelper.getKeyFactory(getPrivateKeyType(algId)); + return SecurityHelper.getKeyFactory(getPrivateKeyType(algId).name()); } - private static String getPrivateKeyType(final AlgorithmIdentifier algId) { + private static Type getPrivateKeyType(final AlgorithmIdentifier algId) { final ASN1ObjectIdentifier algIdentifier = algId.getAlgorithm(); if (X9ObjectIdentifiers.id_ecPublicKey.equals(algIdentifier)) { - return "EC"; + return Type.EC; } if (PKCSObjectIdentifiers.rsaEncryption.equals(algIdentifier)) { - return "RSA"; + return Type.RSA; } if (X9ObjectIdentifiers.id_dsa.equals(algIdentifier)) { - return "DSA"; + return Type.DSA; } - return algIdentifier.getId(); + return Type.valueOf(algIdentifier.getId()); } private static CertificateFactory getX509CertificateFactory() {