diff --git a/cmd/oras/internal/display/status/progress/manager.go b/cmd/oras/internal/display/status/progress/manager.go index 28e61d5e0..07d0e1044 100644 --- a/cmd/oras/internal/display/status/progress/manager.go +++ b/cmd/oras/internal/display/status/progress/manager.go @@ -23,6 +23,7 @@ import ( ocispec "github.com/opencontainers/image-spec/specs-go/v1" "oras.land/oras/cmd/oras/internal/display/status/console" + "oras.land/oras/internal/progress" ) const ( @@ -34,13 +35,6 @@ const ( var errManagerStopped = errors.New("progress output manager has already been stopped") -// Manager is progress view master -type Manager interface { - Add() (*Messenger, error) - SendAndStop(desc ocispec.Descriptor, prompt string) error - Close() error -} - type manager struct { status []*status statusLock sync.RWMutex @@ -48,10 +42,11 @@ type manager struct { updating sync.WaitGroup renderDone chan struct{} renderClosed chan struct{} + prompt map[progress.State]string } // NewManager initialized a new progress manager. -func NewManager(tty *os.File) (Manager, error) { +func NewManager(tty *os.File, prompt map[progress.State]string) (progress.Manager, error) { c, err := console.NewConsole(tty) if err != nil { return nil, err @@ -60,6 +55,7 @@ func NewManager(tty *os.File) (Manager, error) { console: c, renderDone: make(chan struct{}), renderClosed: make(chan struct{}), + prompt: prompt, } m.start() return m, nil @@ -103,13 +99,13 @@ func (m *manager) render() { } } -// Add appends a new status with 2-line space for rendering. -func (m *manager) Add() (*Messenger, error) { +// Track appends a new status with 2-line space for rendering. +func (m *manager) Track(desc ocispec.Descriptor) (progress.Tracker, error) { if m.closed() { return nil, errManagerStopped } - s := newStatus() + s := newStatus(desc) m.statusLock.Lock() m.status = append(m.status, s) m.statusLock.Unlock() @@ -119,18 +115,7 @@ func (m *manager) Add() (*Messenger, error) { return m.statusChan(s), nil } -// SendAndStop send message for descriptor and stop timing. -func (m *manager) SendAndStop(desc ocispec.Descriptor, prompt string) error { - messenger, err := m.Add() - if err != nil { - return err - } - messenger.Send(prompt, desc, desc.Size) - messenger.Stop() - return nil -} - -func (m *manager) statusChan(s *status) *Messenger { +func (m *manager) statusChan(s *status) progress.Tracker { ch := make(chan *status, BufferSize) m.updating.Add(1) go func() { @@ -139,7 +124,10 @@ func (m *manager) statusChan(s *status) *Messenger { s.update(newStatus) } }() - return &Messenger{ch: ch} + return &Messenger{ + ch: ch, + prompt: m.prompt, + } } // Close stops all status and waits for updating and rendering. diff --git a/cmd/oras/internal/display/status/progress/manager_test.go b/cmd/oras/internal/display/status/progress/manager_test.go index 43d0f2104..c8c2f9551 100644 --- a/cmd/oras/internal/display/status/progress/manager_test.go +++ b/cmd/oras/internal/display/status/progress/manager_test.go @@ -21,6 +21,7 @@ import ( "fmt" "testing" + ocispec "github.com/opencontainers/image-spec/specs-go/v1" "oras.land/oras/cmd/oras/internal/display/status/console" "oras.land/oras/internal/testutils" ) @@ -41,7 +42,7 @@ func Test_manager_render(t *testing.T) { } height, _ := m.console.GetHeightWidth() for i := 0; i < height; i++ { - if _, err := m.Add(); err != nil { + if _, err := m.Track(ocispec.Descriptor{}); err != nil { t.Fatal(err) } } diff --git a/cmd/oras/internal/display/status/progress/messenger.go b/cmd/oras/internal/display/status/progress/messenger.go index 9f0188b5a..f4da6cc2b 100644 --- a/cmd/oras/internal/display/status/progress/messenger.go +++ b/cmd/oras/internal/display/status/progress/messenger.go @@ -15,33 +15,48 @@ limitations under the License. package progress -import ( - ocispec "github.com/opencontainers/image-spec/specs-go/v1" - "oras.land/oras/cmd/oras/internal/display/status/progress/humanize" - "time" -) +import "oras.land/oras/internal/progress" // Messenger is progress message channel. type Messenger struct { ch chan *status closed bool + prompt map[progress.State]string } -// Start initializes the messenger. -func (sm *Messenger) Start() { - if sm.ch == nil { +func (m *Messenger) Update(status progress.Status) error { + if status.State == progress.StateInitialized { + m.start() + } + m.send(m.prompt[status.State], status.Offset) + return nil +} + +func (m *Messenger) Fail(err error) error { + m.ch <- fail(err) + return nil +} + +func (m *Messenger) Close() error { + m.stop() + return nil +} + +// start initializes the messenger. +func (m *Messenger) start() { + if m.ch == nil { return } - sm.ch <- startTiming() + m.ch <- startTiming() } -// Send a status message for the specified descriptor. -func (sm *Messenger) Send(prompt string, descriptor ocispec.Descriptor, offset int64) { +// send a status message for the specified descriptor. +func (m *Messenger) send(prompt string, offset int64) { for { select { - case sm.ch <- newStatusMessage(prompt, descriptor, offset): + case m.ch <- newStatusMessage(prompt, offset): return - case <-sm.ch: + case <-m.ch: // purge the channel until successfully pushed default: // ch is nil @@ -50,46 +65,12 @@ func (sm *Messenger) Send(prompt string, descriptor ocispec.Descriptor, offset i } } -// Stop the messenger after sending a end message. -func (sm *Messenger) Stop() { - if sm.closed { +// stop the messenger after sending a end message. +func (m *Messenger) stop() { + if m.closed { return } - sm.ch <- endTiming() - close(sm.ch) - sm.closed = true -} - -// newStatus generates a base empty status. -func newStatus() *status { - return &status{ - offset: -1, - total: humanize.ToBytes(0), - speedWindow: newSpeedWindow(framePerSecond), - } -} - -// newStatusMessage generates a status for messaging. -func newStatusMessage(prompt string, descriptor ocispec.Descriptor, offset int64) *status { - return &status{ - prompt: prompt, - descriptor: descriptor, - offset: offset, - } -} - -// startTiming creates start timing message. -func startTiming() *status { - return &status{ - offset: -1, - startTime: time.Now(), - } -} - -// endTiming creates end timing message. -func endTiming() *status { - return &status{ - offset: -1, - endTime: time.Now(), - } + m.ch <- endTiming() + close(m.ch) + m.closed = true } diff --git a/cmd/oras/internal/display/status/progress/messenger_test.go b/cmd/oras/internal/display/status/progress/messenger_test.go index a8b782e55..e41ea25db 100644 --- a/cmd/oras/internal/display/status/progress/messenger_test.go +++ b/cmd/oras/internal/display/status/progress/messenger_test.go @@ -16,16 +16,18 @@ limitations under the License. package progress import ( - v1 "github.com/opencontainers/image-spec/specs-go/v1" "testing" + + ocispec "github.com/opencontainers/image-spec/specs-go/v1" ) func Test_Messenger(t *testing.T) { var msg *status ch := make(chan *status, BufferSize) + messenger := &Messenger{ch: ch} - messenger.Start() + messenger.start() select { case msg = <-ch: if msg.offset != -1 { @@ -35,12 +37,11 @@ func Test_Messenger(t *testing.T) { t.Error("Expected start message") } - desc := v1.Descriptor{ - Digest: "mouse", - Size: 100, + desc := ocispec.Descriptor{ + Size: 100, } expected := int64(50) - messenger.Send("Reading", desc, expected) + messenger.send("Reading", expected) select { case msg = <-ch: if msg.offset != expected { @@ -53,8 +54,8 @@ func Test_Messenger(t *testing.T) { t.Error("Expected status message") } - messenger.Send("Reading", desc, expected) - messenger.Send("Read", desc, desc.Size) + messenger.send("Reading", expected) + messenger.send("Read", desc.Size) select { case msg = <-ch: if msg.offset != desc.Size { @@ -73,7 +74,7 @@ func Test_Messenger(t *testing.T) { } expected = int64(-1) - messenger.Stop() + messenger.stop() select { case msg = <-ch: if msg.offset != expected { @@ -83,7 +84,7 @@ func Test_Messenger(t *testing.T) { t.Error("Expected END status message") } - messenger.Stop() + messenger.stop() select { case msg = <-ch: if msg != nil { diff --git a/cmd/oras/internal/display/status/progress/status.go b/cmd/oras/internal/display/status/progress/status.go index feb653834..57b46dadf 100644 --- a/cmd/oras/internal/display/status/progress/status.go +++ b/cmd/oras/internal/display/status/progress/status.go @@ -39,11 +39,13 @@ var ( spinnerColor = aec.LightYellowF doneMarkColor = aec.LightGreenF progressColor = aec.LightBlueB + failureColor = aec.LightRedF ) // status is used as message to update progress view. type status struct { done bool // done is true when the end time is set + err error prompt string descriptor ocispec.Descriptor offset int64 @@ -56,6 +58,47 @@ type status struct { lock sync.Mutex } +// newStatus generates a base empty status. +func newStatus(desc ocispec.Descriptor) *status { + return &status{ + descriptor: desc, + offset: -1, + total: humanize.ToBytes(desc.Size), + speedWindow: newSpeedWindow(framePerSecond), + } +} + +// newStatusMessage generates a status for messaging. +func newStatusMessage(prompt string, offset int64) *status { + return &status{ + prompt: prompt, + offset: offset, + } +} + +// startTiming creates start timing message. +func startTiming() *status { + return &status{ + offset: -1, + startTime: time.Now(), + } +} + +// endTiming creates end timing message. +func endTiming() *status { + return &status{ + offset: -1, + endTime: time.Now(), + } +} + +func fail(err error) *status { + return &status{ + err: err, + offset: -1, + } +} + func (s *status) isZero() bool { return s.offset < 0 && s.startTime.IsZero() && s.endTime.IsZero() } @@ -106,9 +149,13 @@ func (s *status) String(width int) (string, string) { lenBar := int(percent * barLength) bar := fmt.Sprintf("[%s%s]", progressColor.Apply(strings.Repeat(" ", lenBar)), strings.Repeat(".", barLength-lenBar)) speed := s.calculateSpeed() - left = fmt.Sprintf("%s %s(%*s/s) %s %s", - spinnerColor.Apply(string(s.mark.symbol())), - bar, speedLength, speed, s.prompt, name) + var mark string + if s.err == nil { + mark = spinnerColor.Apply(string(s.mark.symbol())) + } else { + mark = failureColor.Apply("✗") + } + left = fmt.Sprintf("%s %s(%*s/s) %s %s", mark, bar, speedLength, speed, s.prompt, name) // bar + wrapper(2) + space(1) + speed + "/s"(2) + wrapper(2) = len(bar) + len(speed) + 7 lenLeft = barLength + speedLength + 7 } else { @@ -165,12 +212,11 @@ func (s *status) update(n *status) { s.lock.Lock() defer s.lock.Unlock() + if n.err != nil { + s.err = n.err + } if n.offset >= 0 { s.offset = n.offset - if n.descriptor.Size != s.descriptor.Size { - s.total = humanize.ToBytes(n.descriptor.Size) - } - s.descriptor = n.descriptor } if n.prompt != "" { s.prompt = n.prompt @@ -181,6 +227,8 @@ func (s *status) update(n *status) { } if !n.endTime.IsZero() { s.endTime = n.endTime - s.done = true + if s.err == nil { + s.done = true + } } } diff --git a/cmd/oras/internal/display/status/progress/status_test.go b/cmd/oras/internal/display/status/progress/status_test.go index 8542d651d..9e16a09b5 100644 --- a/cmd/oras/internal/display/status/progress/status_test.go +++ b/cmd/oras/internal/display/status/progress/status_test.go @@ -18,6 +18,7 @@ limitations under the License. package progress import ( + "context" "testing" "time" @@ -29,22 +30,20 @@ import ( func Test_status_String(t *testing.T) { // zero status and progress - s := newStatus() + s := newStatus(ocispec.Descriptor{ + MediaType: "application/vnd.oci.empty.oras.test.v1+json", + Size: 2, + Digest: "sha256:44136fa355b3678a1146ad16f7e8649e94fb4fc21fe77e8310c060f61caaff8a", + }) if status, digest := s.String(console.MinWidth); status != zeroStatus || digest != zeroDigest { t.Errorf("status.String() = %v, %v, want %v, %v", status, digest, zeroStatus, zeroDigest) } // not done s.update(&status{ - prompt: "test", - descriptor: ocispec.Descriptor{ - MediaType: "application/vnd.oci.empty.oras.test.v1+json", - Size: 2, - Digest: "sha256:44136fa355b3678a1146ad16f7e8649e94fb4fc21fe77e8310c060f61caaff8a", - }, + prompt: "test", startTime: time.Now().Add(-time.Minute), offset: 0, - total: humanize.ToBytes(2), }) // full name statusStr, digestStr := s.String(120) @@ -70,22 +69,20 @@ func Test_status_String(t *testing.T) { func Test_status_String_zeroWidth(t *testing.T) { // zero status and progress - s := newStatus() + s := newStatus(ocispec.Descriptor{ + MediaType: "application/vnd.oci.empty.oras.test.v1+json", + Size: 0, + Digest: "sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", + }) if status, digest := s.String(console.MinWidth); status != zeroStatus || digest != zeroDigest { t.Errorf("status.String() = %v, %v, want %v, %v", status, digest, zeroStatus, zeroDigest) } // not done s.update(&status{ - prompt: "test", - descriptor: ocispec.Descriptor{ - MediaType: "application/vnd.oci.empty.oras.test.v1+json", - Size: 0, - Digest: "sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", - }, + prompt: "test", startTime: time.Now().Add(-time.Minute), offset: 0, - total: humanize.ToBytes(0), }) // not done statusStr, digestStr := s.String(120) @@ -103,9 +100,36 @@ func Test_status_String_zeroWidth(t *testing.T) { t.Error(err) } } + +func Test_status_String_failure(t *testing.T) { + // zero status and progress + s := newStatus(ocispec.Descriptor{ + MediaType: "application/vnd.oci.empty.oras.test.v1+json", + Size: 2, + Digest: "sha256:44136fa355b3678a1146ad16f7e8649e94fb4fc21fe77e8310c060f61caaff8a", + }) + if status, digest := s.String(console.MinWidth); status != zeroStatus || digest != zeroDigest { + t.Errorf("status.String() = %v, %v, want %v, %v", status, digest, zeroStatus, zeroDigest) + } + + // done with failure + s.update(&status{ + err: context.Canceled, + prompt: "test", + descriptor: s.descriptor, + offset: 1, + startTime: time.Now().Add(-time.Minute), + endTime: time.Now(), + }) + statusStr, digestStr := s.String(120) + if err := testutils.OrderedMatch(statusStr+digestStr, "✗", s.prompt, s.descriptor.MediaType, "1.00/2 B", "50.00%", s.descriptor.Digest.String()); err != nil { + t.Error(err) + } +} + func Test_status_durationString(t *testing.T) { // zero duration - s := newStatus() + s := newStatus(ocispec.Descriptor{}) if d := s.durationString(); d != zeroDuration { t.Errorf("status.durationString() = %v, want %v", d, zeroDuration) } diff --git a/cmd/oras/internal/display/status/track/reader.go b/cmd/oras/internal/display/status/track/reader.go index 93919381f..1f2070a00 100644 --- a/cmd/oras/internal/display/status/track/reader.go +++ b/cmd/oras/internal/display/status/track/reader.go @@ -20,79 +20,47 @@ import ( "os" ocispec "github.com/opencontainers/image-spec/specs-go/v1" - "oras.land/oras/cmd/oras/internal/display/status/progress" + sprogress "oras.land/oras/cmd/oras/internal/display/status/progress" + "oras.land/oras/internal/progress" ) type reader struct { - base io.Reader - offset int64 - actionPrompt string - donePrompt string - descriptor ocispec.Descriptor - manager progress.Manager - messenger *progress.Messenger + io.Reader + progress.Tracker + + manager progress.Manager } // NewReader returns a new reader with tracked progress. func NewReader(r io.Reader, descriptor ocispec.Descriptor, actionPrompt string, donePrompt string, tty *os.File) (*reader, error) { - manager, err := progress.NewManager(tty) + prompt := map[progress.State]string{ + progress.StateInitialized: actionPrompt, + progress.StateTransmitting: actionPrompt, + progress.StateTransmitted: donePrompt, + } + + manager, err := sprogress.NewManager(tty, prompt) if err != nil { return nil, err } - return managedReader(r, descriptor, manager, actionPrompt, donePrompt) + return managedReader(r, descriptor, manager) } -func managedReader(r io.Reader, descriptor ocispec.Descriptor, manager progress.Manager, actionPrompt string, donePrompt string) (*reader, error) { - messenger, err := manager.Add() +func managedReader(r io.Reader, descriptor ocispec.Descriptor, manager progress.Manager) (*reader, error) { + tracker, err := manager.Track(descriptor) if err != nil { return nil, err } return &reader{ - base: r, - descriptor: descriptor, - actionPrompt: actionPrompt, - donePrompt: donePrompt, - manager: manager, - messenger: messenger, + Reader: progress.TrackReader(tracker, r), + Tracker: tracker, + manager: manager, }, nil } // StopManager stops the messenger channel and related manager. func (r *reader) StopManager() { - r.Close() + _ = r.Tracker.Close() _ = r.manager.Close() } - -// Done sends message to mark the tracked progress as complete. -func (r *reader) Done() { - r.messenger.Send(r.donePrompt, r.descriptor, r.descriptor.Size) - r.messenger.Stop() -} - -// Close closes the update channel. -func (r *reader) Close() { - r.messenger.Stop() -} - -// Start sends the start timing to the messenger channel. -func (r *reader) Start() { - r.messenger.Start() -} - -// Read reads from the underlying reader and updates the progress. -func (r *reader) Read(p []byte) (int, error) { - n, err := r.base.Read(p) - if err != nil && err != io.EOF { - return n, err - } - - r.offset = r.offset + int64(n) - if err == io.EOF { - if r.offset != r.descriptor.Size { - return n, io.ErrUnexpectedEOF - } - } - r.messenger.Send(r.actionPrompt, r.descriptor, r.offset) - return n, err -} diff --git a/cmd/oras/internal/display/status/track/target.go b/cmd/oras/internal/display/status/track/target.go index dce64201b..50dd3747d 100644 --- a/cmd/oras/internal/display/status/track/target.go +++ b/cmd/oras/internal/display/status/track/target.go @@ -25,21 +25,20 @@ import ( "oras.land/oras-go/v2" "oras.land/oras-go/v2/errdef" "oras.land/oras-go/v2/registry" - "oras.land/oras/cmd/oras/internal/display/status/progress" + sprogress "oras.land/oras/cmd/oras/internal/display/status/progress" + "oras.land/oras/internal/progress" ) // GraphTarget is a tracked oras.GraphTarget. type GraphTarget interface { oras.GraphTarget io.Closer - Prompt(desc ocispec.Descriptor, prompt string) error + Report(desc ocispec.Descriptor, state progress.State) error } type graphTarget struct { oras.GraphTarget - manager progress.Manager - actionPrompt string - donePrompt string + manager progress.Manager } type referenceGraphTarget struct { @@ -47,16 +46,14 @@ type referenceGraphTarget struct { } // NewTarget creates a new tracked Target. -func NewTarget(t oras.GraphTarget, actionPrompt, donePrompt string, tty *os.File) (GraphTarget, error) { - manager, err := progress.NewManager(tty) +func NewTarget(t oras.GraphTarget, prompt map[progress.State]string, tty *os.File) (GraphTarget, error) { + manager, err := sprogress.NewManager(tty, prompt) if err != nil { return nil, err } gt := &graphTarget{ - GraphTarget: t, - manager: manager, - actionPrompt: actionPrompt, - donePrompt: donePrompt, + GraphTarget: t, + manager: manager, } if _, ok := t.(registry.ReferencePusher); ok { @@ -76,37 +73,41 @@ func (t *graphTarget) Mount(ctx context.Context, desc ocispec.Descriptor, fromRe // Push pushes the content to the base oras.GraphTarget with tracking. func (t *graphTarget) Push(ctx context.Context, expected ocispec.Descriptor, content io.Reader) error { - r, err := managedReader(content, expected, t.manager, t.actionPrompt, t.donePrompt) + r, err := managedReader(content, expected, t.manager) if err != nil { return err } defer r.Close() - r.Start() + if err := progress.Start(r); err != nil { + return err + } if err := t.GraphTarget.Push(ctx, expected, r); err != nil { if errors.Is(err, errdef.ErrAlreadyExists) { // allowed error types in oras-go oci and memory store - r.Done() + if err := progress.Done(r); err != nil { + return err + } } return err } - r.Done() - return nil + return progress.Done(r) } // PushReference pushes the content to the base oras.GraphTarget with tracking. func (rgt *referenceGraphTarget) PushReference(ctx context.Context, expected ocispec.Descriptor, content io.Reader, reference string) error { - r, err := managedReader(content, expected, rgt.manager, rgt.actionPrompt, rgt.donePrompt) + r, err := managedReader(content, expected, rgt.manager) if err != nil { return err } defer r.Close() - r.Start() + if err := progress.Start(r); err != nil { + return err + } err = rgt.GraphTarget.(registry.ReferencePusher).PushReference(ctx, expected, r, reference) if err != nil { return err } - r.Done() - return nil + return progress.Done(r) } // Close closes the tracking manager. @@ -114,7 +115,17 @@ func (t *graphTarget) Close() error { return t.manager.Close() } -// Prompt prompts the user with the provided prompt and descriptor. -func (t *graphTarget) Prompt(desc ocispec.Descriptor, prompt string) error { - return t.manager.SendAndStop(desc, prompt) +// Report prompts the user with the provided state and descriptor. +func (t *graphTarget) Report(desc ocispec.Descriptor, state progress.State) error { + tracker, err := t.manager.Track(desc) + if err != nil { + return err + } + if err = tracker.Update(progress.Status{ + State: state, + Offset: desc.Size, + }); err != nil { + return err + } + return tracker.Close() } diff --git a/cmd/oras/internal/display/status/track/target_test.go b/cmd/oras/internal/display/status/track/target_test.go index 0630c8a9f..3e5fe62e3 100644 --- a/cmd/oras/internal/display/status/track/target_test.go +++ b/cmd/oras/internal/display/status/track/target_test.go @@ -30,6 +30,7 @@ import ( "oras.land/oras-go/v2/content/memory" "oras.land/oras-go/v2/errdef" "oras.land/oras-go/v2/registry/remote" + "oras.land/oras/internal/progress" "oras.land/oras/internal/testutils" ) @@ -62,9 +63,11 @@ func Test_referenceGraphTarget_PushReference(t *testing.T) { } // test tag := "tagged" - actionPrompt := "action" donePrompt := "done" - target, err := NewTarget(&testReferenceGraphTarget{src}, actionPrompt, donePrompt, device) + prompt := map[progress.State]string{ + progress.StateTransmitted: donePrompt, + } + target, err := NewTarget(&testReferenceGraphTarget{src}, prompt, device) if err != nil { t.Fatal(err) } @@ -108,9 +111,11 @@ func Test_graphTarget_Push_alreadyExists(t *testing.T) { t.Fatal("Failed to prepare test environment:", err) } // test - actionPrompt := "action" donePrompt := "done" - target, err := NewTarget(src, actionPrompt, donePrompt, device) + prompt := map[progress.State]string{ + progress.StateTransmitted: donePrompt, + } + target, err := NewTarget(src, prompt, device) if err != nil { t.Fatal(err) } diff --git a/cmd/oras/internal/display/status/tty.go b/cmd/oras/internal/display/status/tty.go index 369a33e04..adf1b8f39 100644 --- a/cmd/oras/internal/display/status/tty.go +++ b/cmd/oras/internal/display/status/tty.go @@ -26,6 +26,7 @@ import ( "oras.land/oras-go/v2" "oras.land/oras-go/v2/content" "oras.land/oras/cmd/oras/internal/display/status/track" + "oras.land/oras/internal/progress" ) // TTYPushHandler handles TTY status output for push command. @@ -57,7 +58,13 @@ func (ph *TTYPushHandler) OnEmptyArtifact() error { // TrackTarget returns a tracked target. func (ph *TTYPushHandler) TrackTarget(gt oras.GraphTarget) (oras.GraphTarget, StopTrackTargetFunc, error) { - tracked, err := track.NewTarget(gt, PushPromptUploading, PushPromptUploaded, ph.tty) + prompt := map[progress.State]string{ + progress.StateInitialized: PushPromptUploading, + progress.StateTransmitting: PushPromptUploading, + progress.StateTransmitted: PushPromptUploaded, + progress.StateExists: PushPromptExists, + } + tracked, err := track.NewTarget(gt, prompt, ph.tty) if err != nil { return nil, nil, err } @@ -68,7 +75,7 @@ func (ph *TTYPushHandler) TrackTarget(gt oras.GraphTarget) (oras.GraphTarget, St // OnCopySkipped is called when an object already exists. func (ph *TTYPushHandler) OnCopySkipped(_ context.Context, desc ocispec.Descriptor) error { ph.committed.Store(desc.Digest.String(), desc.Annotations[ocispec.AnnotationTitle]) - return ph.tracked.Prompt(desc, PushPromptExists) + return ph.tracked.Report(desc, progress.StateExists) } // PreCopy implements PreCopy of CopyHandler. @@ -84,7 +91,7 @@ func (ph *TTYPushHandler) PostCopy(ctx context.Context, desc ocispec.Descriptor) return err } for _, successor := range successors { - if err = ph.tracked.Prompt(successor, PushPromptSkipped); err != nil { + if err = ph.tracked.Report(successor, progress.StateSkipped); err != nil { return err } } @@ -126,17 +133,24 @@ func (ph *TTYPullHandler) OnNodeProcessing(_ ocispec.Descriptor) error { // OnNodeRestored implements PullHandler. func (ph *TTYPullHandler) OnNodeRestored(desc ocispec.Descriptor) error { - return ph.tracked.Prompt(desc, PullPromptRestored) + return ph.tracked.Report(desc, progress.StateMounted) } // OnNodeSkipped implements PullHandler. func (ph *TTYPullHandler) OnNodeSkipped(desc ocispec.Descriptor) error { - return ph.tracked.Prompt(desc, PullPromptSkipped) + return ph.tracked.Report(desc, progress.StateSkipped) } // TrackTarget returns a tracked target. func (ph *TTYPullHandler) TrackTarget(gt oras.GraphTarget) (oras.GraphTarget, StopTrackTargetFunc, error) { - tracked, err := track.NewTarget(gt, PullPromptDownloading, PullPromptPulled, ph.tty) + prompt := map[progress.State]string{ + progress.StateInitialized: PullPromptDownloading, + progress.StateTransmitting: PullPromptDownloading, + progress.StateTransmitted: PullPromptPulled, + progress.StateSkipped: PullPromptSkipped, + progress.StateMounted: PullPromptRestored, + } + tracked, err := track.NewTarget(gt, prompt, ph.tty) if err != nil { return nil, nil, err } @@ -160,8 +174,16 @@ func NewTTYCopyHandler(tty *os.File) CopyHandler { // StartTracking returns a tracked target from a graph target. func (ch *TTYCopyHandler) StartTracking(gt oras.GraphTarget) (oras.GraphTarget, error) { + prompt := map[progress.State]string{ + progress.StateInitialized: copyPromptCopying, + progress.StateTransmitting: copyPromptCopying, + progress.StateTransmitted: copyPromptCopied, + progress.StateExists: copyPromptExists, + progress.StateSkipped: copyPromptSkipped, + progress.StateMounted: copyPromptMounted, + } var err error - ch.tracked, err = track.NewTarget(gt, copyPromptCopying, copyPromptCopied, ch.tty) + ch.tracked, err = track.NewTarget(gt, prompt, ch.tty) if err != nil { return nil, err } @@ -176,7 +198,7 @@ func (ch *TTYCopyHandler) StopTracking() error { // OnCopySkipped is called when an object already exists. func (ch *TTYCopyHandler) OnCopySkipped(_ context.Context, desc ocispec.Descriptor) error { ch.committed.Store(desc.Digest.String(), desc.Annotations[ocispec.AnnotationTitle]) - return ch.tracked.Prompt(desc, copyPromptExists) + return ch.tracked.Report(desc, progress.StateExists) } // PreCopy implements PreCopy of CopyHandler. @@ -192,7 +214,7 @@ func (ch *TTYCopyHandler) PostCopy(ctx context.Context, desc ocispec.Descriptor) return err } for _, successor := range successors { - if err = ch.tracked.Prompt(successor, copyPromptSkipped); err != nil { + if err = ch.tracked.Report(successor, progress.StateSkipped); err != nil { return err } } @@ -202,5 +224,5 @@ func (ch *TTYCopyHandler) PostCopy(ctx context.Context, desc ocispec.Descriptor) // OnMounted implements OnMounted of CopyHandler. func (ch *TTYCopyHandler) OnMounted(_ context.Context, desc ocispec.Descriptor) error { ch.committed.Store(desc.Digest.String(), desc.Annotations[ocispec.AnnotationTitle]) - return ch.tracked.Prompt(desc, copyPromptMounted) + return ch.tracked.Report(desc, progress.StateMounted) } diff --git a/cmd/oras/root/blob/fetch.go b/cmd/oras/root/blob/fetch.go index 44694c428..a41af6ecc 100644 --- a/cmd/oras/root/blob/fetch.go +++ b/cmd/oras/root/blob/fetch.go @@ -31,6 +31,7 @@ import ( "oras.land/oras/cmd/oras/internal/display/status/track" oerrors "oras.land/oras/cmd/oras/internal/errors" "oras.land/oras/cmd/oras/internal/option" + "oras.land/oras/internal/progress" ) type fetchBlobOptions struct { @@ -176,11 +177,15 @@ func (opts *fetchBlobOptions) doFetch(ctx context.Context, src oras.ReadOnlyTarg return ocispec.Descriptor{}, err } defer trackedReader.StopManager() - trackedReader.Start() + if err := progress.Start(trackedReader); err != nil { + return ocispec.Descriptor{}, err + } if _, err = io.Copy(writer, trackedReader); err != nil { return ocispec.Descriptor{}, err } - trackedReader.Done() + if err := progress.Done(trackedReader); err != nil { + return ocispec.Descriptor{}, err + } } if err := vr.Verify(); err != nil { return ocispec.Descriptor{}, err diff --git a/cmd/oras/root/blob/push.go b/cmd/oras/root/blob/push.go index 883de2fe5..b5bb4fac6 100644 --- a/cmd/oras/root/blob/push.go +++ b/cmd/oras/root/blob/push.go @@ -31,6 +31,7 @@ import ( "oras.land/oras/cmd/oras/internal/option" "oras.land/oras/cmd/oras/internal/output" "oras.land/oras/internal/file" + "oras.land/oras/internal/progress" ) type pushBlobOptions struct { @@ -164,11 +165,11 @@ func (opts *pushBlobOptions) doPush(ctx context.Context, printer *output.Printer return err } defer trackedReader.StopManager() - trackedReader.Start() - r = trackedReader - if err := t.Push(ctx, desc, r); err != nil { + if err := progress.Start(trackedReader); err != nil { return err } - trackedReader.Done() - return nil + if err := t.Push(ctx, desc, trackedReader); err != nil { + return err + } + return progress.Done(trackedReader) } diff --git a/internal/progress/example_test.go b/internal/progress/example_test.go new file mode 100644 index 000000000..22da22bfd --- /dev/null +++ b/internal/progress/example_test.go @@ -0,0 +1,83 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package progress_test + +import ( + "crypto/rand" + "fmt" + "io" + + "oras.land/oras/internal/progress" +) + +// ExampleTrackReader demonstrates how to track the transmission progress of a +// reader. +func ExampleTrackReader() { + // Set up a progress tracker. + total := int64(11) + tracker := progress.TrackerFunc(func(status progress.Status, err error) error { + if err != nil { + fmt.Printf("Error: %v\n", err) + return nil + } + switch status.State { + case progress.StateInitialized: + fmt.Println("Start reading content") + case progress.StateTransmitting: + fmt.Printf("Progress: %d/%d bytes\n", status.Offset, total) + case progress.StateTransmitted: + fmt.Println("Finish reading content") + default: + // Ignore other states. + } + return nil + }) + // Close takes no effect for TrackerFunc but should be called for general + // Tracker implementations. + defer tracker.Close() + + // Wrap a reader of a random content generator with the progress tracker. + r := io.LimitReader(rand.Reader, total) + rc := progress.TrackReader(tracker, r) + + // Start tracking the transmission. + if err := progress.Start(tracker); err != nil { + panic(err) + } + + // Read from the random content generator and discard the content, while + // tracking the progress. + // Note: io.Discard is wrapped with a io.MultiWriter for dropping + // the io.ReadFrom interface for demonstration purposes. + buf := make([]byte, 3) + w := io.MultiWriter(io.Discard) + if _, err := io.CopyBuffer(w, rc, buf); err != nil { + panic(err) + } + + // Finish tracking the transmission. + if err := progress.Done(tracker); err != nil { + panic(err) + } + + // Output: + // Start reading content + // Progress: 3/11 bytes + // Progress: 6/11 bytes + // Progress: 9/11 bytes + // Progress: 11/11 bytes + // Finish reading content +} diff --git a/internal/progress/manager.go b/internal/progress/manager.go new file mode 100644 index 000000000..c44b79d34 --- /dev/null +++ b/internal/progress/manager.go @@ -0,0 +1,31 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package progress tracks the status of descriptors being processed. +package progress + +import ( + "io" + + ocispec "github.com/opencontainers/image-spec/specs-go/v1" +) + +// Manager tracks the progress of multiple descriptors. +type Manager interface { + io.Closer + + // Track starts tracking the progress of a descriptor. + Track(desc ocispec.Descriptor) (Tracker, error) +} diff --git a/internal/progress/status.go b/internal/progress/status.go new file mode 100644 index 000000000..e6c4d1cb9 --- /dev/null +++ b/internal/progress/status.go @@ -0,0 +1,40 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package progress + +// State represents the state of a descriptor. +type State int + +// Registered states. +const ( + StateUnknown State = iota // unknown state + StateInitialized // progress initialized + StateTransmitting // transmitting content + StateTransmitted // content transmitted + StateExists // content exists + StateSkipped // content skipped + StateMounted // content mounted +) + +// Status represents the status of a descriptor. +type Status struct { + // State represents the state of the descriptor. + State State + + // Offset represents the current offset of the descriptor. + // Offset is discarded if set to a negative value. + Offset int64 +} diff --git a/internal/progress/tracker.go b/internal/progress/tracker.go new file mode 100644 index 000000000..0431a767e --- /dev/null +++ b/internal/progress/tracker.go @@ -0,0 +1,166 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package progress + +import "io" + +// Tracker updates the status of a descriptor. +type Tracker interface { + io.Closer + + // Update updates the status of the descriptor. + Update(status Status) error + + // Fail marks the descriptor as failed. + // Fail should return nil on successful failure marking. + Fail(err error) error +} + +// TrackerFunc is an adapter to allow the use of ordinary functions as Trackers. +// If f is a function with the appropriate signature, TrackerFunc(f) is a +// [Tracker] that calls f. +type TrackerFunc func(status Status, err error) error + +// Close closes the tracker. +func (f TrackerFunc) Close() error { + return nil +} + +// Update updates the status of the descriptor. +func (f TrackerFunc) Update(status Status) error { + return f(status, nil) +} + +// Fail marks the descriptor as failed. +func (f TrackerFunc) Fail(err error) error { + return f(Status{}, err) +} + +// Start starts tracking the transmission. +func Start(t Tracker) error { + return t.Update(Status{ + State: StateInitialized, + Offset: -1, + }) +} + +// Done marks the transmission as complete. +// Done should be called after the transmission is complete. +// Note: Reading all content from the reader does not imply the transmission is +// complete. +func Done(t Tracker) error { + return t.Update(Status{ + State: StateTransmitted, + Offset: -1, + }) +} + +// TrackReader bind a reader with a tracker. +func TrackReader(t Tracker, r io.Reader) io.Reader { + rt := readTracker{ + base: r, + tracker: t, + } + if _, ok := r.(io.WriterTo); ok { + return &readTrackerWriteTo{rt} + } + return &rt +} + +// readTracker tracks the transmission based on the read operation. +type readTracker struct { + base io.Reader + tracker Tracker + offset int64 +} + +// Read reads from the base reader and updates the status. +// On partial read, the tracker treats it as two reads: a successful read with +// status update and a failed read with failure report. +func (rt *readTracker) Read(p []byte) (int, error) { + n, err := rt.base.Read(p) + rt.offset += int64(n) + if n > 0 { + if updateErr := rt.tracker.Update(Status{ + State: StateTransmitting, + Offset: rt.offset, + }); updateErr != nil { + err = updateErr + } + } + if err != nil && err != io.EOF { + if failErr := rt.tracker.Fail(err); failErr != nil { + return n, failErr + } + } + return n, err +} + +// readTrackerWriteTo is readTracker with WriteTo support. +type readTrackerWriteTo struct { + readTracker +} + +// WriteTo writes to the base writer and updates the status. +// On partial write, the tracker treats it as two writes: a successful write +// with status update and a failed write with failure report. +func (rt *readTrackerWriteTo) WriteTo(w io.Writer) (int64, error) { + wt := &writeTracker{ + base: w, + tracker: rt.tracker, + offset: rt.offset, + } + n, err := rt.base.(io.WriterTo).WriteTo(wt) + rt.offset = wt.offset + if err != nil && wt.trackerErr == nil { + if failErr := rt.tracker.Fail(err); failErr != nil { + return n, failErr + } + } + return n, err +} + +// writeTracker tracks the transmission based on the write operation. +type writeTracker struct { + base io.Writer + tracker Tracker + offset int64 + trackerErr error +} + +// Write writes to the base writer and updates the status. +// On partial write, the tracker treats it as two writes: a successful write +// with status update and a failed write with failure report. +func (wt *writeTracker) Write(p []byte) (int, error) { + n, err := wt.base.Write(p) + wt.offset += int64(n) + if n > 0 { + if updateErr := wt.tracker.Update(Status{ + State: StateTransmitting, + Offset: wt.offset, + }); updateErr != nil { + wt.trackerErr = updateErr + err = updateErr + } + } + if err != nil { + if failErr := wt.tracker.Fail(err); failErr != nil { + wt.trackerErr = failErr + return n, failErr + } + } + return n, err +} diff --git a/internal/progress/tracker_test.go b/internal/progress/tracker_test.go new file mode 100644 index 000000000..f4190cc0c --- /dev/null +++ b/internal/progress/tracker_test.go @@ -0,0 +1,414 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package progress + +import ( + "bytes" + "errors" + "io" + "testing" +) + +func TestTrackerFunc_Close(t *testing.T) { + var f TrackerFunc + if err := f.Close(); err != nil { + t.Errorf("TrackerFunc.Close() error = %v, wantErr false", err) + } +} + +func TestTrackerFunc_Update(t *testing.T) { + wantStatus := Status{ + State: StateTransmitted, + Offset: 42, + } + var wantErr error + tracker := TrackerFunc(func(status Status, err error) error { + if status != wantStatus { + t.Errorf("TrackerFunc status = %v, want %v", status, wantStatus) + } + if err != nil { + t.Errorf("TrackerFunc err = %v, want nil", err) + } + return wantErr + }) + + if err := tracker.Update(wantStatus); err != wantErr { + t.Errorf("TrackerFunc.Update() error = %v, want %v", err, wantErr) + } + + wantErr = errors.New("fail to track") + if err := tracker.Update(wantStatus); err != wantErr { + t.Errorf("TrackerFunc.Update() error = %v, want %v", err, wantErr) + } +} + +func TestTrackerFunc_Fail(t *testing.T) { + reportErr := errors.New("fail to process") + var wantStatus Status + var wantErr error + tracker := TrackerFunc(func(status Status, err error) error { + if status != wantStatus { + t.Errorf("TrackerFunc status = %v, want %v", status, wantStatus) + } + if err != reportErr { + t.Errorf("TrackerFunc err = %v, want %v", err, reportErr) + } + return wantErr + }) + + if err := tracker.Fail(reportErr); err != wantErr { + t.Errorf("TrackerFunc.Fail() error = %v, want %v", err, wantErr) + } + + wantErr = errors.New("fail to track") + if err := tracker.Fail(reportErr); err != wantErr { + t.Errorf("TrackerFunc.Fail() error = %v, want %v", err, wantErr) + } +} + +func TestStart(t *testing.T) { + tests := []struct { + name string + t Tracker + wantErr bool + }{ + { + name: "successful report initialization", + t: TrackerFunc(func(status Status, err error) error { + if status.State != StateInitialized { + t.Errorf("expected state to be StateInitialized, got %v", status.State) + } + return nil + }), + }, + { + name: "fail to report initialization", + t: TrackerFunc(func(status Status, err error) error { + return errors.New("fail to track") + }), + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := Start(tt.t); (err != nil) != tt.wantErr { + t.Errorf("Start() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestDone(t *testing.T) { + tests := []struct { + name string + t Tracker + wantErr bool + }{ + { + name: "successful report initialization", + t: TrackerFunc(func(status Status, err error) error { + if status.State != StateTransmitted { + t.Errorf("expected state to be StateTransmitted, got %v", status.State) + } + return nil + }), + }, + { + name: "fail to report initialization", + t: TrackerFunc(func(status Status, err error) error { + return errors.New("fail to track") + }), + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := Done(tt.t); (err != nil) != tt.wantErr { + t.Errorf("Done() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestTrackReader(t *testing.T) { + const bufSize = 6 + content := []byte("hello world") + t.Run("track io.Reader", func(t *testing.T) { + var wantStatus Status + tracker := TrackerFunc(func(status Status, err error) error { + if status != wantStatus { + t.Errorf("TrackerFunc status = %v, want %v", status, wantStatus) + } + if err != nil { + t.Errorf("TrackerFunc err = %v, want nil", err) + } + return nil + }) + var reader io.Reader = bytes.NewReader(content) + reader = io.LimitReader(reader, int64(len(content))) // remove the io.WriterTo interface + gotReader := TrackReader(tracker, reader) + if _, ok := gotReader.(*readTracker); !ok { + t.Fatalf("TrackReader() = %v, want *readTracker", gotReader) + } + + wantStatus = Status{ + State: StateTransmitting, + Offset: bufSize, + } + buf := make([]byte, bufSize) + n, err := gotReader.Read(buf) + if err != nil { + t.Fatalf("TrackReader() error = %v, want nil", err) + } + if n != bufSize { + t.Fatalf("TrackReader() n = %v, want %v", n, bufSize) + } + if want := content[:bufSize]; !bytes.Equal(buf, want) { + t.Fatalf("TrackReader() buf = %v, want %v", buf, want) + } + + wantStatus = Status{ + State: StateTransmitting, + Offset: int64(len(content)), + } + n, err = gotReader.Read(buf) + if err != nil { + t.Fatalf("TrackReader() error = %v, want nil", err) + } + if want := len(content) - bufSize; n != want { + t.Fatalf("TrackReader() n = %v, want %v", n, want) + } + buf = buf[:n] + if want := content[bufSize:]; !bytes.Equal(buf, want) { + t.Fatalf("TrackReader() buf = %v, want %v", buf, want) + } + }) + + t.Run("track io.Reader + io.WriterTo", func(t *testing.T) { + var wantStatus Status + tracker := TrackerFunc(func(status Status, err error) error { + if status != wantStatus { + t.Errorf("TrackerFunc status = %v, want %v", status, wantStatus) + } + if err != nil { + t.Errorf("TrackerFunc err = %v, want nil", err) + } + return nil + }) + var reader io.Reader = bytes.NewReader(content) + gotReader := TrackReader(tracker, reader) + if _, ok := gotReader.(*readTrackerWriteTo); !ok { + t.Fatalf("TrackReader() = %v, want *readTrackerWriteTo", gotReader) + } + + wantStatus = Status{ + State: StateTransmitting, + Offset: bufSize, + } + buf := make([]byte, bufSize) + n, err := gotReader.Read(buf) + if err != nil { + t.Fatalf("TrackReader() error = %v, want nil", err) + } + if n != bufSize { + t.Fatalf("TrackReader() n = %v, want %v", n, bufSize) + } + if want := content[:bufSize]; !bytes.Equal(buf, want) { + t.Fatalf("TrackReader() buf = %v, want %v", buf, want) + } + + wantStatus = Status{ + State: StateTransmitting, + Offset: int64(len(content)), + } + writeBuf := bytes.NewBuffer(nil) + wn, err := gotReader.(io.WriterTo).WriteTo(writeBuf) + if err != nil { + t.Fatalf("TrackReader() error = %v, want nil", err) + } + if want := len(content) - bufSize; wn != int64(want) { + t.Fatalf("TrackReader() n = %v, want %v", wn, want) + } + buf = writeBuf.Bytes() + if want := content[bufSize:]; !bytes.Equal(buf, want) { + t.Fatalf("TrackReader() buf = %v, want %v", buf, want) + } + }) + + t.Run("empty io.Reader", func(t *testing.T) { + tracker := TrackerFunc(func(status Status, err error) error { + t.Errorf("TrackerFunc should not be called for empty read") + return nil + }) + gotReader := TrackReader(tracker, bytes.NewReader(nil)) + + buf := make([]byte, bufSize) + n, err := gotReader.Read(buf) + if want := io.EOF; err != want { + t.Fatalf("TrackReader() error = %v, want %v", err, want) + } + if want := 0; n != want { + t.Fatalf("TrackReader() n = %v, want %v", n, want) + } + + writeBuf := bytes.NewBuffer(nil) + wn, err := gotReader.(io.WriterTo).WriteTo(writeBuf) + if err != nil { + t.Fatalf("TrackReader() error = %v, want nil", err) + } + if want := int64(0); wn != want { + t.Fatalf("TrackReader() n = %v, want %v", wn, want) + } + buf = writeBuf.Bytes() + if want := []byte{}; !bytes.Equal(buf, want) { + t.Fatalf("TrackReader() buf = %v, want %v", buf, want) + } + }) + + t.Run("report failure", func(t *testing.T) { + var wantStatus Status + wantErr := errors.New("fail to track") + trackerMockStage := 0 + tracker := TrackerFunc(func(status Status, err error) error { + defer func() { + trackerMockStage++ + }() + switch trackerMockStage { + case 0: + if status != wantStatus { + t.Errorf("TrackerFunc status = %v, want %v", status, wantStatus) + } + if err != nil { + t.Errorf("TrackerFunc err = %v, want nil", err) + } + return wantErr + case 1: + var emptyStatus Status + if wantStatus := emptyStatus; status != wantStatus { + t.Errorf("TrackerFunc status = %v, want %v", status, wantStatus) + } + if err != wantErr { + t.Errorf("TrackerFunc err = %v, want %v", err, wantErr) + } + return nil + default: + t.Errorf("TrackerFunc should not be called") + return nil + } + }) + gotReader := TrackReader(tracker, bytes.NewReader(content)) + + wantStatus = Status{ + State: StateTransmitting, + Offset: bufSize, + } + buf := make([]byte, bufSize) + n, err := gotReader.Read(buf) + if err != wantErr { + t.Fatalf("TrackReader() error = %v, want %v", err, wantErr) + } + if n != bufSize { + t.Fatalf("TrackReader() n = %v, want %v", n, bufSize) + } + if want := content[:bufSize]; !bytes.Equal(buf, want) { + t.Fatalf("TrackReader() buf = %v, want %v", buf, want) + } + + wantStatus = Status{ + State: StateTransmitting, + Offset: int64(len(content)), + } + trackerMockStage = 0 + writeBuf := bytes.NewBuffer(nil) + wn, err := gotReader.(io.WriterTo).WriteTo(writeBuf) + if err != wantErr { + t.Fatalf("TrackReader() error = %v, want %v", err, wantErr) + } + if want := len(content) - bufSize; wn != int64(want) { + t.Fatalf("TrackReader() n = %v, want %v", wn, want) + } + buf = writeBuf.Bytes() + if want := content[bufSize:]; !bytes.Equal(buf, want) { + t.Fatalf("TrackReader() buf = %v, want %v", buf, want) + } + }) + + t.Run("process failure", func(t *testing.T) { + reportErr := io.ErrClosedPipe + var wantStatus Status + var wantErr error + tracker := TrackerFunc(func(status Status, err error) error { + if status != wantStatus { + t.Errorf("TrackerFunc status = %v, want %v", status, wantStatus) + } + if err != reportErr { + t.Errorf("TrackerFunc err = %v, want %v", err, reportErr) + } + return wantErr + }) + pipeReader, pipeWriter := io.Pipe() + pipeReader.Close() + pipeWriter.Close() + gotReader := TrackReader(tracker, pipeReader) + + buf := make([]byte, bufSize) + n, err := gotReader.Read(buf) + if err != reportErr { + t.Fatalf("TrackReader() error = %v, want %v", err, reportErr) + } + if want := 0; n != want { + t.Fatalf("TrackReader() n = %v, want %v", n, want) + } + + wantErr = errors.New("fail to track") + n, err = gotReader.Read(buf) + if err != wantErr { + t.Fatalf("TrackReader() error = %v, want %v", err, wantErr) + } + if want := 0; n != want { + t.Fatalf("TrackReader() n = %v, want %v", n, want) + } + + gotReader = TrackReader(tracker, io.MultiReader(pipeReader)) // wrap io.WriteTo + wantErr = nil + writeBuf := bytes.NewBuffer(nil) + wn, err := gotReader.(io.WriterTo).WriteTo(writeBuf) + if err != reportErr { + t.Fatalf("TrackReader() error = %v, want %v", err, reportErr) + } + if want := int64(0); wn != want { + t.Fatalf("TrackReader() n = %v, want %v", wn, want) + } + buf = writeBuf.Bytes() + if want := []byte{}; !bytes.Equal(buf, want) { + t.Fatalf("TrackReader() buf = %v, want %v", buf, want) + } + + gotReader = TrackReader(tracker, io.MultiReader(pipeReader)) // wrap io.WriteTo + wantErr = errors.New("fail to track") + wn, err = gotReader.(io.WriterTo).WriteTo(writeBuf) + if err != wantErr { + t.Fatalf("TrackReader() error = %v, want %v", err, wantErr) + } + if want := int64(0); wn != want { + t.Fatalf("TrackReader() n = %v, want %v", wn, want) + } + buf = writeBuf.Bytes() + if want := []byte{}; !bytes.Equal(buf, want) { + t.Fatalf("TrackReader() buf = %v, want %v", buf, want) + } + }) +} diff --git a/internal/testutils/prompt.go b/internal/testutils/prompt.go index f40763d94..45bda6203 100644 --- a/internal/testutils/prompt.go +++ b/internal/testutils/prompt.go @@ -20,6 +20,7 @@ import ( ocispec "github.com/opencontainers/image-spec/specs-go/v1" "oras.land/oras-go/v2" + "oras.land/oras/internal/progress" ) // PromptDiscarder mocks trackable GraphTarget with discarded prompt. @@ -29,7 +30,7 @@ type PromptDiscarder struct { } // Prompt discards the prompt. -func (p *PromptDiscarder) Prompt(ocispec.Descriptor, string) error { +func (p *PromptDiscarder) Report(ocispec.Descriptor, progress.State) error { return nil } @@ -48,6 +49,6 @@ func NewErrorPrompt(err error) *ErrorPrompt { } // Prompt mocks an errored prompt. -func (e *ErrorPrompt) Prompt(ocispec.Descriptor, string) error { +func (e *ErrorPrompt) Report(ocispec.Descriptor, progress.State) error { return e.wanted }