diff --git a/lib/jopenssl/_compat23.rb b/lib/jopenssl/_compat23.rb index 78f9d0ef..e080c840 100644 --- a/lib/jopenssl/_compat23.rb +++ b/lib/jopenssl/_compat23.rb @@ -42,30 +42,6 @@ def set_pqg(p, q, g) end - class RSA - - def set_key(n, e, d) - self.n = n - self.e = e - self.d = d - self - end - - def set_factors(p, q) - self.p = p - self.q = q - self - end - - def set_crt_params(dmp1, dmq1, iqmp) - self.dmp1 = dmp1 - self.dmq1 = dmq1 - self.iqmp = iqmp - self - end - - end - end end diff --git a/src/main/java/org/jruby/ext/openssl/PKeyRSA.java b/src/main/java/org/jruby/ext/openssl/PKeyRSA.java index afae1f2e..4bd9d796 100644 --- a/src/main/java/org/jruby/ext/openssl/PKeyRSA.java +++ b/src/main/java/org/jruby/ext/openssl/PKeyRSA.java @@ -40,12 +40,13 @@ import java.security.NoSuchAlgorithmException; import java.security.PrivateKey; import java.security.PublicKey; -import java.security.SecureRandom; import java.security.interfaces.RSAPrivateCrtKey; +import java.security.interfaces.RSAPrivateKey; import java.security.interfaces.RSAPublicKey; import java.security.spec.InvalidKeySpecException; import java.security.spec.RSAKeyGenParameterSpec; import java.security.spec.RSAPrivateCrtKeySpec; +import java.security.spec.RSAPrivateKeySpec; import java.security.spec.RSAPublicKeySpec; import static javax.crypto.Cipher.*; @@ -108,7 +109,11 @@ public static RaiseException newRSAError(Ruby runtime, String message) { } static RaiseException newRSAError(Ruby runtime, Throwable cause) { - return Utils.newError(runtime, _PKey(runtime).getClass("RSAError"), cause.getMessage(), cause); + return newRSAError(runtime, cause.getMessage(), cause); + } + + static RaiseException newRSAError(Ruby runtime, String message, Throwable cause) { + return Utils.newError(runtime, _PKey(runtime).getClass("RSAError"), message, cause); } public PKeyRSA(Ruby runtime, RubyClass type) { @@ -126,7 +131,7 @@ public PKeyRSA(Ruby runtime, RubyClass type, RSAPrivateCrtKey privKey, RSAPublic } private volatile RSAPublicKey publicKey; - private volatile transient RSAPrivateCrtKey privateKey; + private volatile transient RSAPrivateKey privateKey; // fields to hold individual RSAPublicKeySpec components. this allows // a public key to be constructed incrementally, as required by the @@ -317,8 +322,9 @@ public IRubyObject initialize(final ThreadContext context, final IRubyObject[] a } else if ( key instanceof RSAPrivateCrtKey ) { this.privateKey = (RSAPrivateCrtKey) key; + BigInteger exponent = ((RSAPrivateCrtKey) key).getPublicExponent(); try { - this.publicKey = (RSAPublicKey) rsaFactory.generatePublic(new RSAPublicKeySpec(privateKey.getModulus(), privateKey.getPublicExponent())); + this.publicKey = (RSAPublicKey) rsaFactory.generatePublic(new RSAPublicKeySpec(privateKey.getModulus(), exponent)); } catch (GeneralSecurityException e) { throw newRSAError(runtime, e.getMessage()); } catch (RuntimeException e) { @@ -355,7 +361,7 @@ public RubyBoolean private_p() { public RubyString to_der() { final byte[] bytes; try { - bytes = toDerRSAKey(publicKey, privateKey); + bytes = toDerRSAKey(publicKey, privateKey instanceof RSAPrivateCrtKey ? (RSAPrivateCrtKey) privateKey : null); } catch (NoClassDefFoundError e) { throw newRSAError(getRuntime(), bcExceptionMessage(e)); @@ -380,7 +386,8 @@ public PKeyRSA public_key() { public IRubyObject params(final ThreadContext context) { final Ruby runtime = context.runtime; RubyHash hash = RubyHash.newHash(runtime); - if ( privateKey != null ) { + if (privateKey instanceof RSAPrivateCrtKey) { + RSAPrivateCrtKey privateKey = (RSAPrivateCrtKey) this.privateKey; hash.op_aset(context, runtime.newString("iqmp"), BN.newBN(runtime, privateKey.getCrtCoefficient())); hash.op_aset(context, runtime.newString("n"), BN.newBN(runtime, privateKey.getModulus())); hash.op_aset(context, runtime.newString("d"), BN.newBN(runtime, privateKey.getPrivateExponent())); @@ -406,7 +413,8 @@ public IRubyObject params(final ThreadContext context) { @JRubyMethod public RubyString to_text() { StringBuilder result = new StringBuilder(); - if (privateKey != null) { + if (privateKey instanceof RSAPrivateCrtKey) { + RSAPrivateCrtKey privateKey = (RSAPrivateCrtKey) this.privateKey; int len = privateKey.getModulus().bitLength(); result.append("Private-Key: (").append(len).append(" bit)").append('\n'); result.append("modulus:"); @@ -446,8 +454,8 @@ public RubyString to_pem(ThreadContext context, final IRubyObject[] args) { try { final StringWriter writer = new StringWriter(); - if ( privateKey != null ) { - PEMInputOutput.writeRSAPrivateKey(writer, privateKey, spec, passwd); + if (privateKey instanceof RSAPrivateCrtKey) { + PEMInputOutput.writeRSAPrivateKey(writer, (RSAPrivateCrtKey) privateKey, spec, passwd); } else { PEMInputOutput.writeRSAPublicKey(writer, publicKey); @@ -603,8 +611,8 @@ public synchronized IRubyObject set_iqmp(final ThreadContext context, IRubyObjec @JRubyMethod(name="iqmp") public synchronized IRubyObject get_iqmp() { BigInteger iqmp; - if (privateKey != null) { - iqmp = privateKey.getCrtCoefficient(); + if (privateKey instanceof RSAPrivateCrtKey) { + iqmp = ((RSAPrivateCrtKey) privateKey).getCrtCoefficient(); } else { iqmp = rsa_iqmp; } @@ -617,8 +625,8 @@ public synchronized IRubyObject get_iqmp() { @JRubyMethod(name="dmp1") public synchronized IRubyObject get_dmp1() { BigInteger dmp1; - if (privateKey != null) { - dmp1 = privateKey.getPrimeExponentP(); + if (privateKey instanceof RSAPrivateCrtKey) { + dmp1 = ((RSAPrivateCrtKey) privateKey).getPrimeExponentP(); } else { dmp1 = rsa_dmp1; } @@ -631,8 +639,8 @@ public synchronized IRubyObject get_dmp1() { @JRubyMethod(name="dmq1") public synchronized IRubyObject get_dmq1() { BigInteger dmq1; - if (privateKey != null) { - dmq1 = privateKey.getPrimeExponentQ(); + if (privateKey instanceof RSAPrivateCrtKey) { + dmq1 = ((RSAPrivateCrtKey) privateKey).getPrimeExponentQ(); } else { dmq1 = rsa_dmq1; } @@ -659,8 +667,8 @@ public synchronized IRubyObject get_d() { @JRubyMethod(name="p") public synchronized IRubyObject get_p() { BigInteger p; - if (privateKey != null) { - p = privateKey.getPrimeP(); + if (privateKey instanceof RSAPrivateCrtKey) { + p = ((RSAPrivateCrtKey) privateKey).getPrimeP(); } else { p = rsa_p; } @@ -673,8 +681,8 @@ public synchronized IRubyObject get_p() { @JRubyMethod(name="q") public synchronized IRubyObject get_q() { BigInteger q; - if (privateKey != null) { - q = privateKey.getPrimeQ(); + if (privateKey instanceof RSAPrivateCrtKey) { + q = ((RSAPrivateCrtKey) privateKey).getPrimeQ(); } else { q = rsa_q; } @@ -687,8 +695,8 @@ public synchronized IRubyObject get_q() { private BigInteger getPublicExponent() { if (publicKey != null) { return publicKey.getPublicExponent(); - } else if (privateKey != null) { - return privateKey.getPublicExponent(); + } else if (privateKey instanceof RSAPrivateCrtKey) { + return ((RSAPrivateCrtKey) privateKey).getPublicExponent(); } else { return rsa_e; } @@ -750,6 +758,32 @@ public synchronized IRubyObject set_n(final ThreadContext context, IRubyObject v return value; } + @JRubyMethod + public IRubyObject set_key(final ThreadContext context, IRubyObject n, IRubyObject e, IRubyObject d) { + this.rsa_n = BN.getBigInteger(n); + this.rsa_e = BN.getBigInteger(e); + this.rsa_d = BN.getBigInteger(d); + generatePrivateKeyIfParams(context); + return this; + } + + @JRubyMethod + public IRubyObject set_factors(final ThreadContext context, IRubyObject p, IRubyObject q) { + this.rsa_p = BN.getBigInteger(p); + this.rsa_q = BN.getBigInteger(q); + generatePrivateKeyIfParams(context); + return this; + } + + @JRubyMethod + public IRubyObject set_crt_params(final ThreadContext context, IRubyObject dmp1, IRubyObject dmq1, IRubyObject iqmp) { + this.rsa_dmp1 = BN.asBigInteger(dmp1); + this.rsa_dmq1 = BN.asBigInteger(dmq1); + this.rsa_iqmp = BN.asBigInteger(iqmp); + generatePrivateKeyIfParams(context); + return this; + } + private void generatePublicKeyIfParams(final ThreadContext context) { final Ruby runtime = context.runtime; @@ -783,14 +817,12 @@ private void generatePublicKeyIfParams(final ThreadContext context) { private void generatePrivateKeyIfParams(final ThreadContext context) { final Ruby runtime = context.runtime; - if ( privateKey != null ) throw newRSAError(runtime, "illegal modification"); - // Don't access the rsa_n and rsa_e fields directly. They may have // already been consumed and cleared by generatePublicKeyIfParams. BigInteger _rsa_n = getModulus(); BigInteger _rsa_e = getPublicExponent(); - if (_rsa_n != null && _rsa_e != null && rsa_p != null && rsa_q != null && rsa_d != null && rsa_dmp1 != null && rsa_dmq1 != null && rsa_iqmp != null) { + if (_rsa_n != null && _rsa_e != null && rsa_d != null) { final KeyFactory rsaFactory; try { rsaFactory = SecurityHelper.getKeyFactory("RSA"); @@ -799,17 +831,24 @@ private void generatePrivateKeyIfParams(final ThreadContext context) { throw runtime.newLoadError("unsupported key algorithm (RSA)"); } - try { - privateKey = (RSAPrivateCrtKey) rsaFactory.generatePrivate( - new RSAPrivateCrtKeySpec(_rsa_n, _rsa_e, rsa_d, rsa_p, rsa_q, rsa_dmp1, rsa_dmq1, rsa_iqmp) - ); - } - catch (InvalidKeySpecException e) { - throw newRSAError(runtime, "invalid parameters"); + if (rsa_p != null && rsa_q != null && rsa_dmp1 != null && rsa_dmq1 != null && rsa_iqmp != null) { + try { + privateKey = (RSAPrivateCrtKey) rsaFactory.generatePrivate( + new RSAPrivateCrtKeySpec(_rsa_n, _rsa_e, rsa_d, rsa_p, rsa_q, rsa_dmp1, rsa_dmq1, rsa_iqmp) + ); + } catch (InvalidKeySpecException e) { + throw newRSAError(runtime, "invalid parameters", e); + } + rsa_n = null; rsa_e = null; rsa_d = null; + rsa_p = null; rsa_q = null; + rsa_dmp1 = null; rsa_dmq1 = null; rsa_iqmp = null; + } else { + try { + privateKey = (RSAPrivateKey) rsaFactory.generatePrivate(new RSAPrivateKeySpec(_rsa_n, rsa_d)); + } catch (InvalidKeySpecException e) { + throw newRSAError(runtime, "invalid parameters", e); + } } - rsa_n = null; rsa_e = null; - rsa_d = null; rsa_p = null; rsa_q = null; - rsa_dmp1 = null; rsa_dmq1 = null; rsa_iqmp = null; } } diff --git a/src/test/ruby/rsa/test_rsa.rb b/src/test/ruby/rsa/test_rsa.rb index 20f04eeb..efe9daa5 100644 --- a/src/test/ruby/rsa/test_rsa.rb +++ b/src/test/ruby/rsa/test_rsa.rb @@ -9,6 +9,37 @@ def setup require 'base64' end + def test_private + # Generated by key size and public exponent + key = OpenSSL::PKey::RSA.new(512, 3) + assert(key.private?) + + # Generated by DER + key2 = OpenSSL::PKey::RSA.new(key.to_der) + assert(key2.private?) + + # public key + key3 = key.public_key + assert(!key3.private?) + + # Generated by public key DER + key4 = OpenSSL::PKey::RSA.new(key3.to_der) + assert(!key4.private?) + rsa1024 = Fixtures.pkey("rsa1024") + + #if !openssl?(3, 0, 0) + # Generated by RSA#set_key + key5 = OpenSSL::PKey::RSA.new + key5.set_key(rsa1024.n, rsa1024.e, rsa1024.d) + assert(key5.private?) + + # Generated by RSA#set_key, without d + key6 = OpenSSL::PKey::RSA.new + key6.set_key(rsa1024.n, rsa1024.e, nil) + assert(!key6.private?) + #end + end + def test_oid key = OpenSSL::PKey::RSA.new assert_equal 'rsaEncryption', key.oid @@ -72,9 +103,14 @@ def test_rsa_from_params_public_first rsa = OpenSSL::PKey::RSA.new rsa.e, rsa.n = key.e, key.n assert_nothing_raised { rsa.public_encrypt('Test string') } - [:e, :n].each {|param| assert_equal(key.send(param), rsa.send(param)) } + [:e, :n].each { |param| assert_equal(key.send(param), rsa.send(param)) } + + rsa = OpenSSL::PKey::RSA.new + rsa.set_key(key.n, key.e, key.d) + # rsa.d, rsa.p, rsa.q, rsa.iqmp, rsa.dmp1, rsa.dmq1 = key.d, key.p, key.q, key.iqmp, key.dmp1, key.dmq1 + rsa.set_factors(key.p, key.q) + rsa.set_crt_params(key.dmp1, key.dmq1, key.iqmp) - rsa.d, rsa.p, rsa.q, rsa.iqmp, rsa.dmp1, rsa.dmq1 = key.d, key.p, key.q, key.iqmp, key.dmp1, key.dmq1 assert_nothing_raised { rsa.private_encrypt('Test string') } [:e, :n, :d, :p, :q, :iqmp, :dmp1, :dmq1].each do |param| assert_equal(key.send(param), rsa.send(param), param)