diff --git a/go/tdh2/tdh2/tdh2.go b/go/tdh2/tdh2/tdh2.go index 8eb3da8..7c31d98 100644 --- a/go/tdh2/tdh2/tdh2.go +++ b/go/tdh2/tdh2/tdh2.go @@ -26,10 +26,6 @@ func parseGroup(group string) (group.Group, error) { switch group { case nist.NewP256().String(): return nist.NewP256(), nil - case nist.NewP384().String(): - return nist.NewP384(), nil - case nist.NewP521().String(): - return nist.NewP521(), nil } return nil, fmt.Errorf("unsupported group: %q", group) } @@ -373,7 +369,7 @@ func VerifyShare(pk *PublicKey, ctxt *Ciphertext, share *DecryptionShare) error func checkEi(pk *PublicKey, ctxt *Ciphertext, share *DecryptionShare) error { g := pk.group ui_hat := g.Point().Sub(g.Point().Mul(share.f_i, ctxt.u), g.Point().Mul(share.e_i, share.u_i)) - if share.index >= len(pk.hArray) { + if share.index < 0 || share.index >= len(pk.hArray) { return fmt.Errorf("invalid share index") } hi_hat := g.Point().Sub(g.Point().Mul(share.f_i, nil), g.Point().Mul(share.e_i, pk.hArray[share.index])) @@ -459,7 +455,8 @@ func (ctxt *Ciphertext) Decrypt(group group.Group, x_i *PrivateShare, rand ciphe } // CombineShares combines a set of decryption shares and returns the decrypted message. -// The caller has to ensure that the ciphertext is validated. +// The caller has to ensure that the ciphertext is validated, the decryption shares are valid, +// all the shares are distinct and the number of them is at least k. func (c *Ciphertext) CombineShares(group group.Group, shares []*DecryptionShare, k, n int) ([]byte, error) { if group.String() != c.group.String() { return nil, fmt.Errorf("incorrect ciphertext group: %q", c.group) diff --git a/go/tdh2/tdh2/tdh2_test.go b/go/tdh2/tdh2/tdh2_test.go index f57c510..f93e24b 100644 --- a/go/tdh2/tdh2/tdh2_test.go +++ b/go/tdh2/tdh2/tdh2_test.go @@ -20,8 +20,6 @@ import ( var supportedGroups = []string{ nist.NewP256().String(), - nist.NewP384().String(), - nist.NewP521().String(), } // unsupported implements an unsupported group @@ -565,6 +563,17 @@ func TestCheckEi(t *testing.T) { }, err: cmpopts.AnyError, }, + { + name: "negative share index", + ctxt: ctxt, + share: &DecryptionShare{ + index: -1, + u_i: ds.u_i, + e_i: ds.e_i, + f_i: ds.f_i, + }, + err: cmpopts.AnyError, + }, { name: "broken U", ctxt: ctxt, @@ -841,14 +850,6 @@ func TestParseGroup(t *testing.T) { group: nist.NewP256().String(), want: nist.NewP256(), }, - { - group: nist.NewP384().String(), - want: nist.NewP384(), - }, - { - group: nist.NewP521().String(), - want: nist.NewP521(), - }, { group: "wrong", err: cmpopts.AnyError, diff --git a/go/tdh2/tdh2easy/sym.go b/go/tdh2/tdh2easy/sym.go index ec57b81..f7fc4d9 100644 --- a/go/tdh2/tdh2easy/sym.go +++ b/go/tdh2/tdh2easy/sym.go @@ -22,6 +22,9 @@ func symEncrypt(msg, key []byte) ([]byte, []byte, error) { if err != nil { return nil, nil, fmt.Errorf("cannot use AES: %v", err) } + if uint64(len(msg)) > ((1<<32)-2)*uint64(block.BlockSize()) { + return nil, nil, fmt.Errorf("message too long") + } gcm, err := cipher.NewGCM(block) if err != nil { return nil, nil, fmt.Errorf("cannot use GCM mode: %v", err) diff --git a/go/tdh2/tdh2easy/tdh2easy.go b/go/tdh2/tdh2easy/tdh2easy.go index de0f2cd..e78e420 100644 --- a/go/tdh2/tdh2easy/tdh2easy.go +++ b/go/tdh2/tdh2easy/tdh2easy.go @@ -153,7 +153,9 @@ func VerifyShare(c *Ciphertext, pk *PublicKey, share *DecryptionShare) error { // Aggregate decrypts the TDH2-encrypted key and using it recovers the // symmetrically encrypted plaintext. It takes decryption shares and // the total number of participants as the arguments. -// Ciphertext and shares MUST be verified before calling Aggregate. +// Ciphertext and shares MUST be verified before calling Aggregate, +// all the shares have to be distinct and their number has to be +// at least k (the scheme's threshold). func Aggregate(c *Ciphertext, shares []*DecryptionShare, n int) ([]byte, error) { sh := []*tdh2.DecryptionShare{} for _, s := range shares { diff --git a/js/tdh2/tdh2.js b/js/tdh2/tdh2.js index 6181cbd..d580b2a 100644 --- a/js/tdh2/tdh2.js +++ b/js/tdh2/tdh2.js @@ -104,10 +104,13 @@ function xor(a, b) { function encrypt(pub, msg) { const ciph = new Cipher('AES-256-GCM'); + const blockSize = 16; const key = rnd.randomBytes(tdh2InputSize); const nonce = rnd.randomBytes(12); ciph.init(key, nonce); + if (msg.length > ((2**32)-2)*blockSize) + throw new Error('message too long'); const ctxt = Buffer.concat([ ciph.update(msg), ciph.final(),