diff --git a/signature/signer_factory.go b/signature/signer_factory.go index 7b9e4bf..0058eda 100644 --- a/signature/signer_factory.go +++ b/signature/signer_factory.go @@ -39,8 +39,9 @@ func NewSigner(handle *keyset.Handle) (tink.Signer, error) { // wrappedSigner is an Signer implementation that uses the underlying primitive set for signing. type wrappedSigner struct { - ps *primitiveset.PrimitiveSet - logger monitoring.Logger + signer tink.Signer + signerKeyID uint32 + logger monitoring.Logger } // Asserts that wrappedSigner implements the Signer interface. @@ -90,9 +91,11 @@ func extractFullSigner(entry *primitiveset.Entry) (tink.Signer, error) { } func newWrappedSigner(ps *primitiveset.PrimitiveSet) (*wrappedSigner, error) { - if _, err := extractFullSigner(ps.Primary); err != nil { + signer, err := extractFullSigner(ps.Primary) + if err != nil { return nil, err } + // Validate that all entries are tink.Signer. for _, entries := range ps.Entries { for _, entry := range entries { if _, err := extractFullSigner(entry); err != nil { @@ -105,8 +108,9 @@ func newWrappedSigner(ps *primitiveset.PrimitiveSet) (*wrappedSigner, error) { return nil, err } return &wrappedSigner{ - ps: ps, - logger: logger, + signer: signer, + signerKeyID: ps.Primary.KeyID, + logger: logger, }, nil } @@ -128,15 +132,11 @@ func createSignerLogger(ps *primitiveset.PrimitiveSet) (monitoring.Logger, error // Sign signs the given data using the primary key. func (s *wrappedSigner) Sign(data []byte) ([]byte, error) { - signer, err := extractFullSigner(s.ps.Primary) - if err != nil { - return nil, err - } - signature, err := signer.Sign(data) + signature, err := s.signer.Sign(data) if err != nil { s.logger.LogFailure() return nil, err } - s.logger.Log(s.ps.Primary.KeyID, len(data)) + s.logger.Log(s.signerKeyID, len(data)) return signature, nil }