Skip to content

Commit

Permalink
The mapserver service correctly prunes now.
Browse files Browse the repository at this point in the history
  • Loading branch information
juagargi committed Aug 17, 2023
1 parent 4250cc8 commit f9ed2c3
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 42 deletions.
2 changes: 1 addition & 1 deletion cmd/ingest/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ func mainFunction() int {

// Now update the SMT Trie with the changed domains:
fmt.Println("Starting SMT update ...")
err = updater.UpdateSMT(ctx, conn, 32)
err = updater.UpdateSMT(ctx, conn)
exitIfError(err)
fmt.Println("Done SMT update.")

Expand Down
59 changes: 43 additions & 16 deletions cmd/mapserver/mapserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,49 +123,76 @@ func (s *MapServer) pruneAndUpdate(ctx context.Context) {
s.update(ctx)
}

// prune only removes the affected certificates from the certs table and adds the affected domains
// to the dirty table. Because update is always called right after prune, we don't need to first
// compute the coalesced domains for those dirty domains after prune and before update. It is
// sufficient to call CoalescePayloadsForDirtyDomains after update and it will take care of all
// dirty domains, coming from both prune and update.
func (s *MapServer) prune(ctx context.Context) {
getTime := func() string {
return time.Now().UTC().Format(time.RFC3339)
}
fmt.Printf("======== prune started at %s\n", getTime())
defer fmt.Printf("======== prune finished at %s\n\n", getTime())

// deleteme TODO
n, err := s.Updater.Conn.PruneCerts(ctx, time.Now())
err := s.Updater.Conn.PruneCerts(ctx, time.Now())
if err != nil {
s.updateErrChan <- fmt.Errorf("pruning: %w", err)
}
fmt.Printf("pruning: %d certs removed\n", n)

s.updateErrChan <- error(nil) // Always answer something.
}

func (s *MapServer) update(ctx context.Context) {
getTime := func() string {
return time.Now().UTC().Format(time.RFC3339)
}
fmt.Printf("======== update started at %s\n", getTime())
defer fmt.Printf("======== update finished at %s\n\n", getTime())

if err := s.Updater.StartFetchingRemaining(); err != nil {
if err := s.updateCerts(ctx); err != nil {
s.updateErrChan <- err
return
}
// TODO(juagargi) do policy certificates here.

fmt.Printf("coalescing certificate payloads at %s\n", getTime())
if err := s.Updater.CoalescePayloadsForDirtyDomains(ctx); err != nil {
s.updateErrChan <- err
return
}

// Update SMT.
fmt.Printf("updating SMT at %s\n", getTime())
if err := s.Updater.UpdateSMT(ctx); err != nil {
s.updateErrChan <- fmt.Errorf("updating SMT: %w", err)
return
}

s.updateErrChan <- fmt.Errorf("retrieving start and end indices: %w", err)
// Cleanup.
fmt.Printf("cleaning up at %s\n", getTime())
if err := s.Updater.Conn.CleanupDirty(ctx); err != nil {
s.updateErrChan <- fmt.Errorf("cleaning up DB: %w", err)
return
}

// Always queue answer in form of an error:
s.updateErrChan <- error(nil)
}

func (s *MapServer) updateCerts(ctx context.Context) error {
if err := s.Updater.StartFetchingRemaining(); err != nil {
return fmt.Errorf("retrieving start and end indices: %w", err)
}
defer s.Updater.StopFetching()

// Main update loop.
for s.Updater.NextBatch(ctx) {
n, err := s.Updater.UpdateNextBatch(ctx)

fmt.Printf("updated %5d certs batch at %s\n", n, getTime())
if err != nil {
// We stop the loop here, as probably requires manual inspection of the logs, etc.
fmt.Printf("error: %s\n", err)
break
return fmt.Errorf("updating next batch of x509 certificates: %w", err)
}
}
s.Updater.StopFetching()
return nil
}

// Queue answer in form of an error:
s.updateErrChan <- error(nil)
func getTime() string {
return time.Now().UTC().Format(time.RFC3339)
}
2 changes: 1 addition & 1 deletion pkg/mapserver/responder/responder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func TestProof(t *testing.T) {
require.NoError(t, err)

// Create/update the SMT.
err = updater.UpdateSMT(ctx, conn, 32)
err = updater.UpdateSMT(ctx, conn)
require.NoError(t, err)

// And cleanup dirty, flagging the end of the update cycle.
Expand Down
56 changes: 32 additions & 24 deletions pkg/mapserver/updater/updater.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ import (
// MapUpdater: map updater. It is responsible for updating the tree, and writing to db
type MapUpdater struct {
Fetcher logfetcher.Fetcher
smt *trie.Trie
Conn db.Conn
lastUpdatedIndex int64
// Don't use a Trie in memory. The updater just creates one when needed and disposes of it
// afterwards.
}

// NewMapUpdater: return a new map updater.
Expand All @@ -34,21 +35,13 @@ func NewMapUpdater(config *db.Configuration, url string) (*MapUpdater, error) {
return nil, fmt.Errorf("NewMapUpdater | db.Connect | %w", err)
}

// SMT
smt, err := trie.NewTrie(nil, common.SHA256Hash, dbConn)
if err != nil {
return nil, fmt.Errorf("NewMapServer | NewTrie | %w", err)
}
smt.CacheHeightLimit = 32

fetcher, err := logfetcher.NewLogFetcher(url)
if err != nil {
return nil, err
}

return &MapUpdater{
Fetcher: fetcher,
smt: smt,
Conn: dbConn,
}, nil
}
Expand Down Expand Up @@ -109,7 +102,7 @@ func (u *MapUpdater) UpdateNextBatch(ctx context.Context) (int, error) {
}

// UpdateCertsLocally: add certs (in the form of asn.1 encoded byte arrays) directly without querying log
func (mapUpdater *MapUpdater) UpdateCertsLocally(ctx context.Context, certList [][]byte, certChainList [][][]byte) error {
func (u *MapUpdater) UpdateCertsLocally(ctx context.Context, certList [][]byte, certChainList [][][]byte) error {
expirations := make([]*time.Time, 0, len(certList))
certs := make([]*ctx509.Certificate, 0, len(certList))
certChains := make([][]*ctx509.Certificate, 0, len(certList))
Expand All @@ -131,20 +124,28 @@ func (mapUpdater *MapUpdater) UpdateCertsLocally(ctx context.Context, certList [
certChains = append(certChains, chain)
}
certs, IDs, parentIDs, names := util.UnfoldCerts(certs, certChains)
return UpdateWithKeepExisting(ctx, mapUpdater.Conn, names, IDs, parentIDs, certs, expirations, nil)
return UpdateWithKeepExisting(ctx, u.Conn, names, IDs, parentIDs, certs, expirations, nil)
}

// UpdatePolicyCerts: update RPC and PC from url. Currently just mock PC and RPC
func (mapUpdater *MapUpdater) UpdatePolicyCerts(ctx context.Context, ctUrl string, startIdx, endIdx int64) error {
func (u *MapUpdater) UpdatePolicyCerts(ctx context.Context, ctUrl string, startIdx, endIdx int64) error {
// get PC and RPC first
rpcList, err := logfetcher.GetPCAndRPCs(ctUrl, startIdx, endIdx, 20)
if err != nil {
return fmt.Errorf("CollectCerts | GetPCAndRPC | %w", err)
}
return mapUpdater.updatePolicyCerts(ctx, rpcList)
return u.updatePolicyCerts(ctx, rpcList)
}

func (mapUpdater *MapUpdater) updateCertBatch(
func (u *MapUpdater) UpdateSMT(ctx context.Context) error {
return UpdateSMT(ctx, u.Conn)
}

func (u *MapUpdater) CoalescePayloadsForDirtyDomains(ctx context.Context) error {
return CoalescePayloadsForDirtyDomains(ctx, u.Conn)
}

func (u *MapUpdater) updateCertBatch(
ctx context.Context,
leafCerts []*ctx509.Certificate,
chains [][]*ctx509.Certificate,
Expand All @@ -163,7 +164,7 @@ func (mapUpdater *MapUpdater) updateCertBatch(
// Process whole batch.
return UpdateWithKeepExisting(
ctx,
mapUpdater.Conn,
u.Conn,
names,
certIDs,
parentIDs,
Expand All @@ -173,7 +174,7 @@ func (mapUpdater *MapUpdater) updateCertBatch(
)
}

func (mapUpdater *MapUpdater) updatePolicyCerts(
func (u *MapUpdater) updatePolicyCerts(
ctx context.Context,
rpcs []*common.PolicyCertificate,
) error {
Expand Down Expand Up @@ -316,21 +317,18 @@ func UpdateSMTfromDomains(
// UpdateSMT reads all the dirty domains (pending to update their contents in the SMT), creates
// a SMT Trie, loads it, and updates its entries with the new values.
// It finally commits the Trie and saves its root in the DB.
func UpdateSMT(ctx context.Context, conn db.Conn, cacheHeight int) error {
func UpdateSMT(ctx context.Context, conn db.Conn) error {
// Load root.
var root []byte
if rootID, err := conn.LoadRoot(ctx); err != nil {
root, err := loadRoot(ctx, conn)
if err != nil {
return err
} else if rootID != nil {
root = rootID[:]
}

// Load SMT.
smtTrie, err := trie.NewTrie(root, common.SHA256Hash, conn)
if err != nil {
err = fmt.Errorf("with root \"%s\", creating NewTrie: %w", hex.EncodeToString(root), err)
panic(err)
return fmt.Errorf("with root \"%s\", creating NewTrie: %w", hex.EncodeToString(root), err)
}
// smtTrie.CacheHeightLimit = 32

// Get the dirty domains.
domains, err := conn.RetrieveDirtyDomains(ctx)
Expand All @@ -351,6 +349,16 @@ func UpdateSMT(ctx context.Context, conn db.Conn, cacheHeight int) error {
return nil
}

func loadRoot(ctx context.Context, conn db.Conn) ([]byte, error) {
var root []byte
if rootID, err := conn.LoadRoot(ctx); err != nil {
return nil, err
} else if rootID != nil {
root = rootID[:]
}
return root, nil
}

func insertCerts(ctx context.Context, conn db.Conn, names [][]string,
ids, parentIDs []*common.SHA256Output, expirations []*time.Time, payloads [][]byte) error {

Expand Down

0 comments on commit f9ed2c3

Please sign in to comment.