diff --git a/src/main/java/org/jruby/ext/openssl/PKeyEC.java b/src/main/java/org/jruby/ext/openssl/PKeyEC.java index a70398a8..8f0c940c 100644 --- a/src/main/java/org/jruby/ext/openssl/PKeyEC.java +++ b/src/main/java/org/jruby/ext/openssl/PKeyEC.java @@ -11,6 +11,7 @@ import java.io.StringReader; import java.io.StringWriter; import java.math.BigInteger; +import java.security.AlgorithmParameters; import java.security.GeneralSecurityException; import java.security.InvalidKeyException; import java.security.KeyFactory; @@ -36,6 +37,7 @@ import java.util.Collections; import java.util.Enumeration; import java.util.List; +import java.util.Optional; import javax.crypto.KeyAgreement; import org.bouncycastle.asn1.ASN1EncodableVector; import org.bouncycastle.asn1.ASN1Encoding; @@ -82,7 +84,6 @@ import static org.jruby.ext.openssl.impl.PKey.readECPrivateKey; import org.jruby.ext.openssl.util.ByteArrayOutputStream; import org.jruby.ext.openssl.x509store.PEMInputOutput; -import org.jruby.util.ByteList; /** * OpenSSL::PKey::EC implementation. @@ -170,24 +171,21 @@ public static RubyArray builtin_curves(ThreadContext context, IRubyObject self) return curves; } - private static ASN1ObjectIdentifier getCurveOID(final String curveName) { + private static Optional getCurveOID(final String curveName) { ASN1ObjectIdentifier id; id = org.bouncycastle.asn1.sec.SECNamedCurves.getOID(curveName); - if ( id != null ) return id; + if ( id != null ) return Optional.of(id); id = org.bouncycastle.asn1.x9.X962NamedCurves.getOID(curveName); - if ( id != null ) return id; + if ( id != null ) return Optional.of(id); id = org.bouncycastle.asn1.nist.NISTNamedCurves.getOID(curveName); - if ( id != null ) return id; + if ( id != null ) return Optional.of(id); id = org.bouncycastle.asn1.teletrust.TeleTrusTNamedCurves.getOID(curveName); - if ( id != null ) return id; - throw new IllegalStateException("could not identify curve name: " + curveName); + if ( id != null ) return Optional.of(id); + return Optional.empty(); } private static boolean isCurveName(final String curveName) { - try { - return getCurveOID(curveName) != null; - } - catch (IllegalStateException ex) { return false; } + return getCurveOID(curveName).isPresent(); } private static String getCurveName(final ASN1ObjectIdentifier oid) { @@ -365,14 +363,18 @@ else if ( key instanceof ECPrivateKey ) { setPrivateKey((ECPrivateKey) key); } else if ( key instanceof ECPublicKey ) { - this.publicKey = (ECPublicKey) key; this.privateKey = null; + this.publicKey = (ECPublicKey) key; + this.privateKey = null; } else { throw newECError(runtime, "Neither PUB key nor PRIV key: " + key.getClass().getName()); } if ( publicKey != null ) { - publicKey.getParams().getCurve(); + final String oid = getCurveNameObjectIdFromKey(context, publicKey); + if (isCurveName(oid)) { + this.curveName = getCurveName(new ASN1ObjectIdentifier(oid)); + } } return this; @@ -391,6 +393,23 @@ private void unwrapPrivateKeyWithName() { } } + private static String getCurveNameObjectIdFromKey(final ThreadContext context, final ECPublicKey key) { + try { + AlgorithmParameters algParams = AlgorithmParameters.getInstance("EC"); + algParams.init(key.getParams()); + return algParams.getParameterSpec(ECGenParameterSpec.class).getName(); + } + catch (NoSuchAlgorithmException ex) { + throw newECError(context.runtime, ex.getMessage()); + } + catch (InvalidParameterSpecException ex) { + throw newECError(context.runtime, ex.toString()); + } + catch (Exception ex) { + throw (RaiseException) newECError(context.runtime, ex.toString()).initCause(ex); + } + } + private void setGroup(final Group group) { this.group = group; this.curveName = this.group.getCurveName(); @@ -806,7 +825,9 @@ public RubyString to_pem(final ThreadContext context, final IRubyObject[] args) try { final StringWriter writer = new StringWriter(); - PEMInputOutput.writeECParameters(writer, getCurveOID(getCurveName()), spec, passwd); + final ASN1ObjectIdentifier oid = getCurveOID(getCurveName()) + .orElseThrow(() -> newECError(context.runtime, "invalid curve name: " + getCurveName())); + PEMInputOutput.writeECParameters(writer, oid, spec, passwd); return RubyString.newString(context.runtime, writer.getBuffer()); } catch (IOException ex) {