From 363dc2da933fb938da7169936de7985acb5dc922 Mon Sep 17 00:00:00 2001 From: Terry Howe Date: Wed, 4 Dec 2024 17:19:03 -0700 Subject: [PATCH] refactor: create tty copy handler (#1485) Signed-off-by: Terry Howe Signed-off-by: Billy Zha Co-authored-by: Billy Zha --- cmd/oras/internal/display/handler.go | 5 +- cmd/oras/internal/display/handler_test.go | 24 +++- cmd/oras/internal/display/status/interface.go | 2 + cmd/oras/internal/display/status/text.go | 10 ++ cmd/oras/internal/display/status/tty.go | 61 ++++++++++ .../display/status/tty_console_test.go | 106 +++++++++++++++++- cmd/oras/root/cp.go | 71 +++--------- cmd/oras/root/cp_test.go | 20 +--- 8 files changed, 226 insertions(+), 73 deletions(-) diff --git a/cmd/oras/internal/display/handler.go b/cmd/oras/internal/display/handler.go index de84a1723..ce6243d9d 100644 --- a/cmd/oras/internal/display/handler.go +++ b/cmd/oras/internal/display/handler.go @@ -215,6 +215,9 @@ func NewManifestIndexUpdateHandler(outputPath string, printer *output.Printer, p } // NewCopyHandler returns copy handlers. -func NewCopyHandler(printer *output.Printer, fetcher fetcher.Fetcher) (status.CopyHandler, metadata.CopyHandler) { +func NewCopyHandler(printer *output.Printer, tty *os.File, fetcher fetcher.Fetcher) (status.CopyHandler, metadata.CopyHandler) { + if tty != nil { + return status.NewTTYCopyHandler(tty), text.NewCopyHandler(printer) + } return status.NewTextCopyHandler(printer, fetcher), text.NewCopyHandler(printer) } diff --git a/cmd/oras/internal/display/handler_test.go b/cmd/oras/internal/display/handler_test.go index e02d1cfca..f71981e5a 100644 --- a/cmd/oras/internal/display/handler_test.go +++ b/cmd/oras/internal/display/handler_test.go @@ -16,11 +16,13 @@ limitations under the License. package display import ( + "oras.land/oras/internal/testutils" "os" + "reflect" "testing" - "oras.land/oras/internal/testutils" - + "oras.land/oras/cmd/oras/internal/display/metadata/text" + "oras.land/oras/cmd/oras/internal/display/status" "oras.land/oras/cmd/oras/internal/option" "oras.land/oras/cmd/oras/internal/output" ) @@ -50,3 +52,21 @@ func TestNewPullHandler(t *testing.T) { t.Errorf("NewPullHandler() error = %v, want nil", err) } } + +func TestNewCopyHandler(t *testing.T) { + printer := output.NewPrinter(os.Stdout, os.Stderr) + copyHandler, copyMetadataHandler := NewCopyHandler(printer, os.Stdout, nil) + if _, ok := copyHandler.(*status.TTYCopyHandler); !ok { + t.Errorf("expected *status.TTYCopyHandler actual %v", reflect.TypeOf(copyHandler)) + } + if _, ok := copyMetadataHandler.(*text.CopyHandler); !ok { + t.Errorf("expected metadata.CopyHandler actual %v", reflect.TypeOf(copyMetadataHandler)) + } + copyHandler, copyMetadataHandler = NewCopyHandler(printer, nil, nil) + if _, ok := copyHandler.(*status.TextCopyHandler); !ok { + t.Errorf("expected *status.TextCopyHandler actual %v", reflect.TypeOf(copyHandler)) + } + if _, ok := copyMetadataHandler.(*text.CopyHandler); !ok { + t.Errorf("expected metadata.CopyHandler actual %v", reflect.TypeOf(copyMetadataHandler)) + } +} diff --git a/cmd/oras/internal/display/status/interface.go b/cmd/oras/internal/display/status/interface.go index 0a918e42f..dafcdb408 100644 --- a/cmd/oras/internal/display/status/interface.go +++ b/cmd/oras/internal/display/status/interface.go @@ -62,6 +62,8 @@ type CopyHandler interface { PreCopy(ctx context.Context, desc ocispec.Descriptor) error PostCopy(ctx context.Context, desc ocispec.Descriptor) error OnMounted(ctx context.Context, desc ocispec.Descriptor) error + StartTracking(gt oras.GraphTarget) (oras.GraphTarget, error) + StopTracking() error } // ManifestIndexCreateHandler handles status output for manifest index create command. diff --git a/cmd/oras/internal/display/status/text.go b/cmd/oras/internal/display/status/text.go index 76b54f09d..618f93725 100644 --- a/cmd/oras/internal/display/status/text.go +++ b/cmd/oras/internal/display/status/text.go @@ -151,6 +151,16 @@ func NewTextCopyHandler(printer *output.Printer, fetcher content.Fetcher) CopyHa } } +// StartTracking starts a tracked target from a graph target. +func (ch *TextCopyHandler) StartTracking(gt oras.GraphTarget) (oras.GraphTarget, error) { + return gt, nil +} + +// StopTracking ends the copy tracking for the target. +func (ch *TextCopyHandler) StopTracking() error { + return nil +} + // OnCopySkipped is called when an object already exists. func (ch *TextCopyHandler) OnCopySkipped(_ context.Context, desc ocispec.Descriptor) error { ch.committed.Store(desc.Digest.String(), desc.Annotations[ocispec.AnnotationTitle]) diff --git a/cmd/oras/internal/display/status/tty.go b/cmd/oras/internal/display/status/tty.go index 78f639076..369a33e04 100644 --- a/cmd/oras/internal/display/status/tty.go +++ b/cmd/oras/internal/display/status/tty.go @@ -143,3 +143,64 @@ func (ph *TTYPullHandler) TrackTarget(gt oras.GraphTarget) (oras.GraphTarget, St ph.tracked = tracked return tracked, tracked.Close, nil } + +// TTYCopyHandler handles tty status output for copy events. +type TTYCopyHandler struct { + tty *os.File + committed sync.Map + tracked track.GraphTarget +} + +// NewTTYCopyHandler returns a new handler for copy command. +func NewTTYCopyHandler(tty *os.File) CopyHandler { + return &TTYCopyHandler{ + tty: tty, + } +} + +// StartTracking returns a tracked target from a graph target. +func (ch *TTYCopyHandler) StartTracking(gt oras.GraphTarget) (oras.GraphTarget, error) { + var err error + ch.tracked, err = track.NewTarget(gt, copyPromptCopying, copyPromptCopied, ch.tty) + if err != nil { + return nil, err + } + return ch.tracked, err +} + +// StopTracking ends the copy tracking for the target. +func (ch *TTYCopyHandler) StopTracking() error { + return ch.tracked.Close() +} + +// OnCopySkipped is called when an object already exists. +func (ch *TTYCopyHandler) OnCopySkipped(_ context.Context, desc ocispec.Descriptor) error { + ch.committed.Store(desc.Digest.String(), desc.Annotations[ocispec.AnnotationTitle]) + return ch.tracked.Prompt(desc, copyPromptExists) +} + +// PreCopy implements PreCopy of CopyHandler. +func (ch *TTYCopyHandler) PreCopy(context.Context, ocispec.Descriptor) error { + return nil +} + +// PostCopy implements PostCopy of CopyHandler. +func (ch *TTYCopyHandler) PostCopy(ctx context.Context, desc ocispec.Descriptor) error { + ch.committed.Store(desc.Digest.String(), desc.Annotations[ocispec.AnnotationTitle]) + successors, err := graph.FilteredSuccessors(ctx, desc, ch.tracked, DeduplicatedFilter(&ch.committed)) + if err != nil { + return err + } + for _, successor := range successors { + if err = ch.tracked.Prompt(successor, copyPromptSkipped); err != nil { + return err + } + } + return nil +} + +// OnMounted implements OnMounted of CopyHandler. +func (ch *TTYCopyHandler) OnMounted(_ context.Context, desc ocispec.Descriptor) error { + ch.committed.Store(desc.Digest.String(), desc.Annotations[ocispec.AnnotationTitle]) + return ch.tracked.Prompt(desc, copyPromptMounted) +} diff --git a/cmd/oras/internal/display/status/tty_console_test.go b/cmd/oras/internal/display/status/tty_console_test.go index 426285f62..2a0efb7f9 100644 --- a/cmd/oras/internal/display/status/tty_console_test.go +++ b/cmd/oras/internal/display/status/tty_console_test.go @@ -18,11 +18,18 @@ limitations under the License. package status import ( + "strconv" + "testing" + + "oras.land/oras-go/v2" "oras.land/oras-go/v2/content/memory" "oras.land/oras/internal/testutils" - "testing" ) +type testGraphTarget struct { + oras.GraphTarget +} + func TestTTYPushHandler_TrackTarget(t *testing.T) { // prepare pty _, slave, err := testutils.NewPty() @@ -78,3 +85,100 @@ func Test_TTYPullHandler_TrackTarget(t *testing.T) { } }) } + +func TestTTYCopyHandler_OnMounted(t *testing.T) { + pty, slave, err := testutils.NewPty() + if err != nil { + t.Fatal(err) + } + defer slave.Close() + ch := NewTTYCopyHandler(slave) + _, err = ch.StartTracking(&testGraphTarget{memory.New()}) + if err != nil { + t.Fatal(err) + } + + if err = ch.OnMounted(ctx, mockFetcher.OciImage); err != nil { + t.Fatalf("OnMounted() should not return an error: %v", err) + } + + if err = ch.StopTracking(); err != nil { + t.Fatalf("StopTracking() should not return an error: %v", err) + } + + if err = testutils.MatchPty(pty, slave, "✓", "Mounted", strconv.FormatInt(mockFetcher.OciImage.Size, 10), "100.00%", mockFetcher.OciImage.Digest.String()); err != nil { + t.Fatal(err) + } +} + +func TestTTYCopyHandler_OnCopySkipped(t *testing.T) { + pty, slave, err := testutils.NewPty() + if err != nil { + t.Fatal(err) + } + defer slave.Close() + ch := NewTTYCopyHandler(slave) + _, err = ch.StartTracking(&testGraphTarget{memory.New()}) + if err != nil { + t.Fatal(err) + } + + if err = ch.OnCopySkipped(ctx, mockFetcher.OciImage); err != nil { + t.Errorf("OnCopySkipped() should not return an error: %v", err) + } + + if err = ch.StopTracking(); err != nil { + t.Errorf("StopTracking() should not return an error: %v", err) + } + if err = testutils.MatchPty(pty, slave, "Exists", "oci-image", strconv.FormatInt(mockFetcher.OciImage.Size, 10), "100.00%"); err != nil { + t.Fatal(err) + } +} + +func TestTTYCopyHandler_PostCopy(t *testing.T) { + pty, slave, err := testutils.NewPty() + if err != nil { + t.Fatal(err) + } + defer slave.Close() + ch := NewTTYCopyHandler(slave) + _, err = ch.StartTracking(&testGraphTarget{memory.New()}) + if err != nil { + t.Fatal(err) + } + + if ch.PostCopy(ctx, bogus) == nil { + t.Error("PostCopy() should return an error") + } + + if err = ch.StopTracking(); err != nil { + t.Errorf("StopTracking() should not return an error: %v", err) + } + if err = testutils.MatchPty(pty, slave, "\x1b[?25l\x1b7\x1b[0m"); err != nil { + t.Fatal(err) + } +} + +func TestTTYCopyHandler_PreCopy(t *testing.T) { + pty, slave, err := testutils.NewPty() + if err != nil { + t.Fatal(err) + } + defer slave.Close() + ch := NewTTYCopyHandler(slave) + _, err = ch.StartTracking(&testGraphTarget{memory.New()}) + if err != nil { + t.Fatal(err) + } + + if err = ch.PreCopy(ctx, mockFetcher.OciImage); err != nil { + t.Errorf("PreCopy() should not return an error: %v", err) + } + + if err = ch.StopTracking(); err != nil { + t.Errorf("StopTracking() should not return an error: %v", err) + } + if err = testutils.MatchPty(pty, slave, "\x1b[?25l\x1b7\x1b[0m"); err != nil { + t.Fatal(err) + } +} diff --git a/cmd/oras/root/cp.go b/cmd/oras/root/cp.go index 931613891..6e9a06d6c 100644 --- a/cmd/oras/root/cp.go +++ b/cmd/oras/root/cp.go @@ -21,9 +21,6 @@ import ( "fmt" "slices" "strings" - "sync" - - "oras.land/oras/cmd/oras/internal/display/status" "github.com/opencontainers/go-digest" ocispec "github.com/opencontainers/image-spec/specs-go/v1" @@ -36,7 +33,7 @@ import ( "oras.land/oras/cmd/oras/internal/argument" "oras.land/oras/cmd/oras/internal/command" "oras.land/oras/cmd/oras/internal/display" - "oras.land/oras/cmd/oras/internal/display/status/track" + "oras.land/oras/cmd/oras/internal/display/status" oerrors "oras.land/oras/cmd/oras/internal/errors" "oras.land/oras/cmd/oras/internal/option" "oras.land/oras/internal/docker" @@ -131,9 +128,9 @@ func runCopy(cmd *cobra.Command, opts *copyOptions) error { return err } ctx = registryutil.WithScopeHint(ctx, dst, auth.ActionPull, auth.ActionPush) - copyHandler, handler := display.NewCopyHandler(opts.Printer, dst) + statusHandler, metadataHandler := display.NewCopyHandler(opts.Printer, opts.TTY, dst) - desc, err := doCopy(ctx, copyHandler, src, dst, opts) + desc, err := doCopy(ctx, statusHandler, src, dst, opts) if err != nil { return err } @@ -147,7 +144,7 @@ func runCopy(cmd *cobra.Command, opts *copyOptions) error { if len(opts.extraRefs) != 0 { tagNOpts := oras.DefaultTagNOptions tagNOpts.Concurrency = opts.concurrency - tagListener := listener.NewTaggedListener(dst, handler.OnTagged) + tagListener := listener.NewTaggedListener(dst, metadataHandler.OnTagged) if _, err = oras.TagN(ctx, tagListener, opts.To.Reference, opts.extraRefs, tagNOpts); err != nil { return err } @@ -158,22 +155,14 @@ func runCopy(cmd *cobra.Command, opts *copyOptions) error { return nil } -func doCopy(ctx context.Context, copyHandler status.CopyHandler, src oras.ReadOnlyGraphTarget, dst oras.GraphTarget, opts *copyOptions) (ocispec.Descriptor, error) { +func doCopy(ctx context.Context, copyHandler status.CopyHandler, src oras.ReadOnlyGraphTarget, dst oras.GraphTarget, opts *copyOptions) (desc ocispec.Descriptor, err error) { // Prepare copy options - committed := &sync.Map{} extendedCopyOptions := oras.DefaultExtendedCopyOptions extendedCopyOptions.Concurrency = opts.concurrency extendedCopyOptions.FindPredecessors = func(ctx context.Context, src content.ReadOnlyGraphStorage, desc ocispec.Descriptor) ([]ocispec.Descriptor, error) { return registry.Referrers(ctx, src, desc, "") } - const ( - promptExists = "Exists " - promptCopying = "Copying" - promptCopied = "Copied " - promptSkipped = "Skipped" - promptMounted = "Mounted" - ) srcRepo, srcIsRemote := src.(*remote.Repository) dstRepo, dstIsRemote := dst.(*remote.Repository) if srcIsRemote && dstIsRemote && srcRepo.Reference.Registry == dstRepo.Reference.Registry { @@ -181,45 +170,21 @@ func doCopy(ctx context.Context, copyHandler status.CopyHandler, src oras.ReadOn return []string{srcRepo.Reference.Repository}, nil } } - if opts.TTY == nil { - // no TTY output - extendedCopyOptions.OnCopySkipped = copyHandler.OnCopySkipped - extendedCopyOptions.PreCopy = copyHandler.PreCopy - extendedCopyOptions.PostCopy = copyHandler.PostCopy - extendedCopyOptions.OnMounted = copyHandler.OnMounted - } else { - // TTY output - tracked, err := track.NewTarget(dst, promptCopying, promptCopied, opts.TTY) - if err != nil { - return ocispec.Descriptor{}, err - } - defer tracked.Close() - dst = tracked - extendedCopyOptions.OnCopySkipped = func(ctx context.Context, desc ocispec.Descriptor) error { - committed.Store(desc.Digest.String(), desc.Annotations[ocispec.AnnotationTitle]) - return tracked.Prompt(desc, promptExists) - } - extendedCopyOptions.PostCopy = func(ctx context.Context, desc ocispec.Descriptor) error { - committed.Store(desc.Digest.String(), desc.Annotations[ocispec.AnnotationTitle]) - successors, err := graph.FilteredSuccessors(ctx, desc, tracked, status.DeduplicatedFilter(committed)) - if err != nil { - return err - } - for _, successor := range successors { - if err = tracked.Prompt(successor, promptSkipped); err != nil { - return err - } - } - return nil - } - extendedCopyOptions.OnMounted = func(ctx context.Context, desc ocispec.Descriptor) error { - committed.Store(desc.Digest.String(), desc.Annotations[ocispec.AnnotationTitle]) - return tracked.Prompt(desc, promptMounted) - } + dst, err = copyHandler.StartTracking(dst) + if err != nil { + return desc, err } + defer func() { + stopErr := copyHandler.StopTracking() + if err == nil { + err = stopErr + } + }() + extendedCopyOptions.OnCopySkipped = copyHandler.OnCopySkipped + extendedCopyOptions.PreCopy = copyHandler.PreCopy + extendedCopyOptions.PostCopy = copyHandler.PostCopy + extendedCopyOptions.OnMounted = copyHandler.OnMounted - var desc ocispec.Descriptor - var err error rOpts := oras.DefaultResolveOptions rOpts.TargetPlatform = opts.Platform.Platform if opts.recursive { diff --git a/cmd/oras/root/cp_test.go b/cmd/oras/root/cp_test.go index f625f5669..1e9b19ea1 100644 --- a/cmd/oras/root/cp_test.go +++ b/cmd/oras/root/cp_test.go @@ -25,16 +25,13 @@ import ( "net/http/httptest" "net/url" "os" - "strings" "testing" - "oras.land/oras/cmd/oras/internal/display/status" - "oras.land/oras/cmd/oras/internal/output" - "github.com/opencontainers/go-digest" ocispec "github.com/opencontainers/image-spec/specs-go/v1" "oras.land/oras-go/v2/content/memory" "oras.land/oras-go/v2/registry/remote" + "oras.land/oras/cmd/oras/internal/display/status" "oras.land/oras/internal/testutils" ) @@ -132,10 +129,7 @@ func Test_doCopy(t *testing.T) { opts.TTY = slave opts.From.Reference = memDesc.Digest.String() dst := memory.New() - builder := &strings.Builder{} - printer := output.NewPrinter(builder, os.Stderr) - printer.Verbose = true - handler := status.NewTextCopyHandler(printer, dst) + handler := status.NewTTYCopyHandler(opts.TTY) // test _, err = doCopy(context.Background(), handler, memStore, dst, &opts) if err != nil { @@ -157,10 +151,7 @@ func Test_doCopy_skipped(t *testing.T) { var opts copyOptions opts.TTY = slave opts.From.Reference = memDesc.Digest.String() - builder := &strings.Builder{} - printer := output.NewPrinter(builder, os.Stderr) - printer.Verbose = true - handler := status.NewTextCopyHandler(printer, memStore) + handler := status.NewTTYCopyHandler(opts.TTY) // test _, err = doCopy(context.Background(), handler, memStore, memStore, &opts) @@ -194,10 +185,7 @@ func Test_doCopy_mounted(t *testing.T) { t.Fatal(err) } to.PlainHTTP = true - builder := &strings.Builder{} - printer := output.NewPrinter(builder, os.Stderr) - printer.Verbose = true - handler := status.NewTextCopyHandler(printer, to) + handler := status.NewTTYCopyHandler(opts.TTY) // test _, err = doCopy(context.Background(), handler, from, to, &opts)