Skip to content

Commit

Permalink
[feat] try resolving curve-name from EC public key
Browse files Browse the repository at this point in the history
  • Loading branch information
kares committed Feb 13, 2024
1 parent a59ebbf commit 831100e
Showing 1 changed file with 35 additions and 14 deletions.
49 changes: 35 additions & 14 deletions src/main/java/org/jruby/ext/openssl/PKeyEC.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import java.io.StringReader;
import java.io.StringWriter;
import java.math.BigInteger;
import java.security.AlgorithmParameters;
import java.security.GeneralSecurityException;
import java.security.InvalidKeyException;
import java.security.KeyFactory;
Expand All @@ -36,6 +37,7 @@
import java.util.Collections;
import java.util.Enumeration;
import java.util.List;
import java.util.Optional;
import javax.crypto.KeyAgreement;
import org.bouncycastle.asn1.ASN1EncodableVector;
import org.bouncycastle.asn1.ASN1Encoding;
Expand Down Expand Up @@ -82,7 +84,6 @@
import static org.jruby.ext.openssl.impl.PKey.readECPrivateKey;
import org.jruby.ext.openssl.util.ByteArrayOutputStream;
import org.jruby.ext.openssl.x509store.PEMInputOutput;
import org.jruby.util.ByteList;

/**
* OpenSSL::PKey::EC implementation.
Expand Down Expand Up @@ -170,24 +171,21 @@ public static RubyArray builtin_curves(ThreadContext context, IRubyObject self)
return curves;
}

private static ASN1ObjectIdentifier getCurveOID(final String curveName) {
private static Optional<ASN1ObjectIdentifier> getCurveOID(final String curveName) {
ASN1ObjectIdentifier id;
id = org.bouncycastle.asn1.sec.SECNamedCurves.getOID(curveName);
if ( id != null ) return id;
if ( id != null ) return Optional.of(id);
id = org.bouncycastle.asn1.x9.X962NamedCurves.getOID(curveName);
if ( id != null ) return id;
if ( id != null ) return Optional.of(id);
id = org.bouncycastle.asn1.nist.NISTNamedCurves.getOID(curveName);
if ( id != null ) return id;
if ( id != null ) return Optional.of(id);
id = org.bouncycastle.asn1.teletrust.TeleTrusTNamedCurves.getOID(curveName);
if ( id != null ) return id;
throw new IllegalStateException("could not identify curve name: " + curveName);
if ( id != null ) return Optional.of(id);
return Optional.empty();
}

private static boolean isCurveName(final String curveName) {
try {
return getCurveOID(curveName) != null;
}
catch (IllegalStateException ex) { return false; }
return getCurveOID(curveName).isPresent();
}

private static String getCurveName(final ASN1ObjectIdentifier oid) {
Expand Down Expand Up @@ -365,14 +363,18 @@ else if ( key instanceof ECPrivateKey ) {
setPrivateKey((ECPrivateKey) key);
}
else if ( key instanceof ECPublicKey ) {
this.publicKey = (ECPublicKey) key; this.privateKey = null;
this.publicKey = (ECPublicKey) key;
this.privateKey = null;
}
else {
throw newECError(runtime, "Neither PUB key nor PRIV key: " + key.getClass().getName());
}

if ( publicKey != null ) {
publicKey.getParams().getCurve();
final String oid = getCurveNameObjectIdFromKey(context, publicKey);
if (isCurveName(oid)) {
this.curveName = getCurveName(new ASN1ObjectIdentifier(oid));
}
}

return this;
Expand All @@ -391,6 +393,23 @@ private void unwrapPrivateKeyWithName() {
}
}

private static String getCurveNameObjectIdFromKey(final ThreadContext context, final ECPublicKey key) {
try {
AlgorithmParameters algParams = AlgorithmParameters.getInstance("EC");
algParams.init(key.getParams());
return algParams.getParameterSpec(ECGenParameterSpec.class).getName();
}
catch (NoSuchAlgorithmException ex) {
throw newECError(context.runtime, ex.getMessage());
}
catch (InvalidParameterSpecException ex) {
throw newECError(context.runtime, ex.toString());
}
catch (Exception ex) {
throw (RaiseException) newECError(context.runtime, ex.toString()).initCause(ex);
}
}

private void setGroup(final Group group) {
this.group = group;
this.curveName = this.group.getCurveName();
Expand Down Expand Up @@ -806,7 +825,9 @@ public RubyString to_pem(final ThreadContext context, final IRubyObject[] args)

try {
final StringWriter writer = new StringWriter();
PEMInputOutput.writeECParameters(writer, getCurveOID(getCurveName()), spec, passwd);
final ASN1ObjectIdentifier oid = getCurveOID(getCurveName())
.orElseThrow(() -> newECError(context.runtime, "invalid curve name: " + getCurveName()));
PEMInputOutput.writeECParameters(writer, oid, spec, passwd);
return RubyString.newString(context.runtime, writer.getBuffer());
}
catch (IOException ex) {
Expand Down

0 comments on commit 831100e

Please sign in to comment.