From b9206aea5002c3ba40b0d34e3fbc9b72388721d1 Mon Sep 17 00:00:00 2001 From: Mariia <41679258+mariiatuzovska@users.noreply.github.com> Date: Thu, 5 Dec 2024 16:24:13 -0500 Subject: [PATCH] Add PAIRIDReadWriter tests (#50) Co-authored-by: Amanj Sherwany --- go.mod | 4 + go.sum | 5 +- pkg/cmd/cli/decrypt.go | 2 +- pkg/pair/pair.go | 2 + pkg/pair/pair_test.go | 263 +++++++++++++++++++++++++++++++++++++++++ 5 files changed, 273 insertions(+), 3 deletions(-) create mode 100644 pkg/pair/pair_test.go diff --git a/go.mod b/go.mod index 478330a..056cb40 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/optable/match v1.4.0 github.com/optable/match-api/v2 v2.7.0 github.com/rs/zerolog v1.33.0 + github.com/stretchr/testify v1.10.0 gocloud.dev v0.39.0 golang.org/x/oauth2 v0.22.0 golang.org/x/sync v0.8.0 @@ -25,6 +26,7 @@ require ( cloud.google.com/go/auth/oauth2adapt v0.2.4 // indirect cloud.google.com/go/compute/metadata v0.5.0 // indirect cloud.google.com/go/iam v1.1.13 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect @@ -35,6 +37,7 @@ require ( github.com/googleapis/gax-go/v2 v2.13.0 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.19 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect go.opencensus.io v0.24.0 // indirect go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.53.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.53.0 // indirect @@ -51,4 +54,5 @@ require ( google.golang.org/genproto/googleapis/api v0.0.0-20240812133136-8ffd90a71988 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240812133136-8ffd90a71988 // indirect google.golang.org/grpc v1.65.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index d5db78d..c41a7d3 100644 --- a/go.sum +++ b/go.sum @@ -153,8 +153,8 @@ github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpE github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= -github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= @@ -292,6 +292,7 @@ google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpAD google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/pkg/cmd/cli/decrypt.go b/pkg/cmd/cli/decrypt.go index c40c522..221b051 100644 --- a/pkg/cmd/cli/decrypt.go +++ b/pkg/cmd/cli/decrypt.go @@ -56,7 +56,7 @@ func (c *DecryptCmd) Run(cli *CmdContext) error { } // no need for original salt - salt := base64.StdEncoding.EncodeToString(make([]byte, 32)) + salt := base64.StdEncoding.EncodeToString(make([]byte, pair.SHA256SaltSize)) // Decrypt and write if err := d.Decrypt(ctx, c.NumThreads, salt, advertiserKey); err != nil { diff --git a/pkg/pair/pair.go b/pkg/pair/pair.go index 82ece0b..42ab5f8 100644 --- a/pkg/pair/pair.go +++ b/pkg/pair/pair.go @@ -22,6 +22,8 @@ const ( minimumIDCount = 1000 maxOperationRunTime = 4 * time.Hour + + SHA256SaltSize = 32 ) var ( diff --git a/pkg/pair/pair_test.go b/pkg/pair/pair_test.go new file mode 100644 index 0000000..158fff6 --- /dev/null +++ b/pkg/pair/pair_test.go @@ -0,0 +1,263 @@ +package pair + +import ( + "bytes" + "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/csv" + "fmt" + "io" + "optable-pair-cli/pkg/keys" + "testing" + + "github.com/optable/match/pkg/pair" + "github.com/stretchr/testify/require" +) + +func TestPAIRIDReadWriter_HashEncrypt(t *testing.T) { + t.Parallel() + // arrange + lenEmails := 1001 + ctx := context.Background() + salt := requireGenSalt(t) + key := requireGenKey(t) + emails := requireGenRandomHashedEmails(t, lenEmails) + expected := requireEncryptEmails(t, emails, salt, key) + r, w := bytes.NewBuffer(nil), bytes.NewBuffer(nil) + + // set emails in csv format for PAIRIDReadWriter to read + requireWriteEmails(t, r, emails) + + // act + rw, err := NewPAIRIDReadWriter(r, w) + require.NoError(t, err, "must create PAIRIDReadWriter") + + err = rw.HashEncrypt(ctx, 1, salt, key) + require.NoError(t, err, "must hash and encrypt emails") + + // assert + csvData := csv.NewReader(w) + hashEncryptedData, err := csvData.ReadAll() + require.NoError(t, err, "must read csv data") + require.Len(t, hashEncryptedData, len(expected), "must contain all emails") + + for i, hashEncrypted := range hashEncryptedData { + require.Len(t, hashEncrypted, 1, "must contain one csv column") + require.Equal(t, expected[i], hashEncrypted[0], "encrypted email must match") + } +} + +func TestPAIRIDReadWriter_ReEncrypt(t *testing.T) { + t.Parallel() + // arrange + lenEmails := 10000 + ctx := context.Background() + salt := requireGenSalt(t) + key := requireGenKey(t) + emails := requireGenRandomHashedEmails(t, lenEmails) + encryptedEmails := requireEncryptEmails(t, emails, salt, key) + twiceEncryptedEmails := requireReEncryptEmails(t, encryptedEmails, salt, key) + r, w := bytes.NewBuffer(nil), bytes.NewBuffer(nil) + + // set twice encrypted emails in csv format for PAIRIDReadWriter to read + requireWriteEmails(t, r, encryptedEmails) + + // in this test we check encrypted emails are encrypted correctly and shuffled + expected := twiceEncryptedEmails + + // act + rw, err := NewPAIRIDReadWriter(r, w) + require.NoError(t, err, "must create PAIRIDReadWriter") + + err = rw.ReEncrypt(ctx, 1, salt, key) + require.NoError(t, err, "must re-encrypt emails") + + // assert + csvData := csv.NewReader(w) + reEncryptedData, err := csvData.ReadAll() + require.NoError(t, err, "must read csv data") + require.Len(t, reEncryptedData, len(expected), "must contain all emails") + + notShuffled := 0 + for i, reEncrypted := range reEncryptedData { + require.Len(t, reEncrypted, 1, "must contain one csv column") + + // check how many emails stay at the same place + if reEncrypted[0] == expected[i] { + notShuffled++ + } + + // must find the encrypted email in the expected slice + found := false + for _, e := range expected { + if e == reEncrypted[0] { + found = true + break + } + } + require.True(t, found, "re-encrypted email must match") + } + + require.Less(t, float64(notShuffled), float64(lenEmails)*0.01, "must shuffle more than 99% of emails") +} + +func TestPAIRIDReadWriter_HashDecrypt(t *testing.T) { + t.Parallel() + // arrange + lenEmails := 1001 + ctx := context.Background() + salt := requireGenSalt(t) + key := requireGenKey(t) + emails := requireGenRandomHashedEmails(t, lenEmails) + encryptedEmails := requireEncryptEmails(t, emails, salt, key) + twiceEncryptedEmails := requireReEncryptEmails(t, encryptedEmails, salt, key) + r, w := bytes.NewBuffer(nil), bytes.NewBuffer(nil) + + // set twice encrypted emails in csv format for PAIRIDReadWriter to read + requireWriteEmails(t, r, twiceEncryptedEmails) + + // in this test we check twice encrypted emails are decrypted correctly, i.e. + // decrypt(encrypt(encrypt(data))) = encrypt(data) + expected := encryptedEmails + + // act + rw, err := NewPAIRIDReadWriter(r, w) + require.NoError(t, err, "must create PAIRIDReadWriter") + + err = rw.Decrypt(ctx, 1, salt, key) + require.NoError(t, err, "must decrypt emails") + + // assert + csvData := csv.NewReader(w) + decryptedData, err := csvData.ReadAll() + require.NoError(t, err, "must read csv data") + require.Len(t, decryptedData, len(expected), "must contain all emails") + + for i, decrypted := range decryptedData { + require.Len(t, decrypted, 1, "must contain one csv column") + require.Equal(t, expected[i], decrypted[0], "encrypted email must match") + } +} + +func TestPAIRIDReadWriter_InputBelowThreshold(t *testing.T) { + t.Parallel() + // arrange + lenEmails := 999 + ctx := context.Background() + salt := requireGenSalt(t) + key := requireGenKey(t) + emails := requireGenRandomHashedEmails(t, lenEmails) + encryptedEmails := requireEncryptEmails(t, emails, salt, key) + twiceEncryptedEmails := requireReEncryptEmails(t, encryptedEmails, salt, key) + + t.Run("HashEncrypt", func(t *testing.T) { + t.Parallel() + r, w := bytes.NewBuffer(nil), bytes.NewBuffer(nil) + + // set emails in csv format for PAIRIDReadWriter to read + requireWriteEmails(t, r, emails) + + rw, err := NewPAIRIDReadWriter(r, w) + require.NoError(t, err, "must create PAIRIDReadWriter") + + err = rw.HashEncrypt(ctx, 1, salt, key) + require.Error(t, err, "must return error when input is below threshold") + require.Equal(t, ErrInputBelowThreshold, err) + }) + + t.Run("ReEncrypt", func(t *testing.T) { + t.Parallel() + r, w := bytes.NewBuffer(nil), bytes.NewBuffer(nil) + + // set encrypted emails in csv format for PAIRIDReadWriter to read + requireWriteEmails(t, r, encryptedEmails) + + rw, err := NewPAIRIDReadWriter(r, w) + require.NoError(t, err, "must create PAIRIDReadWriter") + + err = rw.ReEncrypt(ctx, 1, salt, key) + require.Error(t, err, "must return error when input is below threshold") + require.Equal(t, ErrInputBelowThreshold, err) + }) + + t.Run("Decrypt", func(t *testing.T) { + t.Parallel() + r, w := bytes.NewBuffer(nil), bytes.NewBuffer(nil) + + // set twice encrypted emails in csv format for PAIRIDReadWriter to read + requireWriteEmails(t, r, twiceEncryptedEmails) + + rw, err := NewPAIRIDReadWriter(r, w) + require.NoError(t, err, "must create PAIRIDReadWriter") + + err = rw.Decrypt(ctx, 1, salt, key) + require.Error(t, err, "must return error when input is below threshold") + require.Equal(t, ErrInputBelowThreshold, err) + }) +} + +func requireGenRandomHashedEmails(t *testing.T, emailsCount int) []string { + t.Helper() + shaEncoder := sha256.New() + hems := make([]string, emailsCount) + for i := range hems { + shaEncoder.Write([]byte(fmt.Sprintf("%d@gmail.com", i))) + hem := shaEncoder.Sum(nil) + hems[i] = fmt.Sprintf("%x", hem) + } + return hems +} + +func requireWriteEmails(t *testing.T, w io.Writer, emails []string) { + csvWriter := csv.NewWriter(w) + for _, email := range emails { + err := csvWriter.Write([]string{email}) + require.NoError(t, err) + } + csvWriter.Flush() +} + +func requireEncryptEmails(t *testing.T, emails []string, salt, key string) []string { + t.Helper() + pk, err := keys.NewPAIRPrivateKey(salt, key) + require.NoError(t, err) + + encryptedEmails := make([]string, len(emails)) + for i, email := range emails { + encrypted, err := pk.Encrypt([]byte(email)) + require.NoError(t, err) + encryptedEmails[i] = string(encrypted) + } + return encryptedEmails +} + +func requireReEncryptEmails(t *testing.T, emails []string, salt, key string) []string { + t.Helper() + pk, err := keys.NewPAIRPrivateKey(salt, key) + require.NoError(t, err) + + encryptedEmails := make([]string, len(emails)) + for i, email := range emails { + encrypted, err := pk.ReEncrypt([]byte(email)) + require.NoError(t, err) + encryptedEmails[i] = string(encrypted) + } + return encryptedEmails +} + +func requireGenSalt(t *testing.T) string { + t.Helper() + salt := make([]byte, SHA256SaltSize) + _, err := rand.Read(salt) + require.NoError(t, err) + return base64.StdEncoding.EncodeToString(salt) +} + +func requireGenKey(t *testing.T) string { + t.Helper() + key, err := keys.NewPrivateKey(pair.PAIRSHA256Ristretto255) + require.NoError(t, err) + return key +}