Skip to content

Commit

Permalink
[fix] RSA private key should generate after set_key
Browse files Browse the repository at this point in the history
far from complete but should improve RSA key behavior
  • Loading branch information
kares committed Dec 6, 2024
1 parent 6a0bc28 commit a014686
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 60 deletions.
24 changes: 0 additions & 24 deletions lib/jopenssl/_compat23.rb
Original file line number Diff line number Diff line change
Expand Up @@ -42,30 +42,6 @@ def set_pqg(p, q, g)

end

class RSA

def set_key(n, e, d)
self.n = n
self.e = e
self.d = d
self
end

def set_factors(p, q)
self.p = p
self.q = q
self
end

def set_crt_params(dmp1, dmq1, iqmp)
self.dmp1 = dmp1
self.dmq1 = dmq1
self.iqmp = iqmp
self
end

end

end

end
107 changes: 73 additions & 34 deletions src/main/java/org/jruby/ext/openssl/PKeyRSA.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,13 @@
import java.security.NoSuchAlgorithmException;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.SecureRandom;
import java.security.interfaces.RSAPrivateCrtKey;
import java.security.interfaces.RSAPrivateKey;
import java.security.interfaces.RSAPublicKey;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.RSAKeyGenParameterSpec;
import java.security.spec.RSAPrivateCrtKeySpec;
import java.security.spec.RSAPrivateKeySpec;
import java.security.spec.RSAPublicKeySpec;

import static javax.crypto.Cipher.*;
Expand Down Expand Up @@ -108,7 +109,11 @@ public static RaiseException newRSAError(Ruby runtime, String message) {
}

static RaiseException newRSAError(Ruby runtime, Throwable cause) {
return Utils.newError(runtime, _PKey(runtime).getClass("RSAError"), cause.getMessage(), cause);
return newRSAError(runtime, cause.getMessage(), cause);
}

static RaiseException newRSAError(Ruby runtime, String message, Throwable cause) {
return Utils.newError(runtime, _PKey(runtime).getClass("RSAError"), message, cause);
}

public PKeyRSA(Ruby runtime, RubyClass type) {
Expand All @@ -126,7 +131,7 @@ public PKeyRSA(Ruby runtime, RubyClass type, RSAPrivateCrtKey privKey, RSAPublic
}

private volatile RSAPublicKey publicKey;
private volatile transient RSAPrivateCrtKey privateKey;
private volatile transient RSAPrivateKey privateKey;

// fields to hold individual RSAPublicKeySpec components. this allows
// a public key to be constructed incrementally, as required by the
Expand Down Expand Up @@ -317,8 +322,9 @@ public IRubyObject initialize(final ThreadContext context, final IRubyObject[] a
}
else if ( key instanceof RSAPrivateCrtKey ) {
this.privateKey = (RSAPrivateCrtKey) key;
BigInteger exponent = ((RSAPrivateCrtKey) key).getPublicExponent();
try {
this.publicKey = (RSAPublicKey) rsaFactory.generatePublic(new RSAPublicKeySpec(privateKey.getModulus(), privateKey.getPublicExponent()));
this.publicKey = (RSAPublicKey) rsaFactory.generatePublic(new RSAPublicKeySpec(privateKey.getModulus(), exponent));
} catch (GeneralSecurityException e) {
throw newRSAError(runtime, e.getMessage());
} catch (RuntimeException e) {
Expand Down Expand Up @@ -355,7 +361,7 @@ public RubyBoolean private_p() {
public RubyString to_der() {
final byte[] bytes;
try {
bytes = toDerRSAKey(publicKey, privateKey);
bytes = toDerRSAKey(publicKey, privateKey instanceof RSAPrivateCrtKey ? (RSAPrivateCrtKey) privateKey : null);
}
catch (NoClassDefFoundError e) {
throw newRSAError(getRuntime(), bcExceptionMessage(e));
Expand All @@ -380,7 +386,8 @@ public PKeyRSA public_key() {
public IRubyObject params(final ThreadContext context) {
final Ruby runtime = context.runtime;
RubyHash hash = RubyHash.newHash(runtime);
if ( privateKey != null ) {
if (privateKey instanceof RSAPrivateCrtKey) {
RSAPrivateCrtKey privateKey = (RSAPrivateCrtKey) this.privateKey;
hash.op_aset(context, runtime.newString("iqmp"), BN.newBN(runtime, privateKey.getCrtCoefficient()));
hash.op_aset(context, runtime.newString("n"), BN.newBN(runtime, privateKey.getModulus()));
hash.op_aset(context, runtime.newString("d"), BN.newBN(runtime, privateKey.getPrivateExponent()));
Expand All @@ -406,7 +413,8 @@ public IRubyObject params(final ThreadContext context) {
@JRubyMethod
public RubyString to_text() {
StringBuilder result = new StringBuilder();
if (privateKey != null) {
if (privateKey instanceof RSAPrivateCrtKey) {
RSAPrivateCrtKey privateKey = (RSAPrivateCrtKey) this.privateKey;
int len = privateKey.getModulus().bitLength();
result.append("Private-Key: (").append(len).append(" bit)").append('\n');
result.append("modulus:");
Expand Down Expand Up @@ -446,8 +454,8 @@ public RubyString to_pem(ThreadContext context, final IRubyObject[] args) {

try {
final StringWriter writer = new StringWriter();
if ( privateKey != null ) {
PEMInputOutput.writeRSAPrivateKey(writer, privateKey, spec, passwd);
if (privateKey instanceof RSAPrivateCrtKey) {
PEMInputOutput.writeRSAPrivateKey(writer, (RSAPrivateCrtKey) privateKey, spec, passwd);
}
else {
PEMInputOutput.writeRSAPublicKey(writer, publicKey);
Expand Down Expand Up @@ -603,8 +611,8 @@ public synchronized IRubyObject set_iqmp(final ThreadContext context, IRubyObjec
@JRubyMethod(name="iqmp")
public synchronized IRubyObject get_iqmp() {
BigInteger iqmp;
if (privateKey != null) {
iqmp = privateKey.getCrtCoefficient();
if (privateKey instanceof RSAPrivateCrtKey) {
iqmp = ((RSAPrivateCrtKey) privateKey).getCrtCoefficient();
} else {
iqmp = rsa_iqmp;
}
Expand All @@ -617,8 +625,8 @@ public synchronized IRubyObject get_iqmp() {
@JRubyMethod(name="dmp1")
public synchronized IRubyObject get_dmp1() {
BigInteger dmp1;
if (privateKey != null) {
dmp1 = privateKey.getPrimeExponentP();
if (privateKey instanceof RSAPrivateCrtKey) {
dmp1 = ((RSAPrivateCrtKey) privateKey).getPrimeExponentP();
} else {
dmp1 = rsa_dmp1;
}
Expand All @@ -631,8 +639,8 @@ public synchronized IRubyObject get_dmp1() {
@JRubyMethod(name="dmq1")
public synchronized IRubyObject get_dmq1() {
BigInteger dmq1;
if (privateKey != null) {
dmq1 = privateKey.getPrimeExponentQ();
if (privateKey instanceof RSAPrivateCrtKey) {
dmq1 = ((RSAPrivateCrtKey) privateKey).getPrimeExponentQ();
} else {
dmq1 = rsa_dmq1;
}
Expand All @@ -659,8 +667,8 @@ public synchronized IRubyObject get_d() {
@JRubyMethod(name="p")
public synchronized IRubyObject get_p() {
BigInteger p;
if (privateKey != null) {
p = privateKey.getPrimeP();
if (privateKey instanceof RSAPrivateCrtKey) {
p = ((RSAPrivateCrtKey) privateKey).getPrimeP();
} else {
p = rsa_p;
}
Expand All @@ -673,8 +681,8 @@ public synchronized IRubyObject get_p() {
@JRubyMethod(name="q")
public synchronized IRubyObject get_q() {
BigInteger q;
if (privateKey != null) {
q = privateKey.getPrimeQ();
if (privateKey instanceof RSAPrivateCrtKey) {
q = ((RSAPrivateCrtKey) privateKey).getPrimeQ();
} else {
q = rsa_q;
}
Expand All @@ -687,8 +695,8 @@ public synchronized IRubyObject get_q() {
private BigInteger getPublicExponent() {
if (publicKey != null) {
return publicKey.getPublicExponent();
} else if (privateKey != null) {
return privateKey.getPublicExponent();
} else if (privateKey instanceof RSAPrivateCrtKey) {
return ((RSAPrivateCrtKey) privateKey).getPublicExponent();
} else {
return rsa_e;
}
Expand Down Expand Up @@ -750,6 +758,32 @@ public synchronized IRubyObject set_n(final ThreadContext context, IRubyObject v
return value;
}

@JRubyMethod
public IRubyObject set_key(final ThreadContext context, IRubyObject n, IRubyObject e, IRubyObject d) {
this.rsa_n = BN.getBigInteger(n);
this.rsa_e = BN.getBigInteger(e);
this.rsa_d = BN.getBigInteger(d);
generatePrivateKeyIfParams(context);
return this;
}

@JRubyMethod
public IRubyObject set_factors(final ThreadContext context, IRubyObject p, IRubyObject q) {
this.rsa_p = BN.getBigInteger(p);
this.rsa_q = BN.getBigInteger(q);
generatePrivateKeyIfParams(context);
return this;
}

@JRubyMethod
public IRubyObject set_crt_params(final ThreadContext context, IRubyObject dmp1, IRubyObject dmq1, IRubyObject iqmp) {
this.rsa_dmp1 = BN.asBigInteger(dmp1);
this.rsa_dmq1 = BN.asBigInteger(dmq1);
this.rsa_iqmp = BN.asBigInteger(iqmp);
generatePrivateKeyIfParams(context);
return this;
}

private void generatePublicKeyIfParams(final ThreadContext context) {
final Ruby runtime = context.runtime;

Expand Down Expand Up @@ -783,14 +817,12 @@ private void generatePublicKeyIfParams(final ThreadContext context) {
private void generatePrivateKeyIfParams(final ThreadContext context) {
final Ruby runtime = context.runtime;

if ( privateKey != null ) throw newRSAError(runtime, "illegal modification");

// Don't access the rsa_n and rsa_e fields directly. They may have
// already been consumed and cleared by generatePublicKeyIfParams.
BigInteger _rsa_n = getModulus();
BigInteger _rsa_e = getPublicExponent();

if (_rsa_n != null && _rsa_e != null && rsa_p != null && rsa_q != null && rsa_d != null && rsa_dmp1 != null && rsa_dmq1 != null && rsa_iqmp != null) {
if (_rsa_n != null && _rsa_e != null && rsa_d != null) {
final KeyFactory rsaFactory;
try {
rsaFactory = SecurityHelper.getKeyFactory("RSA");
Expand All @@ -799,17 +831,24 @@ private void generatePrivateKeyIfParams(final ThreadContext context) {
throw runtime.newLoadError("unsupported key algorithm (RSA)");
}

try {
privateKey = (RSAPrivateCrtKey) rsaFactory.generatePrivate(
new RSAPrivateCrtKeySpec(_rsa_n, _rsa_e, rsa_d, rsa_p, rsa_q, rsa_dmp1, rsa_dmq1, rsa_iqmp)
);
}
catch (InvalidKeySpecException e) {
throw newRSAError(runtime, "invalid parameters");
if (rsa_p != null && rsa_q != null && rsa_dmp1 != null && rsa_dmq1 != null && rsa_iqmp != null) {
try {
privateKey = (RSAPrivateCrtKey) rsaFactory.generatePrivate(
new RSAPrivateCrtKeySpec(_rsa_n, _rsa_e, rsa_d, rsa_p, rsa_q, rsa_dmp1, rsa_dmq1, rsa_iqmp)
);
} catch (InvalidKeySpecException e) {
throw newRSAError(runtime, "invalid parameters", e);
}
rsa_n = null; rsa_e = null; rsa_d = null;
rsa_p = null; rsa_q = null;
rsa_dmp1 = null; rsa_dmq1 = null; rsa_iqmp = null;
} else {
try {
privateKey = (RSAPrivateKey) rsaFactory.generatePrivate(new RSAPrivateKeySpec(_rsa_n, rsa_d));
} catch (InvalidKeySpecException e) {
throw newRSAError(runtime, "invalid parameters", e);
}
}
rsa_n = null; rsa_e = null;
rsa_d = null; rsa_p = null; rsa_q = null;
rsa_dmp1 = null; rsa_dmq1 = null; rsa_iqmp = null;
}
}

Expand Down
40 changes: 38 additions & 2 deletions src/test/ruby/rsa/test_rsa.rb
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,37 @@ def setup
require 'base64'
end

def test_private
# Generated by key size and public exponent
key = OpenSSL::PKey::RSA.new(512, 3)
assert(key.private?)

# Generated by DER
key2 = OpenSSL::PKey::RSA.new(key.to_der)
assert(key2.private?)

# public key
key3 = key.public_key
assert(!key3.private?)

# Generated by public key DER
key4 = OpenSSL::PKey::RSA.new(key3.to_der)
assert(!key4.private?)
rsa1024 = Fixtures.pkey("rsa1024")

#if !openssl?(3, 0, 0)
# Generated by RSA#set_key
key5 = OpenSSL::PKey::RSA.new
key5.set_key(rsa1024.n, rsa1024.e, rsa1024.d)
assert(key5.private?)

# Generated by RSA#set_key, without d
key6 = OpenSSL::PKey::RSA.new
key6.set_key(rsa1024.n, rsa1024.e, nil)
assert(!key6.private?)
#end
end

def test_oid
key = OpenSSL::PKey::RSA.new
assert_equal 'rsaEncryption', key.oid
Expand Down Expand Up @@ -72,9 +103,14 @@ def test_rsa_from_params_public_first
rsa = OpenSSL::PKey::RSA.new
rsa.e, rsa.n = key.e, key.n
assert_nothing_raised { rsa.public_encrypt('Test string') }
[:e, :n].each {|param| assert_equal(key.send(param), rsa.send(param)) }
[:e, :n].each { |param| assert_equal(key.send(param), rsa.send(param)) }

rsa = OpenSSL::PKey::RSA.new
rsa.set_key(key.n, key.e, key.d)
# rsa.d, rsa.p, rsa.q, rsa.iqmp, rsa.dmp1, rsa.dmq1 = key.d, key.p, key.q, key.iqmp, key.dmp1, key.dmq1
rsa.set_factors(key.p, key.q)
rsa.set_crt_params(key.dmp1, key.dmq1, key.iqmp)

rsa.d, rsa.p, rsa.q, rsa.iqmp, rsa.dmp1, rsa.dmq1 = key.d, key.p, key.q, key.iqmp, key.dmp1, key.dmq1
assert_nothing_raised { rsa.private_encrypt('Test string') }
[:e, :n, :d, :p, :q, :iqmp, :dmp1, :dmq1].each do |param|
assert_equal(key.send(param), rsa.send(param), param)
Expand Down

0 comments on commit a014686

Please sign in to comment.