Skip to content

Commit

Permalink
Add Prune certs to DB.
Browse files Browse the repository at this point in the history
Use it in the map server update cycle.
  • Loading branch information
juagargi committed Aug 11, 2023
1 parent 0bea89c commit ee8d1a0
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 20 deletions.
24 changes: 20 additions & 4 deletions cmd/mapserver/mapserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,16 @@ func (s *MapServer) PruneAndUpdate(ctx context.Context) error {
}

func (s *MapServer) pruneAndUpdate(ctx context.Context) {
// prune() and update() both send an answer to the updateErrChan. Refrain from updating if
// pruning failed.
s.prune(ctx)
err := <-s.updateErrChan
if err != nil {
s.updateErrChan <- err
return
}

// update() will send its own response to the updateErrChan.
s.update(ctx)
}

Expand All @@ -119,16 +128,25 @@ func (s *MapServer) prune(ctx context.Context) {
return time.Now().UTC().Format(time.RFC3339)
}
fmt.Printf("======== prune started at %s\n", getTime())
defer fmt.Printf("======== prune finished at %s\n\n", getTime())

// deleteme TODO
fmt.Printf("======== prune finished at %s\n\n", getTime())
n, err := s.Updater.Conn.PruneCerts(ctx, time.Now())
if err != nil {
s.updateErrChan <- fmt.Errorf("pruning: %w", err)
}
fmt.Printf("pruning: %d certs removed\n", n)

s.updateErrChan <- error(nil) // Always answer something.
}

func (s *MapServer) update(ctx context.Context) {
getTime := func() string {
return time.Now().UTC().Format(time.RFC3339)
}

fmt.Printf("======== update started at %s\n", getTime())
defer fmt.Printf("======== update finished at %s\n\n", getTime())

if err := s.Updater.StartFetchingRemaining(); err != nil {

s.updateErrChan <- fmt.Errorf("retrieving start and end indices: %w", err)
Expand All @@ -148,8 +166,6 @@ func (s *MapServer) update(ctx context.Context) {
}
s.Updater.StopFetching()

fmt.Printf("======== update finished at %s\n\n", getTime())

// Queue answer in form of an error:
s.updateErrChan <- error(nil)
}
4 changes: 4 additions & 0 deletions pkg/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ type certs interface {
// UpdateLastCertIndexWritten updates the index of the last certificate written into the DB.
// The url specifies the CT log server from which this index comes from.
UpdateLastCertIndexWritten(ctx context.Context, url string, index int64) error

// PruneCerts removes all certificates that are no longer valid according to the paramter.
// I.e. any certificate whose NotAfter date is equal or before the parameter.
PruneCerts(ctx context.Context, now time.Time) (int64, error)
}

type policies interface {
Expand Down
20 changes: 20 additions & 0 deletions pkg/db/mysql/certs.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,12 @@ func (c *mysqlDB) UpdateLastCertIndexWritten(ctx context.Context, url string, in
return err
}

// PruneCerts removes all certificates that are no longer valid according to the paramter.
// I.e. any certificate whose NotAfter date is equal or before the parameter.
func (c *mysqlDB) PruneCerts(ctx context.Context, now time.Time) (int64, error) {
return c.pruneCerts(ctx, now)
}

// checkCertsExist should not be called with larger than ~1000 elements, the query being used
// may fail with a message like:
// Error 1436 (HY000): Thread stack overrun: 1028624 bytes used of a 1048576 byte stack,
Expand Down Expand Up @@ -205,3 +211,17 @@ func (c *mysqlDB) checkCertsExist(ctx context.Context, ids []*common.SHA256Outpu

return nil
}

func (c *mysqlDB) pruneCerts(ctx context.Context, now time.Time) (int64, error) {
// A certificate is valid if its NotAfter is greater or equal than now.
// We thus look for certificates with expiration less than now.

// Simply remove all expired certificates.

str := "DELETE FROM certs WHERE expiration < ?"
res, err := c.db.ExecContext(ctx, str, now)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
1 change: 1 addition & 0 deletions pkg/db/mysql/conf.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ func WithDefaults() db.ConfigurationModFunction {
"interpolateParams": "true", // 1 round trip per query
"collation": "binary",
"maxAllowedPacket": "1073741824", // 1G (cannot use "1G" as the driver uses Atoi)
"parseTime": "true", // driver parses DATETIME into time.Time
}
for k, v := range defaults {
c.Values[k] = v
Expand Down
69 changes: 68 additions & 1 deletion pkg/db/mysql/mysql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,11 +240,78 @@ func TestLastCertIndexWritten(t *testing.T) {
n, err = conn.LastCertIndexWritten(ctx, "doesnt exist")
require.NoError(t, err)
require.Equal(t, int64(-1), n)

}

func TestPruneCerts(t *testing.T) {
ctx, cancelF := context.WithTimeout(context.Background(), 3*time.Second)
defer cancelF()

// Configure a test DB.
config, removeF := testdb.ConfigureTestDB(t)
// defer removeF()
_ = removeF

// Connect to the DB.
conn := testdb.Connect(t, config)
defer conn.Close()

// Prepare test data.
// a.com 's chain will not expire.
// b.com 's chain will expire at its leaf.
// c.com 's chain will expire only its root.
expiredTime := util.TimeFromSecs(100)
now := expiredTime.Add(time.Hour)
leafNames := []string{
"a.com",
"b.com",
"c.com",
}
certs, certIDs, parentIDs, names := testCertHierarchyForLeafs(t, leafNames)
// Check that no certificate is expired yet.
for i, cert := range certs {
require.False(t, now.After(cert.NotAfter),
"failed test data precondition at %d, with value %s",
i, cert.NotAfter)
}
// Check that ensure that the data setup has not changed (leaves at the expected index, etc).
require.Equal(t, len(leafNames)*4, len(certs))
require.Equal(t, len(certs), len(certIDs))
require.Equal(t, len(certs), len(parentIDs))
require.Equal(t, len(certs), len(names))
// Modify b.com: only the 2 leaf certificates.
c := certs[4*1+2] // first chain of b.com
require.Equal(t, "b.com", c.Subject.CommonName) // assert that the test data is still correct.
c.NotAfter = expiredTime
c = certs[4*1+3] // second chain of b.com
require.Equal(t, "b.com", c.Subject.CommonName) // assert that the test data is still correct.
c.NotAfter = expiredTime
// Modify c.com: only the single root of its two chains.
c = certs[4*2] // root of both chains for c.com
require.Equal(t, "c0.com", c.Subject.CommonName) // assert that the test data is still correct.
c.NotAfter = expiredTime

// Ingest data into DB.
err := updater.UpdateWithKeepExisting(ctx, conn, names, certIDs, parentIDs,
certs, util.ExtractExpirations(certs), nil)
require.NoError(t, err)
// Coalescing of payloads.
err = updater.CoalescePayloadsForDirtyDomains(ctx, conn)
require.NoError(t, err)

// Now test that prune removes some of them.
n, err := conn.PruneCerts(ctx, now)
require.NoError(t, err)
// We have two leafs in b.com + root of two chains (but ONE root)
require.Equal(t, int64((1+1)+(1)), n)

// deleteme TODO now test that we can still query
}

// testCertHierarchyForLeafs returns a hierarchy per leaf certificate. Each certificate is composed
// of two mock chains, like: leaf->c1.com->c0.com, leaf->c0.com , created using the function
// BuildTestRandomCertHierarchy.
// BuildTestRandomCertHierarchy. That function always returns four certificates, in this order:
// c0.com,c1.com, leaf->c1->c0, leaf->c0
func testCertHierarchyForLeafs(t tests.T, leaves []string) (certs []*ctx509.Certificate,
certIDs, parentCertIDs []*common.SHA256Output, certNames [][]string) {

Expand Down
29 changes: 24 additions & 5 deletions pkg/mapserver/updater/updater.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func (u *MapUpdater) UpdateNextBatch(ctx context.Context) (int, error) {
if err != nil {
return 0, fmt.Errorf("fetcher: %w", err)
}
n, err := len(certs), u.updateCerts(ctx, certs, chains)
n, err := len(certs), u.updateCertBatch(ctx, certs, chains)
if err == nil {
// Store the last index obtained from the fetcher as updated.
u.lastUpdatedIndex += int64(n)
Expand Down Expand Up @@ -144,14 +144,33 @@ func (mapUpdater *MapUpdater) UpdatePolicyCerts(ctx context.Context, ctUrl strin
return mapUpdater.updatePolicyCerts(ctx, rpcList)
}

func (mapUpdater *MapUpdater) updateCerts(
func (mapUpdater *MapUpdater) updateCertBatch(
ctx context.Context,
certs []*ctx509.Certificate,
leafCerts []*ctx509.Certificate,
chains [][]*ctx509.Certificate,
) error {

// TODO(juagargi)
return nil
if len(leafCerts) != len(chains) {
return fmt.Errorf("inconsistent certs and chains count: %d and %d respectively",
len(leafCerts), len(chains))
}

certs, certIDs, parentIDs, names := util.UnfoldCerts(leafCerts, chains)

// Extract expirations.
expirations := util.ExtractExpirations(certs)

// Process whole batch.
return UpdateWithKeepExisting(
ctx,
mapUpdater.Conn,
names,
certIDs,
parentIDs,
certs,
expirations,
nil, // no policies in this call
)
}

func (mapUpdater *MapUpdater) updatePolicyCerts(
Expand Down
19 changes: 9 additions & 10 deletions pkg/mapserver/updater/updater_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package updater
import (
"context"
"encoding/hex"
"fmt"
"math/rand"
"testing"
"time"
Expand Down Expand Up @@ -40,10 +39,6 @@ func TestUpdateWithKeepExisting(t *testing.T) {
"leaf.certs.com",
"example.certs.com",
}
// Add many more leaf certificates for the test.
for i := 0; i < 20000; i++ {
leafCerts = append(leafCerts, fmt.Sprintf("leaf-%d.auto.certs.com", i+1))
}

// Create a random certificate test hierarchy for each leaf.
var certs []*ctx509.Certificate
Expand All @@ -62,11 +57,9 @@ func TestUpdateWithKeepExisting(t *testing.T) {
pols := random.BuildTestRandomPolicyHierarchy(t, "a-domain-name.thing")

// Update with certificates and policies.
t0 := time.Now()
err := UpdateWithKeepExisting(ctx, conn, certNames, certIDs, parentCertIDs,
certs, util.ExtractExpirations(certs), pols)
require.NoError(t, err)
t.Logf("time needed to update %d certificates: %s", len(certIDs), time.Since(t0))

// Coalescing of payloads.
err = CoalescePayloadsForDirtyDomains(ctx, conn)
Expand All @@ -78,6 +71,7 @@ func TestUpdateWithKeepExisting(t *testing.T) {
// t.Logf("%s: %s", leaf, hex.EncodeToString(domainID[:]))
gotCertIDsID, gotCertIDs, err := conn.RetrieveDomainCertificatesIDs(ctx, domainID)
require.NoError(t, err)
// Expect as many IDs as total certs per leaf ( #certs / #leafs ). Each ID is 32 bytes:
expectedSize := common.SHA256Size * len(certs) / len(leafCerts)
require.Len(t, gotCertIDs, expectedSize, "bad length, should be %d but it's %d",
expectedSize, len(gotCertIDs))
Expand Down Expand Up @@ -188,7 +182,11 @@ func TestMapUpdaterStartFetching(t *testing.T) {
if (onReturnNextBatchCalls * batchSize) > fetcher.size {
n = fetcher.size % batchSize
}
return make([]*ctx509.Certificate, n),
randomCerts := make([]*ctx509.Certificate, n)
for i := range randomCerts {
randomCerts[i] = random.RandomX509Cert(t, t.Name())
}
return randomCerts,
make([][]*ctx509.Certificate, n),
nil
},
Expand Down Expand Up @@ -314,9 +312,10 @@ func TestMapUpdaterStartFetchingRemaining(t *testing.T) {
return fetcher.size-onReturnNextBatchCalls > 0
},
onReturnNextBatch: func() ([]*ctx509.Certificate, [][]*ctx509.Certificate, error) {
// Return one cert and chain.
// Return one cert and chain with no parents.
onReturnNextBatchCalls++
return make([]*ctx509.Certificate, 1), make([][]*ctx509.Certificate, 1), nil
return []*ctx509.Certificate{random.RandomX509Cert(t, "a.com")},
make([][]*ctx509.Certificate, 1), nil
},
}
updater.Fetcher = fetcher
Expand Down

0 comments on commit ee8d1a0

Please sign in to comment.