diff --git a/cmd/trivy-java-db/main.go b/cmd/trivy-java-db/main.go index 889c3ed..9ad4d12 100644 --- a/cmd/trivy-java-db/main.go +++ b/cmd/trivy-java-db/main.go @@ -66,10 +66,13 @@ func init() { } func crawl(ctx context.Context) error { - c := crawler.NewCrawler(crawler.Option{ + c, err := crawler.NewCrawler(crawler.Option{ Limit: int64(limit), CacheDir: cacheDir, }) + if err != nil { + return xerrors.Errorf("unable to create new Crawler: %w", err) + } if err := c.Crawl(ctx); err != nil { return xerrors.Errorf("crawl error: %w", err) } @@ -77,18 +80,18 @@ func crawl(ctx context.Context) error { } func build() error { - if err := db.Reset(cacheDir); err != nil { - return xerrors.Errorf("db reset error: %w", err) - } - dbDir := filepath.Join(cacheDir, "db") + dbDir := db.Dir(cacheDir) slog.Info("Database", slog.String("path", dbDir)) dbc, err := db.New(dbDir) if err != nil { return xerrors.Errorf("db create error: %w", err) } - if err = dbc.Init(); err != nil { - return xerrors.Errorf("db init error: %w", err) + if !db.Exists(dbDir) { + if err = dbc.Init(); err != nil { + return xerrors.Errorf("db init error: %w", err) + } } + meta := db.NewMetadata(dbDir) b := builder.NewBuilder(dbc, meta) if err = b.Build(cacheDir); err != nil { diff --git a/pkg/crawler/crawler.go b/pkg/crawler/crawler.go index 65f2a4c..33fb1d1 100644 --- a/pkg/crawler/crawler.go +++ b/pkg/crawler/crawler.go @@ -21,6 +21,7 @@ import ( "golang.org/x/sync/semaphore" "golang.org/x/xerrors" + "github.com/aquasecurity/trivy-java-db/pkg/db" "github.com/aquasecurity/trivy-java-db/pkg/fileutil" "github.com/aquasecurity/trivy-java-db/pkg/types" ) @@ -30,6 +31,7 @@ const mavenRepoURL = "https://repo.maven.apache.org/maven2/" type Crawler struct { dir string http *retryablehttp.Client + dbc *db.DB rootUrl string wg sync.WaitGroup @@ -44,7 +46,7 @@ type Option struct { CacheDir string } -func NewCrawler(opt Option) Crawler { +func NewCrawler(opt Option) (Crawler, error) { client := retryablehttp.NewClient() client.RetryMax = 10 client.Logger = slog.Default() @@ -77,14 +79,25 @@ func NewCrawler(opt Option) Crawler { indexDir := filepath.Join(opt.CacheDir, "indexes") slog.Info("Index dir", slog.String("path", indexDir)) + var dbc db.DB + if db.Exists(opt.CacheDir) { + var err error + dbc, err = db.New(opt.CacheDir) + if err != nil { + return Crawler{}, xerrors.Errorf("unable to open DB: %w", err) + } + + } + return Crawler{ dir: indexDir, http: client, + dbc: &dbc, rootUrl: opt.RootUrl, urlCh: make(chan string, opt.Limit*10), limit: semaphore.NewWeighted(opt.Limit), - } + }, nil } func (c *Crawler) Crawl(ctx context.Context) error { @@ -222,7 +235,12 @@ func (c *Crawler) Visit(ctx context.Context, url string) error { } func (c *Crawler) crawlSHA1(ctx context.Context, baseURL string, meta *Metadata, dirs []string) error { - var foundVersions []Version + var foundVersions []types.Version + // Get versions from the DB (if exists) to reduce the number of requests to the server + savedVersion, err := c.versionsFromDB(meta.ArtifactID, meta.GroupID) + if err != nil { + return xerrors.Errorf("unable to get list of versions from DB: %w", err) + } // Check each version dir to find links to `*.jar.sha1` files. for _, dir := range dirs { dirURL := baseURL + dir @@ -234,34 +252,37 @@ func (c *Crawler) crawlSHA1(ctx context.Context, baseURL string, meta *Metadata, // Remove the `/` suffix to correctly compare file versions with version from directory name. dirVersion := strings.TrimSuffix(dir, "/") var dirVersionSha1 []byte - var versions []Version + var versions []types.Version + for _, sha1Url := range sha1Urls { - sha1, err := c.fetchSHA1(ctx, sha1Url) - if err != nil { - return xerrors.Errorf("unable to fetch sha1: %s", err) - } - if ver := versionFromSha1URL(meta.ArtifactID, sha1Url); ver != "" && len(sha1) != 0 { - // Save sha1 for the file where the version is equal to the version from the directory name in order to remove duplicates later - // Avoid overwriting dirVersion when inserting versions into the database (sha1 is uniq blob) - // e.g. `cudf-0.14-cuda10-1.jar.sha1` should not overwrite `cudf-0.14.jar.sha1` - // https://repo.maven.apache.org/maven2/ai/rapids/cudf/0.14/ - if ver == dirVersion { - dirVersionSha1 = sha1 - } else { - versions = append(versions, Version{ - Version: ver, - SHA1: sha1, - }) + ver := versionFromSha1URL(meta.ArtifactID, sha1Url) + sha1, ok := savedVersion[ver] + if !ok { + sha1, err = c.fetchSHA1(ctx, sha1Url) + if err != nil { + return xerrors.Errorf("unable to fetch sha1: %s", err) } } + // Save sha1 for the file where the version is equal to the version from the directory name in order to remove duplicates later + // Avoid overwriting dirVersion when inserting versions into the database (sha1 is uniq blob) + // e.g. `cudf-0.14-cuda10-1.jar.sha1` should not overwrite `cudf-0.14.jar.sha1` + // https://repo.maven.apache.org/maven2/ai/rapids/cudf/0.14/ + if ver == dirVersion { + dirVersionSha1 = sha1 + } else { + versions = append(versions, types.Version{ + Version: ver, + SHA1: sha1, + }) + } } // Remove duplicates of dirVersionSha1 - versions = lo.Filter(versions, func(v Version, _ int) bool { + versions = lo.Filter(versions, func(v types.Version, _ int) bool { return !bytes.Equal(v.SHA1, dirVersionSha1) }) if dirVersionSha1 != nil { - versions = append(versions, Version{ + versions = append(versions, types.Version{ Version: dirVersion, SHA1: dirVersionSha1, }) @@ -410,6 +431,13 @@ func (c *Crawler) httpGet(ctx context.Context, url string) (*http.Response, erro return resp, nil } +func (c *Crawler) versionsFromDB(artifactID, groupID string) (map[string][]byte, error) { + if c.dbc == nil { + return nil, nil + } + return c.dbc.SelectVersionsByArtifactIDAndGroupID(artifactID, groupID) +} + func randomSleep() { // Seed rand r := rand.New(rand.NewSource(int64(time.Now().Nanosecond()))) diff --git a/pkg/crawler/crawler_test.go b/pkg/crawler/crawler_test.go index 69a004c..c53f5f6 100644 --- a/pkg/crawler/crawler_test.go +++ b/pkg/crawler/crawler_test.go @@ -2,15 +2,21 @@ package crawler_test import ( "context" + "encoding/hex" "net/http" "net/http/httptest" "os" "path/filepath" "testing" + "github.com/aquasecurity/trivy-java-db/pkg/dbtest" + "github.com/aquasecurity/trivy-java-db/pkg/types" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/aquasecurity/trivy-java-db/pkg/crawler" + + _ "modernc.org/sqlite" ) func TestCrawl(t *testing.T) { @@ -18,6 +24,7 @@ func TestCrawl(t *testing.T) { name string limit int64 fileNames map[string]string + withDb bool goldenPath string filePath string wantErr string @@ -42,6 +49,27 @@ func TestCrawl(t *testing.T) { goldenPath: "testdata/happy/abbot.json.golden", filePath: "indexes/abbot/abbot.json", }, + { + name: "happy path with DB", + withDb: true, + limit: 1, + fileNames: map[string]string{ + "/maven2/": "testdata/happy/index.html", + "/maven2/abbot/": "testdata/happy/abbot.html", + "/maven2/abbot/abbot/": "testdata/happy/abbot_abbot.html", + "/maven2/abbot/abbot/maven-metadata.xml": "testdata/happy/maven-metadata.xml", + "/maven2/abbot/abbot/0.12.3/": "testdata/happy/abbot_abbot_0.12.3.html", + "/maven2/abbot/abbot/0.12.3/abbot-0.12.3.jar.sha1": "testdata/happy/abbot-0.12.3.jar.sha1", + "/maven2/abbot/abbot/0.13.0/": "testdata/happy/abbot_abbot_0.13.0.html", + "/maven2/abbot/abbot/0.13.0/abbot-0.13.0.jar.sha1": "testdata/happy/abbot-0.13.0.jar.sha1", + "/maven2/abbot/abbot/0.13.0/abbot-0.13.0-copy.jar.sha1": "testdata/happy/abbot-0.13.0-copy.jar.sha1", + "/maven2/abbot/abbot/1.4.0/": "testdata/happy/abbot_abbot_1.4.0.html", + "/maven2/abbot/abbot/1.4.0/abbot-1.4.0.jar.sha1": "testdata/happy/abbot-1.4.0.jar.sha1", + "/maven2/abbot/abbot/1.4.0/abbot-1.4.0-lite.jar.sha1": "testdata/happy/abbot-1.4.0-lite.jar.sha1", + }, + goldenPath: "testdata/happy/abbot.json.golden", + filePath: "indexes/abbot/abbot.json", + }, { name: "sad path", limit: 2, @@ -76,13 +104,24 @@ func TestCrawl(t *testing.T) { defer ts.Close() tmpDir := t.TempDir() - cl := crawler.NewCrawler(crawler.Option{ + if tt.withDb { + dbc, err := dbtest.InitDB(t, []types.Index{ + indexAbbot123, + indexAbbot130, + }) + require.NoError(t, err) + + tmpDir = dbc.Dir() + } + + cl, err := crawler.NewCrawler(crawler.Option{ RootUrl: ts.URL + "/maven2/", Limit: tt.limit, CacheDir: tmpDir, }) + require.NoError(t, err) - err := cl.Crawl(context.Background()) + err = cl.Crawl(context.Background()) if tt.wantErr != "" { assert.ErrorContains(t, err, tt.wantErr) return @@ -97,5 +136,24 @@ func TestCrawl(t *testing.T) { assert.JSONEq(t, string(want), string(got)) }) } - } + +var ( + abbot123Sha1b, _ = hex.DecodeString("51d28a27d919ce8690a40f4f335b9d591ceb16e9") + indexAbbot123 = types.Index{ + GroupID: "abbot", + ArtifactID: "abbot", + Version: "0.12.3", + SHA1: abbot123Sha1b, + ArchiveType: types.JarType, + } + + abbot130Sha1b, _ = hex.DecodeString("596d91e67631b0deb05fb685d8d1b6735f3e4f60") + indexAbbot130 = types.Index{ + GroupID: "abbot", + ArtifactID: "abbot", + Version: "0.13.0", + SHA1: abbot130Sha1b, + ArchiveType: types.JarType, + } +) diff --git a/pkg/crawler/testdata/happy/abbot_abbot_0.13.0.html b/pkg/crawler/testdata/happy/abbot_abbot_0.13.0.html index b5661f0..213b9af 100644 --- a/pkg/crawler/testdata/happy/abbot_abbot_0.13.0.html +++ b/pkg/crawler/testdata/happy/abbot_abbot_0.13.0.html @@ -17,9 +17,9 @@

abbot/abbot/0.13.0


../
-abbot-0.13.0.jar                                  2005-09-20 05:44    779426
-abbot-0.13.0.jar.md5                              2005-09-20 05:44        32
-abbot-0.13.0.jar.sha1                             2005-09-20 05:44        40
+abbot-0.13.0-copy.jar                                  2005-09-20 05:44    779426
+abbot-0.13.0-copy.jar.md5                              2005-09-20 05:44        32
+abbot-0.13.0-copy.jar.sha1                             2005-09-20 05:44        40
 abbot-0.13.0.jar                                  2005-09-20 05:44    779426      
 abbot-0.13.0.jar.md5                              2005-09-20 05:44        32      
 abbot-0.13.0.jar.sha1                             2005-09-20 05:44        40