diff --git a/cmd/trivy-java-db/main.go b/cmd/trivy-java-db/main.go
index 889c3ed..9ad4d12 100644
--- a/cmd/trivy-java-db/main.go
+++ b/cmd/trivy-java-db/main.go
@@ -66,10 +66,13 @@ func init() {
}
func crawl(ctx context.Context) error {
- c := crawler.NewCrawler(crawler.Option{
+ c, err := crawler.NewCrawler(crawler.Option{
Limit: int64(limit),
CacheDir: cacheDir,
})
+ if err != nil {
+ return xerrors.Errorf("unable to create new Crawler: %w", err)
+ }
if err := c.Crawl(ctx); err != nil {
return xerrors.Errorf("crawl error: %w", err)
}
@@ -77,18 +80,18 @@ func crawl(ctx context.Context) error {
}
func build() error {
- if err := db.Reset(cacheDir); err != nil {
- return xerrors.Errorf("db reset error: %w", err)
- }
- dbDir := filepath.Join(cacheDir, "db")
+ dbDir := db.Dir(cacheDir)
slog.Info("Database", slog.String("path", dbDir))
dbc, err := db.New(dbDir)
if err != nil {
return xerrors.Errorf("db create error: %w", err)
}
- if err = dbc.Init(); err != nil {
- return xerrors.Errorf("db init error: %w", err)
+ if !db.Exists(dbDir) {
+ if err = dbc.Init(); err != nil {
+ return xerrors.Errorf("db init error: %w", err)
+ }
}
+
meta := db.NewMetadata(dbDir)
b := builder.NewBuilder(dbc, meta)
if err = b.Build(cacheDir); err != nil {
diff --git a/pkg/crawler/crawler.go b/pkg/crawler/crawler.go
index 65f2a4c..33fb1d1 100644
--- a/pkg/crawler/crawler.go
+++ b/pkg/crawler/crawler.go
@@ -21,6 +21,7 @@ import (
"golang.org/x/sync/semaphore"
"golang.org/x/xerrors"
+ "github.com/aquasecurity/trivy-java-db/pkg/db"
"github.com/aquasecurity/trivy-java-db/pkg/fileutil"
"github.com/aquasecurity/trivy-java-db/pkg/types"
)
@@ -30,6 +31,7 @@ const mavenRepoURL = "https://repo.maven.apache.org/maven2/"
type Crawler struct {
dir string
http *retryablehttp.Client
+ dbc *db.DB
rootUrl string
wg sync.WaitGroup
@@ -44,7 +46,7 @@ type Option struct {
CacheDir string
}
-func NewCrawler(opt Option) Crawler {
+func NewCrawler(opt Option) (Crawler, error) {
client := retryablehttp.NewClient()
client.RetryMax = 10
client.Logger = slog.Default()
@@ -77,14 +79,25 @@ func NewCrawler(opt Option) Crawler {
indexDir := filepath.Join(opt.CacheDir, "indexes")
slog.Info("Index dir", slog.String("path", indexDir))
+ var dbc db.DB
+ if db.Exists(opt.CacheDir) {
+ var err error
+ dbc, err = db.New(opt.CacheDir)
+ if err != nil {
+ return Crawler{}, xerrors.Errorf("unable to open DB: %w", err)
+ }
+
+ }
+
return Crawler{
dir: indexDir,
http: client,
+ dbc: &dbc,
rootUrl: opt.RootUrl,
urlCh: make(chan string, opt.Limit*10),
limit: semaphore.NewWeighted(opt.Limit),
- }
+ }, nil
}
func (c *Crawler) Crawl(ctx context.Context) error {
@@ -222,7 +235,12 @@ func (c *Crawler) Visit(ctx context.Context, url string) error {
}
func (c *Crawler) crawlSHA1(ctx context.Context, baseURL string, meta *Metadata, dirs []string) error {
- var foundVersions []Version
+ var foundVersions []types.Version
+ // Get versions from the DB (if exists) to reduce the number of requests to the server
+ savedVersion, err := c.versionsFromDB(meta.ArtifactID, meta.GroupID)
+ if err != nil {
+ return xerrors.Errorf("unable to get list of versions from DB: %w", err)
+ }
// Check each version dir to find links to `*.jar.sha1` files.
for _, dir := range dirs {
dirURL := baseURL + dir
@@ -234,34 +252,37 @@ func (c *Crawler) crawlSHA1(ctx context.Context, baseURL string, meta *Metadata,
// Remove the `/` suffix to correctly compare file versions with version from directory name.
dirVersion := strings.TrimSuffix(dir, "/")
var dirVersionSha1 []byte
- var versions []Version
+ var versions []types.Version
+
for _, sha1Url := range sha1Urls {
- sha1, err := c.fetchSHA1(ctx, sha1Url)
- if err != nil {
- return xerrors.Errorf("unable to fetch sha1: %s", err)
- }
- if ver := versionFromSha1URL(meta.ArtifactID, sha1Url); ver != "" && len(sha1) != 0 {
- // Save sha1 for the file where the version is equal to the version from the directory name in order to remove duplicates later
- // Avoid overwriting dirVersion when inserting versions into the database (sha1 is uniq blob)
- // e.g. `cudf-0.14-cuda10-1.jar.sha1` should not overwrite `cudf-0.14.jar.sha1`
- // https://repo.maven.apache.org/maven2/ai/rapids/cudf/0.14/
- if ver == dirVersion {
- dirVersionSha1 = sha1
- } else {
- versions = append(versions, Version{
- Version: ver,
- SHA1: sha1,
- })
+ ver := versionFromSha1URL(meta.ArtifactID, sha1Url)
+ sha1, ok := savedVersion[ver]
+ if !ok {
+ sha1, err = c.fetchSHA1(ctx, sha1Url)
+ if err != nil {
+ return xerrors.Errorf("unable to fetch sha1: %s", err)
}
}
+ // Save sha1 for the file where the version is equal to the version from the directory name in order to remove duplicates later
+ // Avoid overwriting dirVersion when inserting versions into the database (sha1 is uniq blob)
+ // e.g. `cudf-0.14-cuda10-1.jar.sha1` should not overwrite `cudf-0.14.jar.sha1`
+ // https://repo.maven.apache.org/maven2/ai/rapids/cudf/0.14/
+ if ver == dirVersion {
+ dirVersionSha1 = sha1
+ } else {
+ versions = append(versions, types.Version{
+ Version: ver,
+ SHA1: sha1,
+ })
+ }
}
// Remove duplicates of dirVersionSha1
- versions = lo.Filter(versions, func(v Version, _ int) bool {
+ versions = lo.Filter(versions, func(v types.Version, _ int) bool {
return !bytes.Equal(v.SHA1, dirVersionSha1)
})
if dirVersionSha1 != nil {
- versions = append(versions, Version{
+ versions = append(versions, types.Version{
Version: dirVersion,
SHA1: dirVersionSha1,
})
@@ -410,6 +431,13 @@ func (c *Crawler) httpGet(ctx context.Context, url string) (*http.Response, erro
return resp, nil
}
+func (c *Crawler) versionsFromDB(artifactID, groupID string) (map[string][]byte, error) {
+ if c.dbc == nil {
+ return nil, nil
+ }
+ return c.dbc.SelectVersionsByArtifactIDAndGroupID(artifactID, groupID)
+}
+
func randomSleep() {
// Seed rand
r := rand.New(rand.NewSource(int64(time.Now().Nanosecond())))
diff --git a/pkg/crawler/crawler_test.go b/pkg/crawler/crawler_test.go
index 69a004c..c53f5f6 100644
--- a/pkg/crawler/crawler_test.go
+++ b/pkg/crawler/crawler_test.go
@@ -2,15 +2,21 @@ package crawler_test
import (
"context"
+ "encoding/hex"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
+ "github.com/aquasecurity/trivy-java-db/pkg/dbtest"
+ "github.com/aquasecurity/trivy-java-db/pkg/types"
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
"github.com/aquasecurity/trivy-java-db/pkg/crawler"
+
+ _ "modernc.org/sqlite"
)
func TestCrawl(t *testing.T) {
@@ -18,6 +24,7 @@ func TestCrawl(t *testing.T) {
name string
limit int64
fileNames map[string]string
+ withDb bool
goldenPath string
filePath string
wantErr string
@@ -42,6 +49,27 @@ func TestCrawl(t *testing.T) {
goldenPath: "testdata/happy/abbot.json.golden",
filePath: "indexes/abbot/abbot.json",
},
+ {
+ name: "happy path with DB",
+ withDb: true,
+ limit: 1,
+ fileNames: map[string]string{
+ "/maven2/": "testdata/happy/index.html",
+ "/maven2/abbot/": "testdata/happy/abbot.html",
+ "/maven2/abbot/abbot/": "testdata/happy/abbot_abbot.html",
+ "/maven2/abbot/abbot/maven-metadata.xml": "testdata/happy/maven-metadata.xml",
+ "/maven2/abbot/abbot/0.12.3/": "testdata/happy/abbot_abbot_0.12.3.html",
+ "/maven2/abbot/abbot/0.12.3/abbot-0.12.3.jar.sha1": "testdata/happy/abbot-0.12.3.jar.sha1",
+ "/maven2/abbot/abbot/0.13.0/": "testdata/happy/abbot_abbot_0.13.0.html",
+ "/maven2/abbot/abbot/0.13.0/abbot-0.13.0.jar.sha1": "testdata/happy/abbot-0.13.0.jar.sha1",
+ "/maven2/abbot/abbot/0.13.0/abbot-0.13.0-copy.jar.sha1": "testdata/happy/abbot-0.13.0-copy.jar.sha1",
+ "/maven2/abbot/abbot/1.4.0/": "testdata/happy/abbot_abbot_1.4.0.html",
+ "/maven2/abbot/abbot/1.4.0/abbot-1.4.0.jar.sha1": "testdata/happy/abbot-1.4.0.jar.sha1",
+ "/maven2/abbot/abbot/1.4.0/abbot-1.4.0-lite.jar.sha1": "testdata/happy/abbot-1.4.0-lite.jar.sha1",
+ },
+ goldenPath: "testdata/happy/abbot.json.golden",
+ filePath: "indexes/abbot/abbot.json",
+ },
{
name: "sad path",
limit: 2,
@@ -76,13 +104,24 @@ func TestCrawl(t *testing.T) {
defer ts.Close()
tmpDir := t.TempDir()
- cl := crawler.NewCrawler(crawler.Option{
+ if tt.withDb {
+ dbc, err := dbtest.InitDB(t, []types.Index{
+ indexAbbot123,
+ indexAbbot130,
+ })
+ require.NoError(t, err)
+
+ tmpDir = dbc.Dir()
+ }
+
+ cl, err := crawler.NewCrawler(crawler.Option{
RootUrl: ts.URL + "/maven2/",
Limit: tt.limit,
CacheDir: tmpDir,
})
+ require.NoError(t, err)
- err := cl.Crawl(context.Background())
+ err = cl.Crawl(context.Background())
if tt.wantErr != "" {
assert.ErrorContains(t, err, tt.wantErr)
return
@@ -97,5 +136,24 @@ func TestCrawl(t *testing.T) {
assert.JSONEq(t, string(want), string(got))
})
}
-
}
+
+var (
+ abbot123Sha1b, _ = hex.DecodeString("51d28a27d919ce8690a40f4f335b9d591ceb16e9")
+ indexAbbot123 = types.Index{
+ GroupID: "abbot",
+ ArtifactID: "abbot",
+ Version: "0.12.3",
+ SHA1: abbot123Sha1b,
+ ArchiveType: types.JarType,
+ }
+
+ abbot130Sha1b, _ = hex.DecodeString("596d91e67631b0deb05fb685d8d1b6735f3e4f60")
+ indexAbbot130 = types.Index{
+ GroupID: "abbot",
+ ArtifactID: "abbot",
+ Version: "0.13.0",
+ SHA1: abbot130Sha1b,
+ ArchiveType: types.JarType,
+ }
+)
diff --git a/pkg/crawler/testdata/happy/abbot_abbot_0.13.0.html b/pkg/crawler/testdata/happy/abbot_abbot_0.13.0.html
index b5661f0..213b9af 100644
--- a/pkg/crawler/testdata/happy/abbot_abbot_0.13.0.html
+++ b/pkg/crawler/testdata/happy/abbot_abbot_0.13.0.html
@@ -17,9 +17,9 @@
abbot/abbot/0.13.0
../
-abbot-0.13.0.jar 2005-09-20 05:44 779426
-abbot-0.13.0.jar.md5 2005-09-20 05:44 32
-abbot-0.13.0.jar.sha1 2005-09-20 05:44 40
+abbot-0.13.0-copy.jar 2005-09-20 05:44 779426
+abbot-0.13.0-copy.jar.md5 2005-09-20 05:44 32
+abbot-0.13.0-copy.jar.sha1 2005-09-20 05:44 40
abbot-0.13.0.jar 2005-09-20 05:44 779426
abbot-0.13.0.jar.md5 2005-09-20 05:44 32
abbot-0.13.0.jar.sha1 2005-09-20 05:44 40