Skip to content

Commit

Permalink
[feat] partial support for PKey::EC::Point#to_octet_string(form)
Browse files Browse the repository at this point in the history
  • Loading branch information
kares committed Apr 8, 2024
1 parent 04de740 commit 4234dd5
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 27 deletions.
144 changes: 119 additions & 25 deletions src/main/java/org/jruby/ext/openssl/PKeyEC.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import java.util.Collections;
import java.util.Enumeration;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import javax.crypto.KeyAgreement;
import org.bouncycastle.asn1.ASN1EncodableVector;
Expand Down Expand Up @@ -63,6 +64,7 @@
import org.jruby.RubyModule;
import org.jruby.RubyObject;
import org.jruby.RubyString;
import org.jruby.RubySymbol;
import org.jruby.anno.JRubyClass;
import org.jruby.anno.JRubyMethod;
import org.jruby.exceptions.RaiseException;
Expand Down Expand Up @@ -695,6 +697,14 @@ public RubyString to_pem(ThreadContext context, final IRubyObject[] args) {
}
}

private enum PointConversion {
COMPRESSED, UNCOMPRESSED, HYBRID;

String toRubyString() {
return super.toString().toLowerCase(Locale.ROOT);
}
}

@JRubyClass(name = "OpenSSL::PKey::EC::Group")
public static final class Group extends RubyObject {

Expand All @@ -713,6 +723,9 @@ static void createGroup(final Ruby runtime, final RubyClass EC, final RubyClass

private transient PKeyEC key;
private ECParameterSpec paramSpec;

private PointConversion conversionForm = PointConversion.UNCOMPRESSED;

private RubyString curve_name;

public Group(Ruby runtime, RubyClass type) {
Expand All @@ -725,11 +738,6 @@ public Group(Ruby runtime, RubyClass type) {
this.paramSpec = key.publicKey.getParams();
}

private String getCurveName() {
if (key != null) return key.getCurveName();
return curve_name.toString();
}

@JRubyMethod(rest = true, visibility = Visibility.PRIVATE)
public IRubyObject initialize(final ThreadContext context, final IRubyObject[] args) {
final Ruby runtime = context.runtime;
Expand All @@ -743,13 +751,20 @@ public IRubyObject initialize(final ThreadContext context, final IRubyObject[] a
return this;
}

this.curve_name = ((RubyString) arg);
this.curve_name = arg.convertToString();

// TODO PEM/DER parsing not implemented
final ECNamedCurveParameterSpec ecCurveParamSpec = ECNamedCurveTable.getParameterSpec(curve_name.toString());
final EllipticCurve curve = EC5Util.convertCurve(ecCurveParamSpec.getCurve(), ecCurveParamSpec.getSeed());
this.paramSpec = EC5Util.convertSpec(curve, ecCurveParamSpec);
}
return this;
}

private String getCurveName() {
if (key != null) return key.getCurveName();
return curve_name.toString();
}

@Override
@JRubyMethod(name = { "==", "eql?" })
public IRubyObject op_equal(final ThreadContext context, final IRubyObject obj) {
Expand Down Expand Up @@ -831,13 +846,35 @@ public RubyString to_pem(final ThreadContext context, final IRubyObject[] args)
}
}

final EllipticCurve getCurve() {
EllipticCurve getCurve() {
if (paramSpec == null) {
paramSpec = getParamSpec(getCurveName());
}
return paramSpec.getCurve();
}

@JRubyMethod
public RubySymbol point_conversion_form(final ThreadContext context) {
return context.runtime.newSymbol(this.conversionForm.toRubyString());
}

@JRubyMethod(name = "point_conversion_form=")
public IRubyObject set_point_conversion_form(final ThreadContext context, final IRubyObject form) {
this.conversionForm = parse_point_conversion_form(context.runtime, form);
return form;
}

static PointConversion parse_point_conversion_form(final Ruby runtime, final IRubyObject form) {
if (form instanceof RubySymbol) {
final String pointConversionForm = ((RubySymbol) form).asJavaString();
if ("uncompressed".equals(pointConversionForm)) return PointConversion.UNCOMPRESSED;
if ("compressed".equals(pointConversionForm)) return PointConversion.COMPRESSED;
if ("hybrid".equals(pointConversionForm)) return PointConversion.HYBRID;
}
throw runtime.newArgumentError("unsupported point conversion form: " + form.inspect());
}


// @Override
// @JRubyMethod
// @SuppressWarnings("unchecked")
Expand Down Expand Up @@ -874,14 +911,12 @@ public Point(Ruby runtime, RubyClass type) {
super(runtime, type);
}

// private transient ECPublicKey publicKey;
private ECPoint point;
//private int bitLength;
private Group group;

Point(Ruby runtime, ECPublicKey publicKey, Group group) {
this(runtime, _EC(runtime).getClass("Point"));
//this.publicKey = publicKey;
this.point = publicKey.getW();
this.group = group;
}
Expand Down Expand Up @@ -914,7 +949,12 @@ public IRubyObject initialize(final ThreadContext context, final IRubyObject[] a
this.group = (Group) arg;
}
if ( argc == 2 ) { // (group, bn)
final byte[] encoded = ((BN) args[1]).getValue().abs().toByteArray();
final byte[] encoded;
if (args[1] instanceof BN) {
encoded = ((BN) args[1]).getValue().abs().toByteArray();
} else {
encoded = args[1].convertToString().getBytes();
}
try {
this.point = ECPointUtil.decodePoint(group.getCurve(), encoded);
}
Expand Down Expand Up @@ -951,15 +991,54 @@ private ECPoint asECPoint() {
}

private int bitLength() {
assert group != null;
assert group.paramSpec != null;
return group.paramSpec.getOrder().bitLength();
}

private PointConversion getPointConversionForm() {
if (group == null) return null;
return group.conversionForm;
}

@JRubyMethod
public BN to_bn(final ThreadContext context) {
final byte[] encoded = encode(bitLength(), point);
return toBN(context, getPointConversionForm()); // group.point_conversion_form
}

@JRubyMethod
public BN to_bn(final ThreadContext context, final IRubyObject conversion_form) {
return toBN(context, Group.parse_point_conversion_form(context.runtime, conversion_form));
}

private BN toBN(final ThreadContext context, final PointConversion conversionForm) {
final byte[] encoded = encodePoint(conversionForm);
return BN.newBN(context.runtime, new BigInteger(1, encoded));
}

private byte[] encodePoint(final PointConversion conversionForm) {
final byte[] encoded;
switch (conversionForm) {
case UNCOMPRESSED:
encoded = encodeUncompressed(bitLength(), point);
break;
case COMPRESSED:
encoded = encodeCompressed(point);
break;
case HYBRID:
throw getRuntime().newNotImplementedError(":hybrid compression not implemented");
default:
throw new AssertionError("unexpected conversion form: " + conversionForm);
}
return encoded;
}

@JRubyMethod
public IRubyObject to_octet_string(final ThreadContext context, final IRubyObject conversion_form) {
final PointConversion conversionForm = Group.parse_point_conversion_form(context.runtime, conversion_form);
return StringHelper.newString(context.runtime, encodePoint(conversionForm));
}

private boolean isInfinity() {
return point == ECPoint.POINT_INFINITY;
}
Expand All @@ -986,34 +1065,49 @@ public IRubyObject inspect() {
}

static byte[] encode(final ECPublicKey pubKey) {
return encode(pubKey.getParams().getOrder().bitLength(), pubKey.getW());
return encodeUncompressed(pubKey.getParams().getOrder().bitLength(), pubKey.getW());
}

private static byte[] encode(final int bitLength, final ECPoint point) {
if ( point == ECPoint.POINT_INFINITY ) return new byte[1];
private static byte[] encodeUncompressed(final int fieldSize, final ECPoint point) {
if (point == ECPoint.POINT_INFINITY) return new byte[1];

final int bytesLength = (bitLength + 7) / 8;
byte[] encoded = new byte[1 + bytesLength + bytesLength];
final int expLength = (fieldSize + 7) / 8;

byte[] encoded = new byte[1 + expLength + expLength];

encoded[0] = 0x04;

addIntBytes(point.getAffineX(), expLength, encoded, 1);
addIntBytes(point.getAffineY(), expLength, encoded, 1 + expLength);

return encoded;
}

private static byte[] encodeCompressed(final ECPoint point) {
if (point == ECPoint.POINT_INFINITY) return new byte[1];

final int bytesLength = point.getAffineX().bitLength() / 8 + 1;

byte[] encoded = new byte[1 + bytesLength];

encoded[0] = (byte) (point.getAffineY().testBit(0) ? 0x03 : 0x02);

addIntBytes(point.getAffineX(), bytesLength, encoded, 1);
addIntBytes(point.getAffineY(), bytesLength, encoded, 1 + bytesLength);

return encoded;
}

private static void addIntBytes(BigInteger i, final int length, final byte[] dest, final int destOffset) {
final byte[] bytes = i.toByteArray();
private static void addIntBytes(final BigInteger value, final int length, final byte[] dest, final int destOffset) {
final byte[] in = value.toByteArray();

if (length < bytes.length) {
System.arraycopy(bytes, bytes.length - length, dest, destOffset, length);
if (length < in.length) {
System.arraycopy(in, in.length - length, dest, destOffset, length);
}
else if (length > bytes.length) {
System.arraycopy(bytes, 0, dest, destOffset + (length - bytes.length), bytes.length);
else if (length > in.length) {
System.arraycopy(in, 0, dest, destOffset + (length - in.length), in.length);
}
else {
System.arraycopy(bytes, 0, dest, destOffset, length);
System.arraycopy(in, 0, dest, destOffset, length);
}
}

Expand Down
31 changes: 29 additions & 2 deletions src/test/ruby/ec/test_ec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,35 @@ def test_point
client_public_key_bn = OpenSSL::BN.new('58089019511196532477248433747314139754458690644712400444716868601190212265537817278966641566813745621284958192417192818318052462970895792919572995957754854')

binary = "\x04U\x1D6|\xA9\x14\eC\x13\x99b\x96\x9B\x94f\x8F\xB0o\xE2\xD3\xBC%\x8E\xE0Xn\xF2|R\x99b\xBD\xBFB\x8FS\xCF\x13\x7F\x8C\x03N\x96\x9D&\xB2\xE1\xBDQ\b\xCE\x94!s\x06.\xC5?\x96\xC7q\xDA\x8B\xE6"
client_public_key = OpenSSL::PKey::EC::Point.new(group, client_public_key_bn)
assert_equal binary, client_public_key.to_bn.to_s(2)
point = OpenSSL::PKey::EC::Point.new(group, client_public_key_bn)
assert_equal binary, point.to_bn.to_s(2)
assert_equal binary, point.to_octet_string(:uncompressed)

point2 = OpenSSL::PKey::EC::Point.new(group, point.to_octet_string(:uncompressed))
assert_equal binary, point2.to_bn.to_s(2)

compressed = "\x02U\x1D6|\xA9\x14\eC\x13\x99b\x96\x9B\x94f\x8F\xB0o\xE2\xD3\xBC%\x8E\xE0Xn\xF2|R\x99b\xBD"
assert_equal compressed, point.to_octet_string(:compressed)

# TODO: not yet implemented
# hybrid = "\x06U\x1D6|\xA9\x14\eC\x13\x99b\x96\x9B\x94f\x8F\xB0o\xE2\xD3\xBC%\x8E\xE0Xn\xF2|R\x99b\xBD\xBFB\x8FS\xCF\x13\x7F\x8C\x03N\x96\x9D&\xB2\xE1\xBDQ\b\xCE\x94!s\x06.\xC5?\x96\xC7q\xDA\x8B\xE6"
# assert_equal hybrid, point.to_octet_string(:hybrid)
end

def test_random_point
group = OpenSSL::PKey::EC::Group.new("prime256v1")
key = OpenSSL::PKey::EC.generate(group)
point = key.public_key

point2 = OpenSSL::PKey::EC::Point.new(group, point.to_bn)
assert_equal point, point2
assert_equal point.to_bn, point2.to_bn
assert_equal point.to_octet_string(:uncompressed), point2.to_octet_string(:uncompressed)

point3 = OpenSSL::PKey::EC::Point.new(group, point.to_octet_string(:uncompressed))
assert_equal point, point3
assert_equal point.to_bn, point3.to_bn
assert_equal point.to_octet_string(:uncompressed), point3.to_octet_string(:uncompressed)
end

require File.expand_path('base64.rb', File.dirname(__FILE__))
Expand Down

0 comments on commit 4234dd5

Please sign in to comment.