diff --git a/basculehash/bcrypt.go b/basculehash/bcrypt.go index 1e9b790..0e247ec 100644 --- a/basculehash/bcrypt.go +++ b/basculehash/bcrypt.go @@ -4,8 +4,6 @@ package basculehash import ( - "io" - "golang.org/x/crypto/bcrypt" ) @@ -20,19 +18,16 @@ type Bcrypt struct { Cost int } +var _ Hasher = Bcrypt{} +var _ Comparer = Bcrypt{} + // Hash executes the bcrypt algorithm and write the output to dst. -func (b Bcrypt) Hash(dst io.Writer, plaintext []byte) (n int, err error) { +func (b Bcrypt) Hash(plaintext []byte) (Digest, error) { hashed, err := bcrypt.GenerateFromPassword(plaintext, b.Cost) - if err == nil { - n, err = dst.Write(hashed) - } - - return + return Digest(hashed), err } // Matches attempts to match a plaintext against its bcrypt hashed value. -func (b Bcrypt) Matches(plaintext, hash []byte) (ok bool, err error) { - err = bcrypt.CompareHashAndPassword(hash, plaintext) - ok = (err == nil) - return +func (b Bcrypt) Matches(plaintext []byte, hash Digest) error { + return bcrypt.CompareHashAndPassword(hash, plaintext) } diff --git a/basculehash/bcrypt_test.go b/basculehash/bcrypt_test.go index 3ff1939..c8d11f5 100644 --- a/basculehash/bcrypt_test.go +++ b/basculehash/bcrypt_test.go @@ -4,68 +4,34 @@ package basculehash import ( - "bytes" "fmt" - "strings" "testing" "github.com/stretchr/testify/suite" "golang.org/x/crypto/bcrypt" ) -const bcryptPlaintext string = "bcrypt plaintext" - type BcryptTestSuite struct { - suite.Suite -} - -// goodHash returns a hash that is expected to be successful. -// The plaintext() is hashed with the given cost. -func (suite *BcryptTestSuite) goodHash(cost int) []byte { - var ( - b bytes.Buffer - hasher = Bcrypt{Cost: cost} - _, err = hasher.Hash(&b, []byte(bcryptPlaintext)) - ) - - suite.Require().NoError(err) - return b.Bytes() + TestSuite } func (suite *BcryptTestSuite) TestHash() { suite.Run("DefaultCost", func() { - var ( - o strings.Builder - hasher = Bcrypt{} - - n, err = hasher.Hash(&o, []byte(bcryptPlaintext)) + suite.goodHash( + Bcrypt{}.Hash(suite.plaintext), ) - - suite.NoError(err) - suite.Equal(o.Len(), n) }) suite.Run("CustomCost", func() { - var ( - o strings.Builder - hasher = Bcrypt{Cost: 12} - - n, err = hasher.Hash(&o, []byte(bcryptPlaintext)) + suite.goodHash( + Bcrypt{Cost: 12}.Hash(suite.plaintext), ) - - suite.NoError(err) - suite.Equal(o.Len(), n) }) suite.Run("CostTooHigh", func() { - var ( - o strings.Builder - hasher = Bcrypt{Cost: bcrypt.MaxCost + 100} - - _, err = hasher.Hash(&o, []byte(bcryptPlaintext)) + suite.badHash( + Bcrypt{Cost: bcrypt.MaxCost + 100}.Hash(suite.plaintext), ) - - suite.Error(err) }) } @@ -74,13 +40,15 @@ func (suite *BcryptTestSuite) TestMatches() { for _, cost := range []int{0 /* default */, 4, 8} { suite.Run(fmt.Sprintf("cost=%d", cost), func() { var ( - hashed = suite.goodHash(cost) - hasher = Bcrypt{Cost: cost} - ok, err = hasher.Matches([]byte(bcryptPlaintext), hashed) + hasher = Bcrypt{Cost: cost} + hashed = suite.goodHash( + hasher.Hash(suite.plaintext), + ) ) - suite.True(ok) - suite.NoError(err) + suite.NoError( + hasher.Matches(suite.plaintext, hashed), + ) }) } }) @@ -89,13 +57,15 @@ func (suite *BcryptTestSuite) TestMatches() { for _, cost := range []int{0 /* default */, 4, 8} { suite.Run(fmt.Sprintf("cost=%d", cost), func() { var ( - hashed = suite.goodHash(cost) - hasher = Bcrypt{Cost: cost} - ok, err = hasher.Matches([]byte("a different plaintext"), hashed) + hasher = Bcrypt{Cost: cost} + hashed = suite.goodHash( + hasher.Hash(suite.plaintext), + ) ) - suite.False(ok) - suite.Error(err) + suite.Error( + hasher.Matches([]byte("a different plaintext"), hashed), + ) }) } }) diff --git a/basculehash/comparer.go b/basculehash/comparer.go deleted file mode 100644 index 7e27e55..0000000 --- a/basculehash/comparer.go +++ /dev/null @@ -1,19 +0,0 @@ -// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC -// SPDX-License-Identifier: Apache-2.0 - -package basculehash - -// Comparer is a strategy for comparing plaintext values with a -// hash value from a Hasher. -type Comparer interface { - // Matches tests if the given plaintext matches the given hash. - // For example, this method can test if a password matches the - // one-way hashed password from a config file or database. - // - // If this method returns true, the error will always be nil. - // If this method returns false, the error may be non-nil to - // indicate that the match failed due to a problem, such as - // the hash not being parseable. Client code that is just - // interested in a yes/no answer can disregard the error return. - Matches(plaintext, hash []byte) (bool, error) -} diff --git a/basculehash/credentials.go b/basculehash/credentials.go new file mode 100644 index 0000000..4a1dc6a --- /dev/null +++ b/basculehash/credentials.go @@ -0,0 +1,25 @@ +// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC +// SPDX-License-Identifier: Apache-2.0 + +package basculehash + +import "context" + +// Credentials is a source of principals and their associated digests. A +// credentials instance may be in-memory or a remote system. +type Credentials interface { + // Get returns the Digest associated with the given Principal. + // This method returns false if the principal did not exist. + Get(ctx context.Context, principal string) (d Digest, exists bool) + + // Set associates a principal with a Digest. If the principal already + // exists, its digest is replaced. + Set(ctx context.Context, principal string, d Digest) + + // Delete removes one or more principals from this set. + Delete(ctx context.Context, principals ...string) + + // Update performs a bulk update of these credentials. Any existing + // principals are replaced. + Update(ctx context.Context, p Principals) +} diff --git a/basculehash/credentials_test.go b/basculehash/credentials_test.go new file mode 100644 index 0000000..f8527ad --- /dev/null +++ b/basculehash/credentials_test.go @@ -0,0 +1,110 @@ +// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC +// SPDX-License-Identifier: Apache-2.0 + +package basculehash + +import ( + "context" + + "golang.org/x/crypto/bcrypt" +) + +// CredentialsTestSuite runs a standard battery of tests against +// a Credentials implementation. +// +// Tests of UnmarshalJSON need to be done in tests of concrete types +// due to the way unmarshalling works in golang. +type CredentialsTestSuite[C Credentials] struct { + TestSuite + + // Implementations should supply SetupTest and SetupSubTest + // methods that populate this member. Don't forget to call + // TestSuite.SetupTest and TestSuite.SetupSubTest! + credentials C + + testCtx context.Context + hasher Hasher +} + +// SetupSuite initializes a hasher and comparer to use when verifying +// and creating digests. +func (suite *CredentialsTestSuite[C]) SetupSuite() { + suite.testCtx = context.Background() + suite.hasher = Bcrypt{Cost: bcrypt.MinCost} +} + +// exists asserts that a given principal exists with the given Digest. +func (suite *CredentialsTestSuite[C]) exists(principal string, expected Digest) { + d, ok := suite.credentials.Get(suite.testCtx, principal) + suite.Require().True(ok) + suite.Require().Equal(expected, d) +} + +// notExists asserts that the given principal did not exist. +func (suite *CredentialsTestSuite[C]) notExists(principal string) { + d, ok := suite.credentials.Get(suite.testCtx, principal) + suite.Require().False(ok) + suite.Require().Empty(d) +} + +// defaultHash creates a distinct hash of the suite plaintext for testing. +func (suite *CredentialsTestSuite[C]) defaultHash() Digest { + return suite.goodHash( + suite.hasher.Hash( + suite.plaintext, + ), + ) +} + +func (suite *CredentialsTestSuite[C]) TestGetSetDelete() { + suite.T().Log("delete from empty") + suite.credentials.Delete(suite.testCtx, "joe") + + suite.T().Log("add") + joeDigest := suite.defaultHash() + suite.credentials.Set(suite.testCtx, "joe", joeDigest) + suite.exists("joe", joeDigest) + + suite.T().Log("add another") + fredDigest := suite.defaultHash() + suite.credentials.Set(suite.testCtx, "fred", fredDigest) + suite.exists("joe", joeDigest) + suite.exists("fred", fredDigest) + + suite.T().Log("replace") + newJoeDigest := suite.defaultHash() + suite.Require().NotEqual(newJoeDigest, joeDigest) // hashes should always generate salt to make them distinct + suite.credentials.Set(suite.testCtx, "joe", newJoeDigest) + suite.exists("joe", newJoeDigest) + suite.exists("fred", fredDigest) + + suite.T().Log("delete a principal") + suite.credentials.Delete(suite.testCtx, "fred") + suite.notExists("fred") + suite.exists("joe", newJoeDigest) +} + +func (suite *CredentialsTestSuite[C]) TestUpdate() { + suite.credentials.Update(suite.testCtx, nil) + + joeDigest := suite.defaultHash() + fredDigest := suite.defaultHash() + suite.credentials.Update(suite.testCtx, Principals{ + "joe": joeDigest, + "fred": fredDigest, + }) + + suite.exists("joe", joeDigest) + suite.exists("fred", fredDigest) + + joeDigest = suite.defaultHash() + moeDigest := suite.defaultHash() + suite.credentials.Update(suite.testCtx, Principals{ + "joe": joeDigest, + "moe": moeDigest, + }) + + suite.exists("joe", joeDigest) + suite.exists("fred", fredDigest) + suite.exists("moe", moeDigest) +} diff --git a/basculehash/digest.go b/basculehash/digest.go new file mode 100644 index 0000000..34160c8 --- /dev/null +++ b/basculehash/digest.go @@ -0,0 +1,42 @@ +// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC +// SPDX-License-Identifier: Apache-2.0 + +package basculehash + +import "io" + +// Digest is the result of applying a Hasher to plaintext. +// A digest must be valid UTF-8, preferably using the format +// described by https://github.com/P-H-C/phc-string-format/blob/master/phc-sf-spec.md. +type Digest []byte + +// Copy returns a distinct copy of this digest. +func (d Digest) Copy() Digest { + clone := make(Digest, len(d)) + copy(clone, d) + return clone +} + +// String returns this Digest as is, but cast as a string. +func (d Digest) String() string { + return string(d) +} + +// MarshalText simply returns this Digest as a byte slice. This method ensures +// that the digest is written as is instead of encoded as base64 or some other +// encoding. +func (d Digest) MarshalText() ([]byte, error) { + return []byte(d), nil +} + +// UnmarshalText uses the given text as is. +func (d *Digest) UnmarshalText(text []byte) error { + *d = text + return nil +} + +// WriteTo writes this digest to the given writer. +func (d Digest) WriteTo(dst io.Writer) (int64, error) { + c, err := dst.Write(d) + return int64(c), err +} diff --git a/basculehash/digest_test.go b/basculehash/digest_test.go new file mode 100644 index 0000000..cac42ad --- /dev/null +++ b/basculehash/digest_test.go @@ -0,0 +1,63 @@ +// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC +// SPDX-License-Identifier: Apache-2.0 + +package basculehash + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/suite" +) + +type DigestTestSuite struct { + TestSuite + + digest Digest +} + +func (suite *DigestTestSuite) SetupTest() { + suite.TestSuite.SetupTest() + suite.digest = suite.goodHash( + Default().Hash(suite.plaintext), + ) +} + +func (suite *DigestTestSuite) TestCopy() { + clone := suite.digest.Copy() + suite.Equal(suite.digest, clone) + suite.NotSame(suite.digest, clone) +} + +func (suite *DigestTestSuite) TestString() { + suite.Equal( + suite.digest, + Digest(suite.digest.String()), + ) +} + +func (suite *DigestTestSuite) TestMarshalText() { + text, err := suite.digest.MarshalText() + suite.Require().NoError(err) + + var clone Digest + err = clone.UnmarshalText(text) + suite.Require().NoError(err) + suite.Equal(suite.digest, clone) +} + +func (suite *DigestTestSuite) TestWriteTo() { + var o bytes.Buffer + n, err := suite.digest.WriteTo(&o) + suite.Equal(int64(len(suite.digest)), n) + suite.Require().NoError(err) + + suite.Equal( + suite.digest, + Digest(o.Bytes()), + ) +} + +func TestDigest(t *testing.T) { + suite.Run(t, new(DigestTestSuite)) +} diff --git a/basculehash/hasher.go b/basculehash/hasher.go index baeebe7..cc1f3e8 100644 --- a/basculehash/hasher.go +++ b/basculehash/hasher.go @@ -3,16 +3,45 @@ package basculehash -import ( - "io" -) - // Hasher is a strategy for one-way hashing. +// +// Comparer is the interface for comparing hash digests with plaintext. +// A given Comparer will correspond to the format written by a Hasher. type Hasher interface { - // Hash writes the hash of a plaintext to a writer. The number of - // bytes written along with any error is returned. + // Hash returns a digest of the given plaintext. The returned Digest + // must be recognizable to a Comparer in order to be validated. + // + // If this method returns a nil error, it MUST return a valid Digest. + // If this method returns an error, the Digest is not guaranteed to have + // any particular value and should be discarded. + // + // The format of the digest must be ASCII. The recommended format is + // the PHC format documented at: // - // The format of the written hash must be ASCII. The recommended - // format is the modular crypt format, which bcrypt uses. - Hash(dst io.Writer, plaintext []byte) (int, error) + // https://github.com/P-H-C/phc-string-format/blob/master/phc-sf-spec.md + Hash(plaintext []byte) (Digest, error) +} + +// Comparer is a strategy for comparing plaintext values with a +// hash digest from a Hasher. +type Comparer interface { + // Matches tests if the given plaintext matches the given hash. + // For example, this method can test if a password matches the + // one-way hashed password from a config file or database. + Matches(plaintext []byte, d Digest) error +} + +// HasherComparer provides both hashing and corresponding comparison. +// This is the typical interface that a hashing algorithm will implement. +type HasherComparer interface { + Hasher + Comparer +} + +var defaultHash HasherComparer = Bcrypt{} + +// Default returns the default algorithm to use for comparing +// hashed passwords. +func Default() HasherComparer { + return defaultHash } diff --git a/basculehash/principals.go b/basculehash/principals.go new file mode 100644 index 0000000..5bbb061 --- /dev/null +++ b/basculehash/principals.go @@ -0,0 +1,45 @@ +// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC +// SPDX-License-Identifier: Apache-2.0 + +package basculehash + +import ( + "context" +) + +// Principals is a Credentials implementation that is a simple map +// of principals to digests. This type is not safe for concurrent +// usage. +// +// This type is appropriate if the set of credentials is either immutable +// or protected from concurrent updates by some other means. +type Principals map[string]Digest + +var _ Credentials = Principals{} + +// Get returns the Digest associated with the principal. This method +// returns false if the principal did not exist. +func (p Principals) Get(_ context.Context, principal string) (d Digest, exists bool) { + d, exists = p[principal] + return +} + +// Set adds or replaces the given principal and its associated digest. +func (p Principals) Set(_ context.Context, principal string, d Digest) { + p[principal] = d.Copy() +} + +// Delete removes the given principal(s) from this set. +func (p Principals) Delete(_ context.Context, principals ...string) { + for _, toDelete := range principals { + delete(p, toDelete) + } +} + +// Update performs a bulk update of credentials. Each digest is copied +// before storing in this instance. +func (p Principals) Update(_ context.Context, more Principals) { + for principal, digest := range more { + p[principal] = digest.Copy() + } +} diff --git a/basculehash/principals_test.go b/basculehash/principals_test.go new file mode 100644 index 0000000..fdf154a --- /dev/null +++ b/basculehash/principals_test.go @@ -0,0 +1,27 @@ +// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC +// SPDX-License-Identifier: Apache-2.0 + +package basculehash + +import ( + "testing" + + "github.com/stretchr/testify/suite" +) + +type PrincipalsTestSuite struct { + CredentialsTestSuite[Principals] +} + +func (suite *PrincipalsTestSuite) SetupSubTest() { + suite.SetupTest() +} + +func (suite *PrincipalsTestSuite) SetupTest() { + suite.CredentialsTestSuite.SetupTest() + suite.credentials = Principals{} +} + +func TestPrincipals(t *testing.T) { + suite.Run(t, new(PrincipalsTestSuite)) +} diff --git a/basculehash/store.go b/basculehash/store.go new file mode 100644 index 0000000..d8a9a3f --- /dev/null +++ b/basculehash/store.go @@ -0,0 +1,100 @@ +// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC +// SPDX-License-Identifier: Apache-2.0 + +package basculehash + +import ( + "context" + "encoding/json" + "sync" +) + +// Store is an in-memory, threadsafe Credentials implementation. +// A Store instance is safe for concurrent reads and writes. +// Instances of this type must not be copied after creation. +// +// The zero value of this type is valid and ready to use. +type Store struct { + lock sync.RWMutex + principals Principals +} + +var _ Credentials = (*Store)(nil) + +// Get returns the Digest associated with the principal. +func (s *Store) Get(ctx context.Context, principal string) (d Digest, exists bool) { + s.lock.RLock() + d, exists = s.principals.Get(ctx, principal) + s.lock.RUnlock() + return +} + +// Set adds or updates a principal's password. +func (s *Store) Set(ctx context.Context, principal string, d Digest) { + clone := d.Copy() + s.lock.Lock() + + if s.principals == nil { + s.principals = make(Principals) + } + + s.principals.Set(ctx, principal, clone) + + s.lock.Unlock() +} + +// Delete removes the principal(s) from this Store. +func (s *Store) Delete(_ context.Context, principals ...string) { + s.lock.Lock() + + for _, toDelete := range principals { + delete(s.principals, toDelete) + } + + s.lock.Unlock() +} + +// Update performs a bulk update to this Store. +func (s *Store) Update(_ context.Context, more Principals) { + names := make([]string, 0, len(more)) + digests := make([]Digest, 0, len(more)) + for principal, digest := range more { + names = append(names, principal) + digests = append(digests, digest.Copy()) + } + + s.lock.Lock() + + if s.principals == nil { + s.principals = make(Principals) + } + + for i := 0; i < len(names); i++ { + s.principals[names[i]] = digests[i] // a copy was already made + } + + s.lock.Unlock() +} + +// MarshalJSON writes the current state of this Store to JSON. +func (s *Store) MarshalJSON() (data []byte, err error) { + s.lock.RLock() + data, err = json.Marshal(s.principals) + s.lock.RUnlock() + return +} + +// UnmarshalJSON unmarshals data and replaces the current set of principals. +// If unmarshalling returned an error, this Store's state remains unchanged. +func (s *Store) UnmarshalJSON(data []byte) (err error) { + s.lock.Lock() + + var unmarshaled Principals + if err = json.Unmarshal(data, &unmarshaled); err == nil { + s.principals = unmarshaled + } + + s.lock.Unlock() + + return +} diff --git a/basculehash/store_test.go b/basculehash/store_test.go new file mode 100644 index 0000000..ef440bf --- /dev/null +++ b/basculehash/store_test.go @@ -0,0 +1,73 @@ +// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC +// SPDX-License-Identifier: Apache-2.0 + +package basculehash + +import ( + "encoding/json" + "fmt" + "testing" + + "github.com/stretchr/testify/suite" +) + +type StoreTestSuite struct { + CredentialsTestSuite[*Store] +} + +func (suite *StoreTestSuite) SetupSubTest() { + suite.SetupTest() +} + +func (suite *StoreTestSuite) SetupTest() { + suite.CredentialsTestSuite.SetupTest() + suite.credentials = new(Store) +} + +func (suite *StoreTestSuite) TestMarshalJSON() { + var ( + joeDigest = suite.defaultHash() + fredDigest = suite.defaultHash() + + expectedJSON = fmt.Sprintf( + `{ + "joe": "%s", + "fred": "%s" + }`, + joeDigest, + fredDigest, + ) + ) + + suite.credentials.Set(suite.testCtx, "joe", joeDigest) + suite.credentials.Set(suite.testCtx, "fred", fredDigest) + actualJSON, err := json.Marshal(suite.credentials) + + suite.Require().NoError(err) + suite.JSONEq(expectedJSON, string(actualJSON)) +} + +func (suite *StoreTestSuite) TestUnmarshalJSON() { + var ( + joeDigest = suite.defaultHash() + fredDigest = suite.defaultHash() + + jsonValue = fmt.Sprintf( + `{ + "joe": "%s", + "fred": "%s" + }`, + joeDigest, + fredDigest, + ) + ) + + err := json.Unmarshal([]byte(jsonValue), suite.credentials) + suite.Require().NoError(err) + suite.exists("joe", joeDigest) + suite.exists("fred", fredDigest) +} + +func TestStore(t *testing.T) { + suite.Run(t, new(StoreTestSuite)) +} diff --git a/basculehash/testSuite_test.go b/basculehash/testSuite_test.go new file mode 100644 index 0000000..475d720 --- /dev/null +++ b/basculehash/testSuite_test.go @@ -0,0 +1,38 @@ +// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC +// SPDX-License-Identifier: Apache-2.0 + +package basculehash + +import ( + "github.com/stretchr/testify/suite" +) + +// TestSuite has common infrastructure for hashing test suites. +type TestSuite struct { + suite.Suite + + plaintext []byte +} + +func (suite *TestSuite) SetupSubTest() { + suite.SetupTest() +} + +func (suite *TestSuite) SetupTest() { + suite.plaintext = []byte("here is some plaintext") +} + +// goodHash asserts that a hasher did create a digest successfully, +// and returns that Digest. +func (suite *TestSuite) goodHash(d Digest, err error) Digest { + suite.Require().NoError(err) + suite.Require().NotEmpty(d) + return d +} + +// badHash asserts that the hash fails. The digest and error are returned +// for any future asserts. +func (suite *TestSuite) badHash(d Digest, err error) (Digest, error) { + suite.Require().Error(err) + return d, err // hashers are not required to return empty digests on error +} diff --git a/basculehash/validator.go b/basculehash/validator.go new file mode 100644 index 0000000..4c70d8b --- /dev/null +++ b/basculehash/validator.go @@ -0,0 +1,48 @@ +// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC +// SPDX-License-Identifier: Apache-2.0 + +package basculehash + +import ( + "context" + "errors" + + "github.com/xmidt-org/bascule" +) + +type matcherValidator[S any] struct { + cmp Comparer + creds Credentials +} + +func (mv *matcherValidator[S]) Validate(ctx context.Context, _ S, t bascule.Token) (next bascule.Token, err error) { + next = t + password, ok := bascule.GetPassword(t) + if !ok { + return + } + + if digest, exists := mv.creds.Get(ctx, t.Principal()); exists { + err = mv.cmp.Matches([]byte(password), digest) + if err != nil { + err = errors.Join(bascule.ErrBadCredentials, err) + } + } else { + err = bascule.ErrBadCredentials + } + + return +} + +// NewValidator returns a bascule.Validator that always uses the same hash +// Comparer. The source S is unused, but conforms to the Validator interface. +func NewValidator[S any](cmp Comparer, creds Credentials) bascule.Validator[S] { + if cmp == nil { + cmp = Default() + } + + return &matcherValidator[S]{ + cmp: cmp, + creds: creds, + } +} diff --git a/basculehash/validator_test.go b/basculehash/validator_test.go new file mode 100644 index 0000000..31cf802 --- /dev/null +++ b/basculehash/validator_test.go @@ -0,0 +1,109 @@ +// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC +// SPDX-License-Identifier: Apache-2.0 + +package basculehash + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/suite" + "github.com/xmidt-org/bascule" + "golang.org/x/crypto/bcrypt" +) + +type validatorTestToken struct { + principal, password string +} + +func (t validatorTestToken) Principal() string { return t.principal } + +func (t validatorTestToken) Password() string { return t.password } + +type ValidatorTestSuite struct { + TestSuite + + testCtx context.Context + request *http.Request +} + +func (suite *ValidatorTestSuite) SetupSubTest() { + suite.SetupTest() +} + +func (suite *ValidatorTestSuite) SetupTest() { + suite.TestSuite.SetupTest() + suite.testCtx = context.Background() + suite.request = httptest.NewRequest("GET", "/", nil) +} + +// newDefaultToken creates a password token using this suite's default plaintext. +func (suite *ValidatorTestSuite) newDefaultToken(principal string) bascule.Token { + return validatorTestToken{ + principal: principal, + password: string(suite.plaintext), + } +} + +// newCredentials builds a standard set of credentials using the given hasher. +func (suite *ValidatorTestSuite) newCredentials(h Hasher) Credentials { + return Principals{ + "joe": suite.goodHash(h.Hash(suite.plaintext)), + "fred": suite.goodHash(h.Hash(suite.plaintext)), + } +} + +func (suite *ValidatorTestSuite) newValidator(cmp Comparer, creds Credentials) bascule.Validator[*http.Request] { + v := NewValidator[*http.Request](cmp, creds) + suite.Require().NotNil(v) + return v +} + +func (suite *ValidatorTestSuite) testValidate(cmp Comparer, h Hasher) { + v := suite.newValidator(cmp, suite.newCredentials(h)) + + suite.Run("NonPasswordToken", func() { + t := bascule.StubToken("joe") + next, err := v.Validate(suite.testCtx, suite.request, t) + suite.Equal(t, next) + suite.NoError(err) + }) + + suite.Run("NoSuchPrincipal", func() { + t := suite.newDefaultToken("nosuch") + next, err := v.Validate(suite.testCtx, suite.request, t) + suite.Equal(t, next) + suite.ErrorIs(err, bascule.ErrBadCredentials) + }) + + suite.Run("BadPassword", func() { + t := validatorTestToken{principal: "joe", password: "bad"} + next, err := v.Validate(suite.testCtx, suite.request, t) + suite.Equal(t, next) + suite.ErrorIs(err, bascule.ErrBadCredentials) + }) + + suite.Run("Success", func() { + t := suite.newDefaultToken("joe") + next, err := v.Validate(suite.testCtx, suite.request, t) + suite.Equal(t, next) + suite.NoError(err) + }) +} + +func (suite *ValidatorTestSuite) TestValidate() { + suite.Run("DefaultComparer", func() { + suite.testValidate(nil, Default()) + }) + + suite.Run("CustomComparer", func() { + hc := Bcrypt{Cost: bcrypt.MinCost} + suite.testValidate(hc, hc) + }) +} + +func TestValidator(t *testing.T) { + suite.Run(t, new(ValidatorTestSuite)) +} diff --git a/cmd/hash/cli.go b/cmd/hash/cli.go index 62be675..694428b 100644 --- a/cmd/hash/cli.go +++ b/cmd/hash/cli.go @@ -45,7 +45,11 @@ func (cmd *Bcrypt) Run(kong *kong.Kong) error { Cost: cmd.Cost, } - _, err := hasher.Hash(kong.Stdout, []byte(cmd.Plaintext)) + digest, err := hasher.Hash([]byte(cmd.Plaintext)) + if err == nil { + digest.WriteTo(kong.Stdout) + } + return err }