Skip to content

Commit

Permalink
Add PAIRIDReadWriter tests (#50)
Browse files Browse the repository at this point in the history
Co-authored-by: Amanj Sherwany <[email protected]>
  • Loading branch information
mariiatuzovska and amanjpro authored Dec 5, 2024
1 parent e3742b2 commit b9206ae
Show file tree
Hide file tree
Showing 5 changed files with 273 additions and 3 deletions.
4 changes: 4 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ require (
github.com/optable/match v1.4.0
github.com/optable/match-api/v2 v2.7.0
github.com/rs/zerolog v1.33.0
github.com/stretchr/testify v1.10.0
gocloud.dev v0.39.0
golang.org/x/oauth2 v0.22.0
golang.org/x/sync v0.8.0
Expand All @@ -25,6 +26,7 @@ require (
cloud.google.com/go/auth/oauth2adapt v0.2.4 // indirect
cloud.google.com/go/compute/metadata v0.5.0 // indirect
cloud.google.com/go/iam v1.1.13 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/go-logr/logr v1.4.2 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
Expand All @@ -35,6 +37,7 @@ require (
github.com/googleapis/gax-go/v2 v2.13.0 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
go.opencensus.io v0.24.0 // indirect
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.53.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.53.0 // indirect
Expand All @@ -51,4 +54,5 @@ require (
google.golang.org/genproto/googleapis/api v0.0.0-20240812133136-8ffd90a71988 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240812133136-8ffd90a71988 // indirect
google.golang.org/grpc v1.65.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
5 changes: 3 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpE
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0=
go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo=
Expand Down Expand Up @@ -292,6 +292,7 @@ google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpAD
google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c=
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
Expand Down
2 changes: 1 addition & 1 deletion pkg/cmd/cli/decrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func (c *DecryptCmd) Run(cli *CmdContext) error {
}

// no need for original salt
salt := base64.StdEncoding.EncodeToString(make([]byte, 32))
salt := base64.StdEncoding.EncodeToString(make([]byte, pair.SHA256SaltSize))

// Decrypt and write
if err := d.Decrypt(ctx, c.NumThreads, salt, advertiserKey); err != nil {
Expand Down
2 changes: 2 additions & 0 deletions pkg/pair/pair.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ const (
minimumIDCount = 1000

maxOperationRunTime = 4 * time.Hour

SHA256SaltSize = 32
)

var (
Expand Down
263 changes: 263 additions & 0 deletions pkg/pair/pair_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
package pair

import (
"bytes"
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/csv"
"fmt"
"io"
"optable-pair-cli/pkg/keys"
"testing"

"github.com/optable/match/pkg/pair"
"github.com/stretchr/testify/require"
)

func TestPAIRIDReadWriter_HashEncrypt(t *testing.T) {
t.Parallel()
// arrange
lenEmails := 1001
ctx := context.Background()
salt := requireGenSalt(t)
key := requireGenKey(t)
emails := requireGenRandomHashedEmails(t, lenEmails)
expected := requireEncryptEmails(t, emails, salt, key)
r, w := bytes.NewBuffer(nil), bytes.NewBuffer(nil)

// set emails in csv format for PAIRIDReadWriter to read
requireWriteEmails(t, r, emails)

// act
rw, err := NewPAIRIDReadWriter(r, w)
require.NoError(t, err, "must create PAIRIDReadWriter")

err = rw.HashEncrypt(ctx, 1, salt, key)
require.NoError(t, err, "must hash and encrypt emails")

// assert
csvData := csv.NewReader(w)
hashEncryptedData, err := csvData.ReadAll()
require.NoError(t, err, "must read csv data")
require.Len(t, hashEncryptedData, len(expected), "must contain all emails")

for i, hashEncrypted := range hashEncryptedData {
require.Len(t, hashEncrypted, 1, "must contain one csv column")
require.Equal(t, expected[i], hashEncrypted[0], "encrypted email must match")
}
}

func TestPAIRIDReadWriter_ReEncrypt(t *testing.T) {
t.Parallel()
// arrange
lenEmails := 10000
ctx := context.Background()
salt := requireGenSalt(t)
key := requireGenKey(t)
emails := requireGenRandomHashedEmails(t, lenEmails)
encryptedEmails := requireEncryptEmails(t, emails, salt, key)
twiceEncryptedEmails := requireReEncryptEmails(t, encryptedEmails, salt, key)
r, w := bytes.NewBuffer(nil), bytes.NewBuffer(nil)

// set twice encrypted emails in csv format for PAIRIDReadWriter to read
requireWriteEmails(t, r, encryptedEmails)

// in this test we check encrypted emails are encrypted correctly and shuffled
expected := twiceEncryptedEmails

// act
rw, err := NewPAIRIDReadWriter(r, w)
require.NoError(t, err, "must create PAIRIDReadWriter")

err = rw.ReEncrypt(ctx, 1, salt, key)
require.NoError(t, err, "must re-encrypt emails")

// assert
csvData := csv.NewReader(w)
reEncryptedData, err := csvData.ReadAll()
require.NoError(t, err, "must read csv data")
require.Len(t, reEncryptedData, len(expected), "must contain all emails")

notShuffled := 0
for i, reEncrypted := range reEncryptedData {
require.Len(t, reEncrypted, 1, "must contain one csv column")

// check how many emails stay at the same place
if reEncrypted[0] == expected[i] {
notShuffled++
}

// must find the encrypted email in the expected slice
found := false
for _, e := range expected {
if e == reEncrypted[0] {
found = true
break
}
}
require.True(t, found, "re-encrypted email must match")
}

require.Less(t, float64(notShuffled), float64(lenEmails)*0.01, "must shuffle more than 99% of emails")
}

func TestPAIRIDReadWriter_HashDecrypt(t *testing.T) {
t.Parallel()
// arrange
lenEmails := 1001
ctx := context.Background()
salt := requireGenSalt(t)
key := requireGenKey(t)
emails := requireGenRandomHashedEmails(t, lenEmails)
encryptedEmails := requireEncryptEmails(t, emails, salt, key)
twiceEncryptedEmails := requireReEncryptEmails(t, encryptedEmails, salt, key)
r, w := bytes.NewBuffer(nil), bytes.NewBuffer(nil)

// set twice encrypted emails in csv format for PAIRIDReadWriter to read
requireWriteEmails(t, r, twiceEncryptedEmails)

// in this test we check twice encrypted emails are decrypted correctly, i.e.
// decrypt(encrypt(encrypt(data))) = encrypt(data)
expected := encryptedEmails

// act
rw, err := NewPAIRIDReadWriter(r, w)
require.NoError(t, err, "must create PAIRIDReadWriter")

err = rw.Decrypt(ctx, 1, salt, key)
require.NoError(t, err, "must decrypt emails")

// assert
csvData := csv.NewReader(w)
decryptedData, err := csvData.ReadAll()
require.NoError(t, err, "must read csv data")
require.Len(t, decryptedData, len(expected), "must contain all emails")

for i, decrypted := range decryptedData {
require.Len(t, decrypted, 1, "must contain one csv column")
require.Equal(t, expected[i], decrypted[0], "encrypted email must match")
}
}

func TestPAIRIDReadWriter_InputBelowThreshold(t *testing.T) {
t.Parallel()
// arrange
lenEmails := 999
ctx := context.Background()
salt := requireGenSalt(t)
key := requireGenKey(t)
emails := requireGenRandomHashedEmails(t, lenEmails)
encryptedEmails := requireEncryptEmails(t, emails, salt, key)
twiceEncryptedEmails := requireReEncryptEmails(t, encryptedEmails, salt, key)

t.Run("HashEncrypt", func(t *testing.T) {
t.Parallel()
r, w := bytes.NewBuffer(nil), bytes.NewBuffer(nil)

// set emails in csv format for PAIRIDReadWriter to read
requireWriteEmails(t, r, emails)

rw, err := NewPAIRIDReadWriter(r, w)
require.NoError(t, err, "must create PAIRIDReadWriter")

err = rw.HashEncrypt(ctx, 1, salt, key)
require.Error(t, err, "must return error when input is below threshold")
require.Equal(t, ErrInputBelowThreshold, err)
})

t.Run("ReEncrypt", func(t *testing.T) {
t.Parallel()
r, w := bytes.NewBuffer(nil), bytes.NewBuffer(nil)

// set encrypted emails in csv format for PAIRIDReadWriter to read
requireWriteEmails(t, r, encryptedEmails)

rw, err := NewPAIRIDReadWriter(r, w)
require.NoError(t, err, "must create PAIRIDReadWriter")

err = rw.ReEncrypt(ctx, 1, salt, key)
require.Error(t, err, "must return error when input is below threshold")
require.Equal(t, ErrInputBelowThreshold, err)
})

t.Run("Decrypt", func(t *testing.T) {
t.Parallel()
r, w := bytes.NewBuffer(nil), bytes.NewBuffer(nil)

// set twice encrypted emails in csv format for PAIRIDReadWriter to read
requireWriteEmails(t, r, twiceEncryptedEmails)

rw, err := NewPAIRIDReadWriter(r, w)
require.NoError(t, err, "must create PAIRIDReadWriter")

err = rw.Decrypt(ctx, 1, salt, key)
require.Error(t, err, "must return error when input is below threshold")
require.Equal(t, ErrInputBelowThreshold, err)
})
}

func requireGenRandomHashedEmails(t *testing.T, emailsCount int) []string {
t.Helper()
shaEncoder := sha256.New()
hems := make([]string, emailsCount)
for i := range hems {
shaEncoder.Write([]byte(fmt.Sprintf("%[email protected]", i)))
hem := shaEncoder.Sum(nil)
hems[i] = fmt.Sprintf("%x", hem)
}
return hems
}

func requireWriteEmails(t *testing.T, w io.Writer, emails []string) {
csvWriter := csv.NewWriter(w)
for _, email := range emails {
err := csvWriter.Write([]string{email})
require.NoError(t, err)
}
csvWriter.Flush()
}

func requireEncryptEmails(t *testing.T, emails []string, salt, key string) []string {
t.Helper()
pk, err := keys.NewPAIRPrivateKey(salt, key)
require.NoError(t, err)

encryptedEmails := make([]string, len(emails))
for i, email := range emails {
encrypted, err := pk.Encrypt([]byte(email))
require.NoError(t, err)
encryptedEmails[i] = string(encrypted)
}
return encryptedEmails
}

func requireReEncryptEmails(t *testing.T, emails []string, salt, key string) []string {
t.Helper()
pk, err := keys.NewPAIRPrivateKey(salt, key)
require.NoError(t, err)

encryptedEmails := make([]string, len(emails))
for i, email := range emails {
encrypted, err := pk.ReEncrypt([]byte(email))
require.NoError(t, err)
encryptedEmails[i] = string(encrypted)
}
return encryptedEmails
}

func requireGenSalt(t *testing.T) string {
t.Helper()
salt := make([]byte, SHA256SaltSize)
_, err := rand.Read(salt)
require.NoError(t, err)
return base64.StdEncoding.EncodeToString(salt)
}

func requireGenKey(t *testing.T) string {
t.Helper()
key, err := keys.NewPrivateKey(pair.PAIRSHA256Ristretto255)
require.NoError(t, err)
return key
}

0 comments on commit b9206ae

Please sign in to comment.