From f9ed2c3679001f13c102a256af2e1099a80ddb5d Mon Sep 17 00:00:00 2001 From: "Juan A. Garcia Pardo" Date: Thu, 17 Aug 2023 10:07:14 +0200 Subject: [PATCH] The mapserver service correctly prunes now. --- cmd/ingest/main.go | 2 +- cmd/mapserver/mapserver.go | 59 +++++++++++++++++------ pkg/mapserver/responder/responder_test.go | 2 +- pkg/mapserver/updater/updater.go | 56 ++++++++++++--------- 4 files changed, 77 insertions(+), 42 deletions(-) diff --git a/cmd/ingest/main.go b/cmd/ingest/main.go index d8edc9d6..2043fc58 100644 --- a/cmd/ingest/main.go +++ b/cmd/ingest/main.go @@ -135,7 +135,7 @@ func mainFunction() int { // Now update the SMT Trie with the changed domains: fmt.Println("Starting SMT update ...") - err = updater.UpdateSMT(ctx, conn, 32) + err = updater.UpdateSMT(ctx, conn) exitIfError(err) fmt.Println("Done SMT update.") diff --git a/cmd/mapserver/mapserver.go b/cmd/mapserver/mapserver.go index 39d9e518..469fd3ca 100644 --- a/cmd/mapserver/mapserver.go +++ b/cmd/mapserver/mapserver.go @@ -123,36 +123,63 @@ func (s *MapServer) pruneAndUpdate(ctx context.Context) { s.update(ctx) } +// prune only removes the affected certificates from the certs table and adds the affected domains +// to the dirty table. Because update is always called right after prune, we don't need to first +// compute the coalesced domains for those dirty domains after prune and before update. It is +// sufficient to call CoalescePayloadsForDirtyDomains after update and it will take care of all +// dirty domains, coming from both prune and update. func (s *MapServer) prune(ctx context.Context) { - getTime := func() string { - return time.Now().UTC().Format(time.RFC3339) - } fmt.Printf("======== prune started at %s\n", getTime()) defer fmt.Printf("======== prune finished at %s\n\n", getTime()) - // deleteme TODO - n, err := s.Updater.Conn.PruneCerts(ctx, time.Now()) + err := s.Updater.Conn.PruneCerts(ctx, time.Now()) if err != nil { s.updateErrChan <- fmt.Errorf("pruning: %w", err) } - fmt.Printf("pruning: %d certs removed\n", n) s.updateErrChan <- error(nil) // Always answer something. } func (s *MapServer) update(ctx context.Context) { - getTime := func() string { - return time.Now().UTC().Format(time.RFC3339) - } fmt.Printf("======== update started at %s\n", getTime()) defer fmt.Printf("======== update finished at %s\n\n", getTime()) - if err := s.Updater.StartFetchingRemaining(); err != nil { + if err := s.updateCerts(ctx); err != nil { + s.updateErrChan <- err + return + } + // TODO(juagargi) do policy certificates here. + + fmt.Printf("coalescing certificate payloads at %s\n", getTime()) + if err := s.Updater.CoalescePayloadsForDirtyDomains(ctx); err != nil { + s.updateErrChan <- err + return + } + + // Update SMT. + fmt.Printf("updating SMT at %s\n", getTime()) + if err := s.Updater.UpdateSMT(ctx); err != nil { + s.updateErrChan <- fmt.Errorf("updating SMT: %w", err) + return + } - s.updateErrChan <- fmt.Errorf("retrieving start and end indices: %w", err) + // Cleanup. + fmt.Printf("cleaning up at %s\n", getTime()) + if err := s.Updater.Conn.CleanupDirty(ctx); err != nil { + s.updateErrChan <- fmt.Errorf("cleaning up DB: %w", err) return } + // Always queue answer in form of an error: + s.updateErrChan <- error(nil) +} + +func (s *MapServer) updateCerts(ctx context.Context) error { + if err := s.Updater.StartFetchingRemaining(); err != nil { + return fmt.Errorf("retrieving start and end indices: %w", err) + } + defer s.Updater.StopFetching() + // Main update loop. for s.Updater.NextBatch(ctx) { n, err := s.Updater.UpdateNextBatch(ctx) @@ -160,12 +187,12 @@ func (s *MapServer) update(ctx context.Context) { fmt.Printf("updated %5d certs batch at %s\n", n, getTime()) if err != nil { // We stop the loop here, as probably requires manual inspection of the logs, etc. - fmt.Printf("error: %s\n", err) - break + return fmt.Errorf("updating next batch of x509 certificates: %w", err) } } - s.Updater.StopFetching() + return nil +} - // Queue answer in form of an error: - s.updateErrChan <- error(nil) +func getTime() string { + return time.Now().UTC().Format(time.RFC3339) } diff --git a/pkg/mapserver/responder/responder_test.go b/pkg/mapserver/responder/responder_test.go index 7fb8b7f3..0801fea9 100644 --- a/pkg/mapserver/responder/responder_test.go +++ b/pkg/mapserver/responder/responder_test.go @@ -102,7 +102,7 @@ func TestProof(t *testing.T) { require.NoError(t, err) // Create/update the SMT. - err = updater.UpdateSMT(ctx, conn, 32) + err = updater.UpdateSMT(ctx, conn) require.NoError(t, err) // And cleanup dirty, flagging the end of the update cycle. diff --git a/pkg/mapserver/updater/updater.go b/pkg/mapserver/updater/updater.go index 6205db55..065f9f83 100644 --- a/pkg/mapserver/updater/updater.go +++ b/pkg/mapserver/updater/updater.go @@ -21,9 +21,10 @@ import ( // MapUpdater: map updater. It is responsible for updating the tree, and writing to db type MapUpdater struct { Fetcher logfetcher.Fetcher - smt *trie.Trie Conn db.Conn lastUpdatedIndex int64 + // Don't use a Trie in memory. The updater just creates one when needed and disposes of it + // afterwards. } // NewMapUpdater: return a new map updater. @@ -34,13 +35,6 @@ func NewMapUpdater(config *db.Configuration, url string) (*MapUpdater, error) { return nil, fmt.Errorf("NewMapUpdater | db.Connect | %w", err) } - // SMT - smt, err := trie.NewTrie(nil, common.SHA256Hash, dbConn) - if err != nil { - return nil, fmt.Errorf("NewMapServer | NewTrie | %w", err) - } - smt.CacheHeightLimit = 32 - fetcher, err := logfetcher.NewLogFetcher(url) if err != nil { return nil, err @@ -48,7 +42,6 @@ func NewMapUpdater(config *db.Configuration, url string) (*MapUpdater, error) { return &MapUpdater{ Fetcher: fetcher, - smt: smt, Conn: dbConn, }, nil } @@ -109,7 +102,7 @@ func (u *MapUpdater) UpdateNextBatch(ctx context.Context) (int, error) { } // UpdateCertsLocally: add certs (in the form of asn.1 encoded byte arrays) directly without querying log -func (mapUpdater *MapUpdater) UpdateCertsLocally(ctx context.Context, certList [][]byte, certChainList [][][]byte) error { +func (u *MapUpdater) UpdateCertsLocally(ctx context.Context, certList [][]byte, certChainList [][][]byte) error { expirations := make([]*time.Time, 0, len(certList)) certs := make([]*ctx509.Certificate, 0, len(certList)) certChains := make([][]*ctx509.Certificate, 0, len(certList)) @@ -131,20 +124,28 @@ func (mapUpdater *MapUpdater) UpdateCertsLocally(ctx context.Context, certList [ certChains = append(certChains, chain) } certs, IDs, parentIDs, names := util.UnfoldCerts(certs, certChains) - return UpdateWithKeepExisting(ctx, mapUpdater.Conn, names, IDs, parentIDs, certs, expirations, nil) + return UpdateWithKeepExisting(ctx, u.Conn, names, IDs, parentIDs, certs, expirations, nil) } // UpdatePolicyCerts: update RPC and PC from url. Currently just mock PC and RPC -func (mapUpdater *MapUpdater) UpdatePolicyCerts(ctx context.Context, ctUrl string, startIdx, endIdx int64) error { +func (u *MapUpdater) UpdatePolicyCerts(ctx context.Context, ctUrl string, startIdx, endIdx int64) error { // get PC and RPC first rpcList, err := logfetcher.GetPCAndRPCs(ctUrl, startIdx, endIdx, 20) if err != nil { return fmt.Errorf("CollectCerts | GetPCAndRPC | %w", err) } - return mapUpdater.updatePolicyCerts(ctx, rpcList) + return u.updatePolicyCerts(ctx, rpcList) } -func (mapUpdater *MapUpdater) updateCertBatch( +func (u *MapUpdater) UpdateSMT(ctx context.Context) error { + return UpdateSMT(ctx, u.Conn) +} + +func (u *MapUpdater) CoalescePayloadsForDirtyDomains(ctx context.Context) error { + return CoalescePayloadsForDirtyDomains(ctx, u.Conn) +} + +func (u *MapUpdater) updateCertBatch( ctx context.Context, leafCerts []*ctx509.Certificate, chains [][]*ctx509.Certificate, @@ -163,7 +164,7 @@ func (mapUpdater *MapUpdater) updateCertBatch( // Process whole batch. return UpdateWithKeepExisting( ctx, - mapUpdater.Conn, + u.Conn, names, certIDs, parentIDs, @@ -173,7 +174,7 @@ func (mapUpdater *MapUpdater) updateCertBatch( ) } -func (mapUpdater *MapUpdater) updatePolicyCerts( +func (u *MapUpdater) updatePolicyCerts( ctx context.Context, rpcs []*common.PolicyCertificate, ) error { @@ -316,21 +317,18 @@ func UpdateSMTfromDomains( // UpdateSMT reads all the dirty domains (pending to update their contents in the SMT), creates // a SMT Trie, loads it, and updates its entries with the new values. // It finally commits the Trie and saves its root in the DB. -func UpdateSMT(ctx context.Context, conn db.Conn, cacheHeight int) error { +func UpdateSMT(ctx context.Context, conn db.Conn) error { // Load root. - var root []byte - if rootID, err := conn.LoadRoot(ctx); err != nil { + root, err := loadRoot(ctx, conn) + if err != nil { return err - } else if rootID != nil { - root = rootID[:] } - // Load SMT. smtTrie, err := trie.NewTrie(root, common.SHA256Hash, conn) if err != nil { - err = fmt.Errorf("with root \"%s\", creating NewTrie: %w", hex.EncodeToString(root), err) - panic(err) + return fmt.Errorf("with root \"%s\", creating NewTrie: %w", hex.EncodeToString(root), err) } + // smtTrie.CacheHeightLimit = 32 // Get the dirty domains. domains, err := conn.RetrieveDirtyDomains(ctx) @@ -351,6 +349,16 @@ func UpdateSMT(ctx context.Context, conn db.Conn, cacheHeight int) error { return nil } +func loadRoot(ctx context.Context, conn db.Conn) ([]byte, error) { + var root []byte + if rootID, err := conn.LoadRoot(ctx); err != nil { + return nil, err + } else if rootID != nil { + root = rootID[:] + } + return root, nil +} + func insertCerts(ctx context.Context, conn db.Conn, names [][]string, ids, parentIDs []*common.SHA256Output, expirations []*time.Time, payloads [][]byte) error {