Skip to content

Commit

Permalink
feat(maven): only fetch snapshots and releases when enabled (#1436)
Browse files Browse the repository at this point in the history
#1280

In Maven pom.xml, download policies can be specified for repositories to
decide whether to fetch releases/snapshots or not, which is currently
not supported in MavenRegistryAPIClient. This causes us making many
unnecessary requests.

This PR defines a new data structure to hold information for a Maven
registry including the download policies. For the default registry, we
assume it is only allowed to download releases (it can be a TODO to
support specifying this via CLI flags).
  • Loading branch information
cuixq authored Dec 19, 2024
1 parent 36bf2ee commit faad98e
Show file tree
Hide file tree
Showing 15 changed files with 122 additions and 64 deletions.
8 changes: 7 additions & 1 deletion cmd/osv-scanner/fix/noninteractive.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/google/osv-scanner/internal/remediation"
"github.com/google/osv-scanner/internal/resolution"
"github.com/google/osv-scanner/internal/resolution/client"
"github.com/google/osv-scanner/internal/resolution/datasource"
lf "github.com/google/osv-scanner/internal/resolution/lockfile"
"github.com/google/osv-scanner/internal/resolution/manifest"
"github.com/google/osv-scanner/pkg/lockfile"
Expand Down Expand Up @@ -304,7 +305,12 @@ func autoOverride(ctx context.Context, r reporter.Reporter, opts osvFixOptions,
if ok {
registries := make([]client.Registry, len(specific.Repositories))
for i, repo := range specific.Repositories {
registries[i] = client.Registry{URL: string(repo.URL)}
registries[i] = datasource.MavenRegistry{
URL: string(repo.URL),
ID: string(repo.ID),
ReleasesEnabled: repo.Releases.Enabled.Boolean(),
SnapshotsEnabled: repo.Snapshots.Enabled.Boolean(),
}
}
if err := opts.Client.DependencyClient.AddRegistries(registries); err != nil {
return err
Expand Down
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ go 1.22.7

require (
deps.dev/api/v3 v3.0.0-20241010035105-b3ba03369df1
deps.dev/util/maven v0.0.0-20241010035105-b3ba03369df1
deps.dev/util/resolve v0.0.0-20241010035105-b3ba03369df1
deps.dev/util/maven v0.0.0-20241218001045-3890182485f3
deps.dev/util/resolve v0.0.0-20241218001045-3890182485f3
deps.dev/util/semver v0.0.0-20241010035105-b3ba03369df1
github.com/BurntSushi/toml v1.4.0
github.com/CycloneDX/cyclonedx-go v0.9.1
Expand Down
8 changes: 4 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk=
dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
deps.dev/api/v3 v3.0.0-20241010035105-b3ba03369df1 h1:qvrLinmQrkOLmguTE9FpRfC/e2iud/eVMWigXXTdrdA=
deps.dev/api/v3 v3.0.0-20241010035105-b3ba03369df1/go.mod h1:DyBY3wNVqRCwvb4tLvz6LL/FupH3FMflEROyQAv2Vi0=
deps.dev/util/maven v0.0.0-20241010035105-b3ba03369df1 h1:PWgfyz6h15n4dbcSAL/3lSiXl8foQZCzUWUqEjNLNvI=
deps.dev/util/maven v0.0.0-20241010035105-b3ba03369df1/go.mod h1:SBW3EribdkZYk6zxi5oVn/ZECvi4ixb7EGgEWfSimNk=
deps.dev/util/resolve v0.0.0-20241010035105-b3ba03369df1 h1:nHefSxxfjdmo+zn/8fEcfSUkTXi+LKnBNvul21ZI9qw=
deps.dev/util/resolve v0.0.0-20241010035105-b3ba03369df1/go.mod h1:XXi6yRYqhtxw5DvGX/mbG6fHSLn8OgoPowNd8EAxDgk=
deps.dev/util/maven v0.0.0-20241218001045-3890182485f3 h1:2Zjbnw7OgDGr3vM7Epwxgv2cMyeBps4X9AHF1SD82ao=
deps.dev/util/maven v0.0.0-20241218001045-3890182485f3/go.mod h1:SBW3EribdkZYk6zxi5oVn/ZECvi4ixb7EGgEWfSimNk=
deps.dev/util/resolve v0.0.0-20241218001045-3890182485f3 h1:9mtMNCV9XDN689ukjYBSKjqQXaqwC666BOP6fV2plBQ=
deps.dev/util/resolve v0.0.0-20241218001045-3890182485f3/go.mod h1:XXi6yRYqhtxw5DvGX/mbG6fHSLn8OgoPowNd8EAxDgk=
deps.dev/util/semver v0.0.0-20241010035105-b3ba03369df1 h1:t4P0dCCNIrV84B5d7kOIAzji+HrO303Nrw9BB4ktBy0=
deps.dev/util/semver v0.0.0-20241010035105-b3ba03369df1/go.mod h1:jkcH+k02gWHBiZ7G4OnUOkSZ6WDq54Pt5DrOA8FN8Uo=
github.com/BurntSushi/toml v1.4.0 h1:kuoIxZQy2WRRk1pttg9asf+WVv6tWQuBNVmK8+nqPr0=
Expand Down
9 changes: 7 additions & 2 deletions internal/manifest/maven.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,12 @@ func (e MavenResolverExtractor) Extract(f lockfile.DepFile) ([]lockfile.PackageD
return nil, fmt.Errorf("failed to merge profiles: %w", err)
}
for _, repo := range project.Repositories {
if err := e.MavenRegistryAPIClient.AddRegistry(string(repo.URL)); err != nil {
if err := e.MavenRegistryAPIClient.AddRegistry(datasource.MavenRegistry{
URL: string(repo.URL),
ID: string(repo.ID),
ReleasesEnabled: repo.Releases.Enabled.Boolean(),
SnapshotsEnabled: repo.Snapshots.Enabled.Boolean(),
}); err != nil {
return nil, fmt.Errorf("failed to add registry %s: %w", repo.URL, err)
}
}
Expand All @@ -57,7 +62,7 @@ func (e MavenResolverExtractor) Extract(f lockfile.DepFile) ([]lockfile.PackageD
if registries := e.MavenRegistryAPIClient.GetRegistries(); len(registries) > 0 {
clientRegs := make([]client.Registry, len(registries))
for i, reg := range registries {
clientRegs[i] = client.Registry{URL: reg}
clientRegs[i] = reg
}
if err := e.DependencyClient.AddRegistries(clientRegs); err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion internal/manifest/maven_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ func TestParseMavenWithResolver_WithParent(t *testing.T) {
`))

resolutionClient := clienttest.NewMockResolutionClient(t, "fixtures/universe/basic-universe.yaml")
client, _ := datasource.NewMavenRegistryAPIClient(srv.URL)
client, _ := datasource.NewMavenRegistryAPIClient(datasource.MavenRegistry{URL: srv.URL, ReleasesEnabled: true})
packages, err := manifest.ParseMavenWithResolver(resolutionClient, client, "fixtures/maven/with-parent.xml")
if err != nil {
t.Errorf("Got unexpected error: %v", err)
Expand Down
4 changes: 1 addition & 3 deletions internal/resolution/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ type DependencyClient interface {
AddRegistries(registries []Registry) error
}

type Registry struct {
URL string
}
type Registry interface{}

// PreFetch loads cache, then makes and caches likely queries needed for resolving a package with a list of requirements
func PreFetch(ctx context.Context, c DependencyClient, requirements []resolve.RequirementVersion, manifestPath string) {
Expand Down
9 changes: 7 additions & 2 deletions internal/resolution/client/maven_registry_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package client
import (
"context"
"encoding/gob"
"errors"
"fmt"
"os"
"strings"
Expand All @@ -21,7 +22,7 @@ type MavenRegistryClient struct {
}

func NewMavenRegistryClient(registry string) (*MavenRegistryClient, error) {
client, err := datasource.NewMavenRegistryAPIClient(registry)
client, err := datasource.NewMavenRegistryAPIClient(datasource.MavenRegistry{URL: registry, ReleasesEnabled: true})
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -147,7 +148,11 @@ func (c *MavenRegistryClient) MatchingVersions(ctx context.Context, vk resolve.V

func (c *MavenRegistryClient) AddRegistries(registries []Registry) error {
for _, reg := range registries {
if err := c.api.AddRegistry(reg.URL); err != nil {
specific, ok := reg.(datasource.MavenRegistry)
if !ok {
return errors.New("invalid Maven registry information")
}
if err := c.api.AddRegistry(specific); err != nil {
return err
}
}
Expand Down
82 changes: 49 additions & 33 deletions internal/resolution/datasource/maven_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@ const MavenCentral = "https://repo.maven.apache.org/maven2"
var errAPIFailed = errors.New("API query failed")

type MavenRegistryAPIClient struct {
defaultRegistry string // Base URL of the default registry that we are making requests
// TODO: disable fetching snapshot if specified in pom.xml
registries []string // URLs of the registries to fetch projects
defaultRegistry MavenRegistry // The default registry that we are making requests
registries []MavenRegistry // Additional registries specified to fetch projects

// Cache fields
mu *sync.Mutex
Expand All @@ -39,14 +38,29 @@ type response struct {
Body []byte
}

func NewMavenRegistryAPIClient(registry string) (*MavenRegistryAPIClient, error) {
if registry == "" {
registry = MavenCentral
} else if _, err := url.Parse(registry); err != nil {
return nil, fmt.Errorf("invalid Maven registry %s: %w", registry, err)
type MavenRegistry struct {
URL string
Parsed *url.URL

// Information from pom.xml
ID string
ReleasesEnabled bool
SnapshotsEnabled bool
}

func NewMavenRegistryAPIClient(registry MavenRegistry) (*MavenRegistryAPIClient, error) {
if registry.URL == "" {
registry.URL = MavenCentral
registry.ID = "central"
}
u, err := url.Parse(registry.URL)
if err != nil {
return nil, fmt.Errorf("invalid Maven registry %s: %w", registry.URL, err)
}
registry.Parsed = u

return &MavenRegistryAPIClient{
// We assume only downloading releases is allowed on the default registry.
defaultRegistry: registry,
mu: &sync.Mutex{},
responses: NewRequestCache[string, response](),
Expand All @@ -64,20 +78,25 @@ func (m *MavenRegistryAPIClient) WithoutRegistries() *MavenRegistryAPIClient {
}

// AddRegistry adds the given registry to the list of registries if it has not been added.
func (m *MavenRegistryAPIClient) AddRegistry(registry string) error {
if slices.Contains(m.registries, registry) {
return nil
func (m *MavenRegistryAPIClient) AddRegistry(registry MavenRegistry) error {
for _, reg := range m.registries {
if reg.ID == registry.ID {
return nil
}
}

if _, err := url.Parse(registry); err != nil {
u, err := url.Parse(registry.URL)
if err != nil {
return err
}

registry.Parsed = u
m.registries = append(m.registries, registry)

return nil
}

func (m *MavenRegistryAPIClient) GetRegistries() []string {
func (m *MavenRegistryAPIClient) GetRegistries() (registries []MavenRegistry) {
return m.registries
}

Expand All @@ -89,7 +108,10 @@ func (m *MavenRegistryAPIClient) GetRegistries() []string {
func (m *MavenRegistryAPIClient) GetProject(ctx context.Context, groupID, artifactID, version string) (maven.Project, error) {
if !strings.HasSuffix(version, "-SNAPSHOT") {
for _, registry := range append(m.registries, m.defaultRegistry) {
project, err := m.getProject(ctx, registry, groupID, artifactID, version, "")
if !registry.ReleasesEnabled {
continue
}
project, err := m.getProject(ctx, registry.Parsed, groupID, artifactID, version, "")
if err == nil {
return project, nil
}
Expand All @@ -99,8 +121,11 @@ func (m *MavenRegistryAPIClient) GetProject(ctx context.Context, groupID, artifa
}

for _, registry := range append(m.registries, m.defaultRegistry) {
// Fetch version metadata for snapshot versions.
metadata, err := m.getVersionMetadata(ctx, registry, groupID, artifactID, version)
// Fetch version metadata for snapshot versions from the registries enabling that.
if !registry.SnapshotsEnabled {
continue
}
metadata, err := m.getVersionMetadata(ctx, registry.Parsed, groupID, artifactID, version)
if err != nil {
continue
}
Expand All @@ -114,7 +139,7 @@ func (m *MavenRegistryAPIClient) GetProject(ctx context.Context, groupID, artifa
}
}

project, err := m.getProject(ctx, registry, groupID, artifactID, version, snapshot)
project, err := m.getProject(ctx, registry.Parsed, groupID, artifactID, version, snapshot)
if err == nil {
return project, nil
}
Expand All @@ -128,7 +153,7 @@ func (m *MavenRegistryAPIClient) GetProject(ctx context.Context, groupID, artifa
func (m *MavenRegistryAPIClient) GetVersions(ctx context.Context, groupID, artifactID string) ([]maven.String, error) {
var versions []maven.String
for _, registry := range append(m.registries, m.defaultRegistry) {
metadata, err := m.getArtifactMetadata(ctx, registry, groupID, artifactID)
metadata, err := m.getArtifactMetadata(ctx, registry.Parsed, groupID, artifactID)
if err != nil {
continue
}
Expand All @@ -141,14 +166,11 @@ func (m *MavenRegistryAPIClient) GetVersions(ctx context.Context, groupID, artif

// getProject fetches a pom.xml specified by groupID, artifactID and version and parses it to maven.Project.
// For snapshot versions, the exact version value is specified by snapshot.
func (m *MavenRegistryAPIClient) getProject(ctx context.Context, registry, groupID, artifactID, version, snapshot string) (maven.Project, error) {
func (m *MavenRegistryAPIClient) getProject(ctx context.Context, registry *url.URL, groupID, artifactID, version, snapshot string) (maven.Project, error) {
if snapshot == "" {
snapshot = version
}
u, err := url.JoinPath(registry, strings.ReplaceAll(groupID, ".", "/"), artifactID, version, fmt.Sprintf("%s-%s.pom", artifactID, snapshot))
if err != nil {
return maven.Project{}, fmt.Errorf("failed to join path: %w", err)
}
u := registry.JoinPath(strings.ReplaceAll(groupID, ".", "/"), artifactID, version, fmt.Sprintf("%s-%s.pom", artifactID, snapshot)).String()

var project maven.Project
if err := m.get(ctx, u, &project); err != nil {
Expand All @@ -159,11 +181,8 @@ func (m *MavenRegistryAPIClient) getProject(ctx context.Context, registry, group
}

// getVersionMetadata fetches a version level maven-metadata.xml and parses it to maven.Metadata.
func (m *MavenRegistryAPIClient) getVersionMetadata(ctx context.Context, registry, groupID, artifactID, version string) (maven.Metadata, error) {
u, err := url.JoinPath(registry, strings.ReplaceAll(groupID, ".", "/"), artifactID, version, "maven-metadata.xml")
if err != nil {
return maven.Metadata{}, fmt.Errorf("failed to join path: %w", err)
}
func (m *MavenRegistryAPIClient) getVersionMetadata(ctx context.Context, registry *url.URL, groupID, artifactID, version string) (maven.Metadata, error) {
u := registry.JoinPath(strings.ReplaceAll(groupID, ".", "/"), artifactID, version, "maven-metadata.xml").String()

var metadata maven.Metadata
if err := m.get(ctx, u, &metadata); err != nil {
Expand All @@ -174,11 +193,8 @@ func (m *MavenRegistryAPIClient) getVersionMetadata(ctx context.Context, registr
}

// GetArtifactMetadata fetches an artifact level maven-metadata.xml and parses it to maven.Metadata.
func (m *MavenRegistryAPIClient) getArtifactMetadata(ctx context.Context, registry, groupID, artifactID string) (maven.Metadata, error) {
u, err := url.JoinPath(registry, strings.ReplaceAll(groupID, ".", "/"), artifactID, "maven-metadata.xml")
if err != nil {
return maven.Metadata{}, fmt.Errorf("failed to join path: %w", err)
}
func (m *MavenRegistryAPIClient) getArtifactMetadata(ctx context.Context, registry *url.URL, groupID, artifactID string) (maven.Metadata, error) {
u := registry.JoinPath(strings.ReplaceAll(groupID, ".", "/"), artifactID, "maven-metadata.xml").String()

var metadata maven.Metadata
if err := m.get(ctx, u, &metadata); err != nil {
Expand Down
27 changes: 19 additions & 8 deletions internal/resolution/datasource/maven_registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package datasource

import (
"context"
"net/url"
"reflect"
"testing"

Expand All @@ -13,7 +14,7 @@ func TestGetProject(t *testing.T) {
t.Parallel()

srv := testutility.NewMockHTTPServer(t)
client, _ := NewMavenRegistryAPIClient(srv.URL)
client, _ := NewMavenRegistryAPIClient(MavenRegistry{URL: srv.URL, ReleasesEnabled: true})
srv.SetResponse(t, "org/example/x.y.z/1.0.0/x.y.z-1.0.0.pom", []byte(`
<project>
<groupId>org.example</groupId>
Expand Down Expand Up @@ -42,7 +43,7 @@ func TestGetProjectSnapshot(t *testing.T) {
t.Parallel()

srv := testutility.NewMockHTTPServer(t)
client, _ := NewMavenRegistryAPIClient(srv.URL)
client, _ := NewMavenRegistryAPIClient(MavenRegistry{URL: srv.URL, SnapshotsEnabled: true})
srv.SetResponse(t, "org/example/x.y.z/3.3.1-SNAPSHOT/maven-metadata.xml", []byte(`
<metadata>
<groupId>org.example</groupId>
Expand Down Expand Up @@ -96,7 +97,7 @@ func TestGetArtifactMetadata(t *testing.T) {
t.Parallel()

srv := testutility.NewMockHTTPServer(t)
client, _ := NewMavenRegistryAPIClient(srv.URL)
client, _ := NewMavenRegistryAPIClient(MavenRegistry{URL: srv.URL, ReleasesEnabled: true})
srv.SetResponse(t, "org/example/x.y.z/maven-metadata.xml", []byte(`
<metadata>
<groupId>org.example</groupId>
Expand All @@ -113,7 +114,12 @@ func TestGetArtifactMetadata(t *testing.T) {
</metadata>
`))

got, err := client.getArtifactMetadata(context.Background(), srv.URL, "org.example", "x.y.z")
u, err := url.Parse(srv.URL)
if err != nil {
t.Fatalf("failed to get parse URL %s: %v", srv.URL, err)
}

got, err := client.getArtifactMetadata(context.Background(), u, "org.example", "x.y.z")
if err != nil {
t.Fatalf("failed to get artifact metadata for %s:%s: %v", "org.example", "x.y.z", err)
}
Expand All @@ -139,7 +145,7 @@ func TestGetVersionMetadata(t *testing.T) {
t.Parallel()

srv := testutility.NewMockHTTPServer(t)
client, _ := NewMavenRegistryAPIClient(srv.URL)
client, _ := NewMavenRegistryAPIClient(MavenRegistry{URL: srv.URL, SnapshotsEnabled: true})
srv.SetResponse(t, "org/example/x.y.z/3.3.1-SNAPSHOT/maven-metadata.xml", []byte(`
<metadata>
<groupId>org.example</groupId>
Expand All @@ -166,7 +172,12 @@ func TestGetVersionMetadata(t *testing.T) {
</metadata>
`))

got, err := client.getVersionMetadata(context.Background(), srv.URL, "org.example", "x.y.z", "3.3.1-SNAPSHOT")
u, err := url.Parse(srv.URL)
if err != nil {
t.Fatalf("failed to get parse URL %s: %v", srv.URL, err)
}

got, err := client.getVersionMetadata(context.Background(), u, "org.example", "x.y.z", "3.3.1-SNAPSHOT")
if err != nil {
t.Fatalf("failed to get metadata for %s:%s verion %s: %v", "org.example", "x.y.z", "3.3.1-SNAPSHOT", err)
}
Expand Down Expand Up @@ -202,7 +213,7 @@ func TestMultipleRegistry(t *testing.T) {
t.Parallel()

dft := testutility.NewMockHTTPServer(t)
client, _ := NewMavenRegistryAPIClient(dft.URL)
client, _ := NewMavenRegistryAPIClient(MavenRegistry{URL: dft.URL, ReleasesEnabled: true})
dft.SetResponse(t, "org/example/x.y.z/maven-metadata.xml", []byte(`
<metadata>
<groupId>org.example</groupId>
Expand Down Expand Up @@ -233,7 +244,7 @@ func TestMultipleRegistry(t *testing.T) {
`))

srv := testutility.NewMockHTTPServer(t)
if err := client.AddRegistry(srv.URL); err != nil {
if err := client.AddRegistry(MavenRegistry{URL: srv.URL, ReleasesEnabled: true}); err != nil {
t.Fatalf("failed to add registry %s: %v", srv.URL, err)
}
srv.SetResponse(t, "org/example/x.y.z/maven-metadata.xml", []byte(`
Expand Down
Loading

0 comments on commit faad98e

Please sign in to comment.