Skip to content

Commit

Permalink
FS-1215 Add default download option for unsupported firmware (#161)
Browse files Browse the repository at this point in the history
* FS-1215 Add default download option for unsupported firmware
  • Loading branch information
coffeefreak101 authored Feb 20, 2024
1 parent 1825239 commit e956de2
Show file tree
Hide file tree
Showing 6 changed files with 310 additions and 20 deletions.
13 changes: 11 additions & 2 deletions internal/app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package app

import (
"context"
"net/http"
"net/url"
"os"
"strings"
Expand Down Expand Up @@ -101,8 +102,12 @@ func New(ctx context.Context, inventoryKind types.InventoryKind, cfgFile, logLev
ghClient := github.NewGitHubClient(ctx, app.Config.GithubOpenBmcToken)
downloader = github.NewGitHubDownloader(app.Logger, ghClient)
default:
app.Logger.Error("Vendor not supported: " + vendor)
continue
if app.Config.DefaultDownloadURL == "" {
app.Logger.Error("Vendor not supported: " + vendor)
continue
}

downloader = vendors.NewSourceOverrideDownloader(app.Logger, http.DefaultClient, app.Config.DefaultDownloadURL)
}

syncer := vendors.NewSyncer(dstFs, tmpFs, downloader, inventoryClient, firmwares, app.Logger)
Expand Down Expand Up @@ -223,6 +228,10 @@ func (a *App) envVarAppOverrides() error {
a.Config.GithubOpenBmcToken = a.v.GetString("github.openbmc.token")
}

if a.v.GetString("default.download.url") != "" {
a.Config.DefaultDownloadURL = a.v.GetString("default.download.url")
}

return nil
}

Expand Down
6 changes: 5 additions & 1 deletion internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ import (

"github.com/pkg/errors"

"github.com/metal-toolbox/firmware-syncer/pkg/types"
serverservice "go.hollow.sh/serverservice/pkg/api/v1"

"github.com/metal-toolbox/firmware-syncer/pkg/types"
)

var (
Expand Down Expand Up @@ -49,6 +50,9 @@ type Configuration struct {

// GithubOpenBmcToken defines the token used to access internal openbmc repository
GithubOpenBmcToken string `mapstructure:"github_openbmc_token"`

// DefaultDownloadURL defines where unsupported firmware will be downloaded from
DefaultDownloadURL string `mapstructure:"default_download_url"`
}

// ServerserviceOptions defines configuration for the Serverservice client.
Expand Down
42 changes: 26 additions & 16 deletions internal/inventory/serverservice.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package inventory
import (
"context"
"net/url"
"path"
"slices"
"sort"
"strings"
Expand Down Expand Up @@ -81,32 +80,26 @@ func newClientWithOAuth(ctx context.Context, cfg *config.ServerserviceOptions) (
return client, nil
}

func makeFirmwarePath(fw *serverservice.ComponentFirmwareVersion) string {
return path.Join(fw.Vendor, fw.Filename)
}

// Publish adds firmware data to Hollow's ServerService
func (s *serverService) Publish(ctx context.Context, newFirmware *serverservice.ComponentFirmwareVersion) error {
artifactsURL, err := url.JoinPath(s.artifactsURL, makeFirmwarePath(newFirmware))
if err != nil {
return err
}
func (s *serverService) addRepositoryURL(fw *serverservice.ComponentFirmwareVersion) (err error) {
fw.RepositoryURL, err = url.JoinPath(s.artifactsURL, fw.Vendor, fw.Filename)

newFirmware.RepositoryURL = artifactsURL
return err
}

func (s *serverService) getCurrentFirmware(ctx context.Context, newFirmware *serverservice.ComponentFirmwareVersion) (*serverservice.ComponentFirmwareVersion, error) {
params := serverservice.ComponentFirmwareVersionListParams{
Checksum: newFirmware.Checksum,
}

firmwares, _, err := s.client.ListServerComponentFirmware(ctx, &params)
if err != nil {
return errors.Wrap(ErrServerServiceQuery, "ListServerComponentFirmware: "+err.Error())
return nil, errors.Wrap(ErrServerServiceQuery, "ListServerComponentFirmware: "+err.Error())
}

firmwareCount := len(firmwares)

if firmwareCount == 0 {
return s.createFirmware(ctx, newFirmware)
return nil, nil
}

if firmwareCount != 1 {
Expand All @@ -122,10 +115,27 @@ func (s *serverService) Publish(ctx context.Context, newFirmware *serverservice.
WithField("version", newFirmware.Version).
Error("Multiple firmware IDs found with checksum")

return errors.Wrap(ErrServerServiceDuplicateFirmware, strings.Join(uuids, ","))
return nil, errors.Wrap(ErrServerServiceDuplicateFirmware, strings.Join(uuids, ","))
}

return &firmwares[0], nil
}

// Publish adds firmware data to Hollow's ServerService
func (s *serverService) Publish(ctx context.Context, newFirmware *serverservice.ComponentFirmwareVersion) error {
if err := s.addRepositoryURL(newFirmware); err != nil {
return err
}

currentFirmware, err := s.getCurrentFirmware(ctx, newFirmware)
if err != nil {
return err
}

if currentFirmware == nil {
return s.createFirmware(ctx, newFirmware)
}

currentFirmware := &firmwares[0]
newFirmware.UUID = currentFirmware.UUID
newFirmware.Model = mergeModels(currentFirmware.Model, newFirmware.Model)

Expand Down
69 changes: 69 additions & 0 deletions internal/vendors/downloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"fmt"
"io"
"net/http"
"net/url"
"os"
"path"
Expand Down Expand Up @@ -46,6 +47,9 @@ var (
ErrDirEmpty = errors.New("directory empty")
ErrModTimeFile = errors.New("error retrieving file mod time")
ErrCreatingTmpDir = errors.New("error creating tmp dir")

ErrUnexpectedStatusCode = errors.New("unexpected status code")
ErrDownloadingFile = errors.New("failed to download file")
)

//go:generate mockgen -source=downloader.go -destination=mocks/downloader.go Downloader
Expand Down Expand Up @@ -367,3 +371,68 @@ func (s *S3Downloader) Download(ctx context.Context, downloadDir string, firmwar

return path.Join(downloadDir, firmware.Filename), nil
}

// SourceOverrideDownloader is meant to download firmware from an alternate source
// than the firmware's UpstreamURL.
type SourceOverrideDownloader struct {
logger *logrus.Logger
client serverservice.Doer
baseURL string
}

// NewSourceOverrideDownloader creates a SourceOverrideDownloader.
func NewSourceOverrideDownloader(logger *logrus.Logger, client serverservice.Doer, sourceURL string) Downloader {
if !strings.HasSuffix(sourceURL, "/") {
sourceURL += "/"
}

return &SourceOverrideDownloader{
logger,
client,
sourceURL,
}
}

// Download will download the given firmware into the given downloadDir,
// and return the full path to the downloaded file.
// The file will be downloaded from the sourceURL provided to the SourceOverrideDownloader
// instead of the firmware's UpstreamURL.
func (d *SourceOverrideDownloader) Download(ctx context.Context, downloadDir string, firmware *serverservice.ComponentFirmwareVersion) (string, error) {
filePath := filepath.Join(downloadDir, firmware.Filename)

firmwareURL, err := url.JoinPath(d.baseURL, firmware.Filename)
if err != nil {
return "", errors.Wrap(ErrSourceURL, err.Error())
}

d.logger.WithField("url", firmwareURL).
WithField("firmware", firmware.Filename).
WithField("vendor", firmware.Vendor).
Info("Downloading firmware")

file, err := os.Create(filePath)
if err != nil {
return "", errors.Wrap(ErrCreatingTmpDir, err.Error())
}

req, err := http.NewRequestWithContext(ctx, http.MethodGet, firmwareURL, http.NoBody)
if err != nil {
return "", errors.Wrap(ErrSourceURL, err.Error())
}

resp, err := d.client.Do(req)
if err != nil {
return "", errors.Wrap(ErrDownloadingFile, err.Error())
}
defer resp.Body.Close()

if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return "", errors.Wrap(ErrUnexpectedStatusCode, fmt.Sprintf("status code %d", resp.StatusCode))
}

if _, err = io.Copy(file, resp.Body); err != nil {
return "", errors.Wrap(ErrCopy, err.Error())
}

return filePath, nil
}
146 changes: 145 additions & 1 deletion internal/vendors/downloader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,20 @@ package vendors
import (
"context"
"errors"
"fmt"
"io"
"net/http"
"os"
"path"
"testing"

"github.com/metal-toolbox/firmware-syncer/internal/config"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
serverservice "go.hollow.sh/serverservice/pkg/api/v1"
"go.uber.org/mock/gomock"

"github.com/metal-toolbox/firmware-syncer/internal/config"
mock_vendors "github.com/metal-toolbox/firmware-syncer/internal/vendors/mocks"
)

func Test_InitLocalFs(t *testing.T) {
Expand Down Expand Up @@ -187,3 +197,137 @@ func Test_SplitURLPath(t *testing.T) {
assert.Equal(t, tt.urlPart, urlPart)
}
}

//go:generate mockgen -source=downloader_test.go -destination=mocks/httpDoer.go HTTPDoer

// HTTPDoer interface is meant to help generate the mock
type HTTPDoer interface {
serverservice.Doer
}

// URL Matcher

type requestURLMatcher struct {
expectedURL string
}

func matchesURL(expectedURL string) *requestURLMatcher {
return &requestURLMatcher{expectedURL: expectedURL}
}

func (v *requestURLMatcher) Matches(i interface{}) bool {
request, ok := i.(*http.Request)
if !ok {
return false
}

return request.URL.String() == v.expectedURL
}

func (v *requestURLMatcher) String() string {
return fmt.Sprintf("expected URL %s", v.expectedURL)
}

// ReadCloser Error

type readCloserErr struct{}

func (r *readCloserErr) Read(_ []byte) (int, error) {
return 0, io.ErrUnexpectedEOF
}

func (r *readCloserErr) Close() error {
return nil
}

// SourceOverrideDownloader Test

func Test_SourceOverrideDownloader(t *testing.T) {
ctx := context.Background()
logger := logrus.New()

testCases := []struct {
name string
statusCode int
withBadURL bool
withClientError bool
withCopyError bool
expectedError error
}{
{
name: "success",
},
{
name: "bad url",
withBadURL: true,
expectedError: ErrSourceURL,
},
{
name: "client error",
withClientError: true,
expectedError: ErrDownloadingFile,
},
{
name: "bad status code",
statusCode: 500,
expectedError: ErrUnexpectedStatusCode,
},
{
name: "copy error",
withCopyError: true,
expectedError: ErrCopy,
},
}

for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
tmpDir, err := os.MkdirTemp(os.TempDir(), "test")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpDir)

statusCode := 200
if tt.statusCode != 0 {
statusCode = tt.statusCode
}

var body io.ReadCloser = &http.NoBody
if tt.withCopyError {
body = &readCloserErr{}
}

fakeResponse := &http.Response{Body: body, StatusCode: statusCode}

ctrl := gomock.NewController(t)
client := mock_vendors.NewMockHTTPDoer(ctrl)

var clientError error
if tt.withClientError {
clientError = io.ErrUnexpectedEOF
}

fakeURL := "https://foo"
firmwareName := "firmware.bin"

if tt.withBadURL {
fakeURL = "!@#$%^&*()_+-="
} else {
client.EXPECT().Do(matchesURL("https://foo/firmware.bin")).Return(fakeResponse, clientError)
}

fakeFirmware := &serverservice.ComponentFirmwareVersion{Filename: firmwareName}
downloader := NewSourceOverrideDownloader(logger, client, fakeURL)
firmwarePath, err := downloader.Download(ctx, tmpDir, fakeFirmware)

if tt.expectedError != nil {
assert.ErrorContains(t, err, tt.expectedError.Error())
return
}

assert.NoError(t, err)
assert.Equal(t, path.Join(tmpDir, firmwareName), firmwarePath)
assert.FileExists(t, firmwarePath)
})
}
}
Loading

0 comments on commit e956de2

Please sign in to comment.