Skip to content

Commit

Permalink
[fix] revert readPrivateKey so public key is not lost (#292)
Browse files Browse the repository at this point in the history
  • Loading branch information
kares committed Apr 8, 2024
1 parent d0f11af commit e24d02f
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 65 deletions.
110 changes: 58 additions & 52 deletions src/main/java/org/jruby/ext/openssl/impl/PKey.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,10 @@
import java.security.interfaces.DSAPrivateKey;
import java.security.interfaces.DSAPublicKey;
import java.security.interfaces.ECPrivateKey;
import java.security.interfaces.ECPublicKey;
import java.security.interfaces.RSAPrivateCrtKey;
import java.security.interfaces.RSAPublicKey;
import java.security.spec.DSAPrivateKeySpec;
import java.security.spec.DSAPublicKeySpec;
import java.security.spec.ECParameterSpec;
import java.security.spec.ECPrivateKeySpec;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.KeySpec;
import java.security.spec.PKCS8EncodedKeySpec;
Expand All @@ -67,14 +64,10 @@
import org.bouncycastle.asn1.DLSequence;
import org.bouncycastle.asn1.pkcs.PKCSObjectIdentifiers;
import org.bouncycastle.asn1.pkcs.PrivateKeyInfo;
import org.bouncycastle.asn1.sec.ECPrivateKeyStructure;
import org.bouncycastle.asn1.x509.AlgorithmIdentifier;
import org.bouncycastle.asn1.x509.DSAParameter;
import org.bouncycastle.asn1.x509.SubjectPublicKeyInfo;
import org.bouncycastle.asn1.x9.X9ObjectIdentifiers;
import org.bouncycastle.jce.ECNamedCurveTable;
import org.bouncycastle.jce.spec.ECNamedCurveParameterSpec;
import org.bouncycastle.jce.spec.ECPublicKeySpec;
import org.bouncycastle.jcajce.provider.asymmetric.util.KeyUtil;

import org.jruby.ext.openssl.SecurityHelper;
Expand All @@ -87,44 +80,54 @@
*/
public class PKey {

public static KeyPair readPrivateKey(final byte[] input, final String type)
public enum Type { RSA, DSA, EC; }

public static KeyPair readPrivateKey(final Type type, final byte[] input)
throws IOException, NoSuchAlgorithmException, InvalidKeySpecException {
return readPrivateKey((ASN1Sequence) new ASN1InputStream(input).readObject(), type);
return readPrivateKey(type, mockPrivateKeyInfo(type, input));
}

private static PrivateKeyInfo mockPrivateKeyInfo(final Type type, final byte[] input) throws IOException {
return new PrivateKeyInfo(null, new ASN1InputStream(input).readObject());
}

public static KeyPair readPrivateKey(final ASN1Sequence seq, final String type)
public static KeyPair readPrivateKey(final Type type, final PrivateKeyInfo keyInfo)
throws IOException, NoSuchAlgorithmException, InvalidKeySpecException {
KeySpec pubSpec; KeySpec privSpec;
if ( type.equals("RSA") ) {
ASN1Integer mod = (ASN1Integer) seq.getObjectAt(1);
ASN1Integer pubExp = (ASN1Integer) seq.getObjectAt(2);
ASN1Integer privExp = (ASN1Integer) seq.getObjectAt(3);
ASN1Integer p1 = (ASN1Integer) seq.getObjectAt(4);
ASN1Integer p2 = (ASN1Integer) seq.getObjectAt(5);
ASN1Integer exp1 = (ASN1Integer) seq.getObjectAt(6);
ASN1Integer exp2 = (ASN1Integer) seq.getObjectAt(7);
ASN1Integer crtCoef = (ASN1Integer) seq.getObjectAt(8);
pubSpec = new RSAPublicKeySpec(mod.getValue(), pubExp.getValue());
privSpec = new RSAPrivateCrtKeySpec(mod.getValue(), pubExp.getValue(), privExp.getValue(), p1.getValue(), p2.getValue(), exp1.getValue(),
exp2.getValue(), crtCoef.getValue());
}
else if ( type.equals("DSA") ) {
ASN1Integer p = (ASN1Integer) seq.getObjectAt(1);
ASN1Integer q = (ASN1Integer) seq.getObjectAt(2);
ASN1Integer g = (ASN1Integer) seq.getObjectAt(3);
ASN1Integer y = (ASN1Integer) seq.getObjectAt(4);
ASN1Integer x = (ASN1Integer) seq.getObjectAt(5);
privSpec = new DSAPrivateKeySpec(x.getValue(), p.getValue(), q.getValue(), g.getValue());
pubSpec = new DSAPublicKeySpec(y.getValue(), p.getValue(), q.getValue(), g.getValue());
KeySpec pubSpec; KeySpec privSpec; ASN1Sequence seq;
switch (type) {
case RSA:
seq = (ASN1Sequence) keyInfo.parsePrivateKey();
ASN1Integer mod = (ASN1Integer) seq.getObjectAt(1);
ASN1Integer pubExp = (ASN1Integer) seq.getObjectAt(2);
ASN1Integer privExp = (ASN1Integer) seq.getObjectAt(3);
ASN1Integer p1 = (ASN1Integer) seq.getObjectAt(4);
ASN1Integer p2 = (ASN1Integer) seq.getObjectAt(5);
ASN1Integer exp1 = (ASN1Integer) seq.getObjectAt(6);
ASN1Integer exp2 = (ASN1Integer) seq.getObjectAt(7);
ASN1Integer crtCoef = (ASN1Integer) seq.getObjectAt(8);
pubSpec = new RSAPublicKeySpec(mod.getValue(), pubExp.getValue());
privSpec = new RSAPrivateCrtKeySpec(
mod.getValue(), pubExp.getValue(), privExp.getValue(),
p1.getValue(), p2.getValue(),
exp1.getValue(), exp2.getValue(), crtCoef.getValue());
break;
case DSA:
seq = (ASN1Sequence) keyInfo.parsePrivateKey();
ASN1Integer p = (ASN1Integer) seq.getObjectAt(1);
ASN1Integer q = (ASN1Integer) seq.getObjectAt(2);
ASN1Integer g = (ASN1Integer) seq.getObjectAt(3);
ASN1Integer y = (ASN1Integer) seq.getObjectAt(4);
ASN1Integer x = (ASN1Integer) seq.getObjectAt(5);
privSpec = new DSAPrivateKeySpec(x.getValue(), p.getValue(), q.getValue(), g.getValue());
pubSpec = new DSAPublicKeySpec(y.getValue(), p.getValue(), q.getValue(), g.getValue());
break;
case EC:
return readECPrivateKey(SecurityHelper.getKeyFactory("EC"), keyInfo);
default:
throw new AssertionError("unexpected key type: " + type);
}
else if ( type.equals("EC") ) {
return readECPrivateKey(SecurityHelper.getKeyFactory("EC"), seq);
}
else {
throw new IllegalStateException("unsupported type: " + type);
}
KeyFactory fact = SecurityHelper.getKeyFactory(type);
return new KeyPair(fact.generatePublic(pubSpec), fact.generatePrivate(privSpec));
final KeyFactory keyFactory = SecurityHelper.getKeyFactory(type.name());
return new KeyPair(keyFactory.generatePublic(pubSpec), keyFactory.generatePrivate(privSpec));
}

// d2i_PUBKEY_bio
Expand Down Expand Up @@ -264,21 +267,23 @@ public static KeyPair readECPrivateKey(final byte[] input)

public static KeyPair readECPrivateKey(final KeyFactory keyFactory, final byte[] input)
throws IOException, InvalidKeySpecException {
return readECPrivateKey(keyFactory, (ASN1Sequence) ASN1Primitive.fromByteArray(input));
return readECPrivateKey(keyFactory, mockPrivateKeyInfo(Type.EC, input));
}

public static KeyPair readECPrivateKey(final KeyFactory keyFactory, final ASN1Sequence input)
public static KeyPair readECPrivateKey(final KeyFactory keyFactory, final PrivateKeyInfo keyInfo)
throws IOException, InvalidKeySpecException {
try {
org.bouncycastle.asn1.sec.ECPrivateKey pKey = org.bouncycastle.asn1.sec.ECPrivateKey.getInstance(input);
AlgorithmIdentifier algId = new AlgorithmIdentifier(X9ObjectIdentifiers.id_ecPublicKey, pKey.getParametersObject().toASN1Primitive());
PrivateKeyInfo privInfo = new PrivateKeyInfo(algId, pKey.toASN1Primitive());
SubjectPublicKeyInfo pubInfo = new SubjectPublicKeyInfo(algId, pKey.getPublicKey().getBytes());
PKCS8EncodedKeySpec privSpec = new PKCS8EncodedKeySpec(privInfo.getEncoded());
X509EncodedKeySpec pubSpec = new X509EncodedKeySpec(pubInfo.getEncoded());
org.bouncycastle.asn1.sec.ECPrivateKey key = org.bouncycastle.asn1.sec.ECPrivateKey.getInstance(keyInfo.parsePrivateKey());
AlgorithmIdentifier algId = keyInfo.getPrivateKeyAlgorithm();
// NOTE: should only happen when using mockPrivateKeyInfo(Type, byte[])
if (algId == null) algId = new AlgorithmIdentifier(X9ObjectIdentifiers.id_ecPublicKey);

SubjectPublicKeyInfo pubInfo = new SubjectPublicKeyInfo(algId, key.getPublicKey().getBytes());
PKCS8EncodedKeySpec privSpec = new PKCS8EncodedKeySpec(keyInfo.getEncoded());
X509EncodedKeySpec pubSpec = new X509EncodedKeySpec(pubInfo.getEncoded());

ECPrivateKey privateKey = (ECPrivateKey) keyFactory.generatePrivate(privSpec);
if ( algId.getParameters() instanceof ASN1ObjectIdentifier ) {
if (algId.getParameters() instanceof ASN1ObjectIdentifier) {
privateKey = ECPrivateKeyWithName.wrap(privateKey, (ASN1ObjectIdentifier) algId.getParameters());
}
return new KeyPair(keyFactory.generatePublic(pubSpec), privateKey);
Expand Down Expand Up @@ -311,10 +316,11 @@ public static byte[] toDerRSAKey(RSAPublicKey pubKey, RSAPrivateCrtKey privKey)
return new DERSequence(vec).toASN1Primitive().getEncoded(ASN1Encoding.DER);
}

public static ASN1Sequence toASN1Primitive(RSAPublicKey pubKey) {
public static ASN1Sequence toASN1Primitive(final RSAPublicKey publicKey) {
assert publicKey != null : "null public key";
ASN1EncodableVector vec = new ASN1EncodableVector();
vec.add(new ASN1Integer(pubKey.getModulus()));
vec.add(new ASN1Integer(pubKey.getPublicExponent()));
vec.add(new ASN1Integer(publicKey.getModulus()));
vec.add(new ASN1Integer(publicKey.getPublicExponent()));
return new DERSequence(vec);
}

Expand Down
24 changes: 11 additions & 13 deletions src/main/java/org/jruby/ext/openssl/x509store/PEMInputOutput.java
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,10 @@
import javax.crypto.spec.PBEParameterSpec;
import javax.crypto.spec.SecretKeySpec;

import org.bouncycastle.asn1.ASN1Encoding;
import org.bouncycastle.asn1.ASN1TaggedObject;
import org.bouncycastle.asn1.pkcs.PKCS12PBEParams;
import org.bouncycastle.asn1.pkcs.EncryptedPrivateKeyInfo;
import org.bouncycastle.asn1.ASN1Encodable;
import org.bouncycastle.asn1.ASN1Object;
import org.bouncycastle.asn1.ASN1InputStream;
import org.bouncycastle.asn1.ASN1OctetString;
import org.bouncycastle.asn1.ASN1OutputStream;
Expand Down Expand Up @@ -129,6 +127,7 @@
import org.jruby.ext.openssl.Cipher.Algorithm;
import org.jruby.ext.openssl.impl.ASN1Registry;
import org.jruby.ext.openssl.impl.CipherSpec;
import org.jruby.ext.openssl.impl.PKey.Type;
import org.jruby.ext.openssl.impl.PKCS10Request;
import org.jruby.ext.openssl.SecurityHelper;
import org.jruby.ext.openssl.util.ByteArrayOutputStream;
Expand Down Expand Up @@ -349,10 +348,9 @@ else if ( line.indexOf(BEG_STRING_ECPRIVATEKEY) != -1) {
else if ( line.indexOf(BEG_STRING_PKCS8INF) != -1) {
try {
byte[] bytes = readBase64Bytes(reader, BEF_E + PEM_STRING_PKCS8INF);
PrivateKeyInfo pInfo = PrivateKeyInfo.getInstance(bytes);
KeyFactory keyFactory = getKeyFactory( pInfo.getPrivateKeyAlgorithm() );
PrivateKey pKey = keyFactory.generatePrivate(new PKCS8EncodedKeySpec(pInfo.getEncoded()));
return new KeyPair(null, pKey);
final PrivateKeyInfo keyInfo = PrivateKeyInfo.getInstance(bytes);
final Type type = getPrivateKeyType(keyInfo.getPrivateKeyAlgorithm());
return org.jruby.ext.openssl.impl.PKey.readPrivateKey(type, keyInfo);
}
catch (Exception e) {
throw mapReadException("problem creating private key: ", e);
Expand Down Expand Up @@ -1270,7 +1268,7 @@ else if ( line.contains(endMarker) ) {
} else {
keyBytes = decoded;
}
return org.jruby.ext.openssl.impl.PKey.readPrivateKey(keyBytes, type);
return org.jruby.ext.openssl.impl.PKey.readPrivateKey(Type.valueOf(type), keyBytes);
}

private static byte[] decrypt(byte[] decoded, String dekInfo, char[] passwd)
Expand Down Expand Up @@ -1486,23 +1484,23 @@ private static CMSSignedData readPKCS7(BufferedReader in, char[] p, String endMa

public static KeyFactory getKeyFactory(final AlgorithmIdentifier algId)
throws NoSuchAlgorithmException {
return SecurityHelper.getKeyFactory(getPrivateKeyType(algId));
return SecurityHelper.getKeyFactory(getPrivateKeyType(algId).name());
}

private static String getPrivateKeyType(final AlgorithmIdentifier algId) {
private static Type getPrivateKeyType(final AlgorithmIdentifier algId) {
final ASN1ObjectIdentifier algIdentifier = algId.getAlgorithm();

if (X9ObjectIdentifiers.id_ecPublicKey.equals(algIdentifier)) {
return "EC";
return Type.EC;
}
if (PKCSObjectIdentifiers.rsaEncryption.equals(algIdentifier)) {
return "RSA";
return Type.RSA;
}
if (X9ObjectIdentifiers.id_dsa.equals(algIdentifier)) {
return "DSA";
return Type.DSA;
}

return algIdentifier.getId();
return Type.valueOf(algIdentifier.getId());
}

private static CertificateFactory getX509CertificateFactory() {
Expand Down

0 comments on commit e24d02f

Please sign in to comment.