From 10b0f2bf822a61b1acc05bad7111af76c85e9526 Mon Sep 17 00:00:00 2001 From: Terry Howe Date: Mon, 19 Aug 2024 17:36:12 -0600 Subject: [PATCH] refactor: create tty copy handler Signed-off-by: Terry Howe --- cmd/oras/internal/display/handler.go | 5 +- cmd/oras/internal/display/handler_test.go | 23 ++++- cmd/oras/internal/display/status/interface.go | 2 + cmd/oras/internal/display/status/text.go | 9 ++ cmd/oras/internal/display/status/tty.go | 59 ++++++++++++ .../display/status/tty_console_test.go | 93 +++++++++++++++++++ cmd/oras/root/cp.go | 62 +++---------- cmd/oras/root/cp_test.go | 45 +++++---- 8 files changed, 221 insertions(+), 77 deletions(-) diff --git a/cmd/oras/internal/display/handler.go b/cmd/oras/internal/display/handler.go index ccad542aa..a114a9fe5 100644 --- a/cmd/oras/internal/display/handler.go +++ b/cmd/oras/internal/display/handler.go @@ -180,6 +180,9 @@ func NewManifestIndexCreateHandler(printer *output.Printer) metadata.ManifestInd } // NewCopyHandler returns copy handlers. -func NewCopyHandler(printer *output.Printer, fetcher fetcher.Fetcher) (status.CopyHandler, metadata.CopyHandler) { +func NewCopyHandler(printer *output.Printer, tty *os.File, fetcher fetcher.Fetcher) (status.CopyHandler, metadata.CopyHandler) { + if tty != nil { + return status.NewTTYCopyHandler(tty), text.NewCopyHandler(printer) + } return status.NewTextCopyHandler(printer, fetcher), text.NewCopyHandler(printer) } diff --git a/cmd/oras/internal/display/handler_test.go b/cmd/oras/internal/display/handler_test.go index f1016e6f0..dfac80737 100644 --- a/cmd/oras/internal/display/handler_test.go +++ b/cmd/oras/internal/display/handler_test.go @@ -16,12 +16,15 @@ limitations under the License. package display import ( - "oras.land/oras/internal/testutils" "os" + "reflect" "testing" + "oras.land/oras/cmd/oras/internal/display/metadata/text" + "oras.land/oras/cmd/oras/internal/display/status" "oras.land/oras/cmd/oras/internal/option" "oras.land/oras/cmd/oras/internal/output" + "oras.land/oras/internal/testutils" ) func TestNewPushHandler(t *testing.T) { @@ -49,3 +52,21 @@ func TestNewPullHandler(t *testing.T) { t.Errorf("NewPullHandler() error = %v, want nil", err) } } + +func TestNewCopyHandler(t *testing.T) { + printer := output.NewPrinter(os.Stdout, os.Stderr, false) + copyHandler, copyMetadataHandler := NewCopyHandler(printer, os.Stdout, nil) + if _, ok := copyHandler.(*status.TTYCopyHandler); !ok { + t.Errorf("expected *status.TTYCopyHandler actual %v", reflect.TypeOf(copyHandler)) + } + if _, ok := copyMetadataHandler.(*text.CopyHandler); !ok { + t.Errorf("expected metadata.CopyHandler actual %v", reflect.TypeOf(copyMetadataHandler)) + } + copyHandler, copyMetadataHandler = NewCopyHandler(printer, nil, nil) + if _, ok := copyHandler.(*status.TextCopyHandler); !ok { + t.Errorf("expected *status.TextCopyHandler actual %v", reflect.TypeOf(copyHandler)) + } + if _, ok := copyMetadataHandler.(*text.CopyHandler); !ok { + t.Errorf("expected metadata.CopyHandler actual %v", reflect.TypeOf(copyMetadataHandler)) + } +} diff --git a/cmd/oras/internal/display/status/interface.go b/cmd/oras/internal/display/status/interface.go index c2f0bd8b8..791dcb620 100644 --- a/cmd/oras/internal/display/status/interface.go +++ b/cmd/oras/internal/display/status/interface.go @@ -60,4 +60,6 @@ type CopyHandler interface { PreCopy(ctx context.Context, desc ocispec.Descriptor) error PostCopy(ctx context.Context, desc ocispec.Descriptor) error OnMounted(ctx context.Context, desc ocispec.Descriptor) error + StartTracking(gt oras.GraphTarget) (oras.GraphTarget, error) + StopTracking() } diff --git a/cmd/oras/internal/display/status/text.go b/cmd/oras/internal/display/status/text.go index c4277199f..dd0498f68 100644 --- a/cmd/oras/internal/display/status/text.go +++ b/cmd/oras/internal/display/status/text.go @@ -148,6 +148,15 @@ func NewTextCopyHandler(printer *output.Printer, fetcher content.Fetcher) CopyHa } } +// StartTracking starts a tracked target from a graph target. +func (ch *TextCopyHandler) StartTracking(gt oras.GraphTarget) (oras.GraphTarget, error) { + return gt, nil +} + +// StopTracking ends the copy tracking for the target. +func (ch *TextCopyHandler) StopTracking() { +} + // OnCopySkipped is called when an object already exists. func (ch *TextCopyHandler) OnCopySkipped(_ context.Context, desc ocispec.Descriptor) error { ch.committed.Store(desc.Digest.String(), desc.Annotations[ocispec.AnnotationTitle]) diff --git a/cmd/oras/internal/display/status/tty.go b/cmd/oras/internal/display/status/tty.go index 78f639076..f94eae61e 100644 --- a/cmd/oras/internal/display/status/tty.go +++ b/cmd/oras/internal/display/status/tty.go @@ -143,3 +143,62 @@ func (ph *TTYPullHandler) TrackTarget(gt oras.GraphTarget) (oras.GraphTarget, St ph.tracked = tracked return tracked, tracked.Close, nil } + +// TTYCopyHandler handles tty status output for copy events. +type TTYCopyHandler struct { + tty *os.File + committed *sync.Map + tracked track.GraphTarget +} + +// NewTTYCopyHandler returns a new handler for copy command. +func NewTTYCopyHandler(tty *os.File) CopyHandler { + return &TTYCopyHandler{ + tty: tty, + committed: &sync.Map{}, + } +} + +// StartTracking returns a tracked target from a graph target. +func (ch *TTYCopyHandler) StartTracking(gt oras.GraphTarget) (oras.GraphTarget, error) { + tracked, err := track.NewTarget(gt, copyPromptCopying, copyPromptCopied, ch.tty) + ch.tracked = tracked + return tracked, err +} + +// StopTracking ends the copy tracking for the target. +func (ch *TTYCopyHandler) StopTracking() { + _ = ch.tracked.Close() +} + +// OnCopySkipped is called when an object already exists. +func (ch *TTYCopyHandler) OnCopySkipped(_ context.Context, desc ocispec.Descriptor) error { + ch.committed.Store(desc.Digest.String(), desc.Annotations[ocispec.AnnotationTitle]) + return ch.tracked.Prompt(desc, copyPromptExists) +} + +// PreCopy implements PreCopy of CopyHandler. +func (ch *TTYCopyHandler) PreCopy(_ context.Context, _ ocispec.Descriptor) error { + return nil +} + +// PostCopy implements PostCopy of CopyHandler. +func (ch *TTYCopyHandler) PostCopy(ctx context.Context, desc ocispec.Descriptor) error { + ch.committed.Store(desc.Digest.String(), desc.Annotations[ocispec.AnnotationTitle]) + successors, err := graph.FilteredSuccessors(ctx, desc, ch.tracked, DeduplicatedFilter(ch.committed)) + if err != nil { + return err + } + for _, successor := range successors { + if err = ch.tracked.Prompt(successor, copyPromptSkipped); err != nil { + return err + } + } + return nil +} + +// OnMounted implements OnMounted of CopyHandler. +func (ch *TTYCopyHandler) OnMounted(_ context.Context, desc ocispec.Descriptor) error { + ch.committed.Store(desc.Digest.String(), desc.Annotations[ocispec.AnnotationTitle]) + return ch.tracked.Prompt(desc, copyPromptMounted) +} diff --git a/cmd/oras/internal/display/status/tty_console_test.go b/cmd/oras/internal/display/status/tty_console_test.go index 426285f62..f85237cce 100644 --- a/cmd/oras/internal/display/status/tty_console_test.go +++ b/cmd/oras/internal/display/status/tty_console_test.go @@ -18,11 +18,16 @@ limitations under the License. package status import ( + "oras.land/oras-go/v2" "oras.land/oras-go/v2/content/memory" "oras.land/oras/internal/testutils" "testing" ) +type testGraphTarget struct { + oras.GraphTarget +} + func TestTTYPushHandler_TrackTarget(t *testing.T) { // prepare pty _, slave, err := testutils.NewPty() @@ -78,3 +83,91 @@ func Test_TTYPullHandler_TrackTarget(t *testing.T) { } }) } + +func TestTTYCopyHandler_OnMounted(t *testing.T) { + pty, slave, err := testutils.NewPty() + if err != nil { + t.Fatal(err) + } + defer slave.Close() + ch := NewTTYCopyHandler(slave) + _, err = ch.StartTracking(&testGraphTarget{memory.New()}) + if err != nil { + t.Fatal(err) + } + defer ch.StopTracking() + + if err = ch.OnMounted(ctx, mockFetcher.OciImage); err != nil { + t.Errorf("OnMounted() should not return an error: %v", err) + } + + if err = testutils.MatchPty(pty, slave, "\x1b[?25l\x1b7\x1b[0m"); err != nil { + t.Fatal(err) + } +} + +func TestTTYCopyHandler_OnCopySkipped(t *testing.T) { + pty, slave, err := testutils.NewPty() + if err != nil { + t.Fatal(err) + } + defer slave.Close() + ch := NewTTYCopyHandler(slave) + _, err = ch.StartTracking(&testGraphTarget{memory.New()}) + if err != nil { + t.Fatal(err) + } + defer ch.StopTracking() + + if err = ch.OnCopySkipped(ctx, mockFetcher.OciImage); err != nil { + t.Errorf("OnCopySkipped() should not return an error: %v", err) + } + + if err = testutils.MatchPty(pty, slave, "\x1b[?25l\x1b7\x1b[0m"); err != nil { + t.Fatal(err) + } +} + +func TestTTYCopyHandler_PostCopy(t *testing.T) { + pty, slave, err := testutils.NewPty() + if err != nil { + t.Fatal(err) + } + defer slave.Close() + ch := NewTTYCopyHandler(slave) + _, err = ch.StartTracking(&testGraphTarget{memory.New()}) + if err != nil { + t.Fatal(err) + } + defer ch.StopTracking() + + if ch.PostCopy(ctx, bogus) == nil { + t.Error("PostCopy() should return an error") + } + + if err = testutils.MatchPty(pty, slave, "\x1b[?25l\x1b7\x1b[0m"); err != nil { + t.Fatal(err) + } +} + +func TestTTYCopyHandler_PreCopy(t *testing.T) { + pty, slave, err := testutils.NewPty() + if err != nil { + t.Fatal(err) + } + defer slave.Close() + ch := NewTTYCopyHandler(slave) + _, err = ch.StartTracking(&testGraphTarget{memory.New()}) + if err != nil { + t.Fatal(err) + } + defer ch.StopTracking() + + if err = ch.PreCopy(ctx, mockFetcher.OciImage); err != nil { + t.Errorf("PreCopy() should not return an error: %v", err) + } + + if err = testutils.MatchPty(pty, slave, "\x1b[?25l\x1b7\x1b[0m"); err != nil { + t.Fatal(err) + } +} diff --git a/cmd/oras/root/cp.go b/cmd/oras/root/cp.go index 8cf5ac250..3d0497346 100644 --- a/cmd/oras/root/cp.go +++ b/cmd/oras/root/cp.go @@ -21,9 +21,6 @@ import ( "fmt" "slices" "strings" - "sync" - - "oras.land/oras/cmd/oras/internal/display/status" "github.com/opencontainers/go-digest" ocispec "github.com/opencontainers/image-spec/specs-go/v1" @@ -36,7 +33,7 @@ import ( "oras.land/oras/cmd/oras/internal/argument" "oras.land/oras/cmd/oras/internal/command" "oras.land/oras/cmd/oras/internal/display" - "oras.land/oras/cmd/oras/internal/display/status/track" + "oras.land/oras/cmd/oras/internal/display/status" oerrors "oras.land/oras/cmd/oras/internal/errors" "oras.land/oras/cmd/oras/internal/option" "oras.land/oras/internal/docker" @@ -128,7 +125,7 @@ func runCopy(cmd *cobra.Command, opts *copyOptions) error { return err } ctx = registryutil.WithScopeHint(ctx, dst, auth.ActionPull, auth.ActionPush) - copyHandler, handler := display.NewCopyHandler(opts.Printer, dst) + copyHandler, handler := display.NewCopyHandler(opts.Printer, opts.TTY, dst) desc, err := doCopy(ctx, copyHandler, src, dst, opts) if err != nil { @@ -155,22 +152,14 @@ func runCopy(cmd *cobra.Command, opts *copyOptions) error { return nil } -func doCopy(ctx context.Context, copyHandler status.CopyHandler, src oras.ReadOnlyGraphTarget, dst oras.GraphTarget, opts *copyOptions) (ocispec.Descriptor, error) { +func doCopy(ctx context.Context, copyHandler status.CopyHandler, src oras.ReadOnlyGraphTarget, dst oras.GraphTarget, opts *copyOptions) (desc ocispec.Descriptor, err error) { // Prepare copy options - committed := &sync.Map{} extendedCopyOptions := oras.DefaultExtendedCopyOptions extendedCopyOptions.Concurrency = opts.concurrency extendedCopyOptions.FindPredecessors = func(ctx context.Context, src content.ReadOnlyGraphStorage, desc ocispec.Descriptor) ([]ocispec.Descriptor, error) { return registry.Referrers(ctx, src, desc, "") } - const ( - promptExists = "Exists " - promptCopying = "Copying" - promptCopied = "Copied " - promptSkipped = "Skipped" - promptMounted = "Mounted" - ) srcRepo, srcIsRemote := src.(*remote.Repository) dstRepo, dstIsRemote := dst.(*remote.Repository) if srcIsRemote && dstIsRemote && srcRepo.Reference.Registry == dstRepo.Reference.Registry { @@ -178,45 +167,16 @@ func doCopy(ctx context.Context, copyHandler status.CopyHandler, src oras.ReadOn return []string{srcRepo.Reference.Repository}, nil } } - if opts.TTY == nil { - // no TTY output - extendedCopyOptions.OnCopySkipped = copyHandler.OnCopySkipped - extendedCopyOptions.PreCopy = copyHandler.PreCopy - extendedCopyOptions.PostCopy = copyHandler.PostCopy - extendedCopyOptions.OnMounted = copyHandler.OnMounted - } else { - // TTY output - tracked, err := track.NewTarget(dst, promptCopying, promptCopied, opts.TTY) - if err != nil { - return ocispec.Descriptor{}, err - } - defer tracked.Close() - dst = tracked - extendedCopyOptions.OnCopySkipped = func(ctx context.Context, desc ocispec.Descriptor) error { - committed.Store(desc.Digest.String(), desc.Annotations[ocispec.AnnotationTitle]) - return tracked.Prompt(desc, promptExists) - } - extendedCopyOptions.PostCopy = func(ctx context.Context, desc ocispec.Descriptor) error { - committed.Store(desc.Digest.String(), desc.Annotations[ocispec.AnnotationTitle]) - successors, err := graph.FilteredSuccessors(ctx, desc, tracked, status.DeduplicatedFilter(committed)) - if err != nil { - return err - } - for _, successor := range successors { - if err = tracked.Prompt(successor, promptSkipped); err != nil { - return err - } - } - return nil - } - extendedCopyOptions.OnMounted = func(ctx context.Context, desc ocispec.Descriptor) error { - committed.Store(desc.Digest.String(), desc.Annotations[ocispec.AnnotationTitle]) - return tracked.Prompt(desc, promptMounted) - } + dst, err = copyHandler.StartTracking(dst) + if err != nil { + return desc, err } + defer copyHandler.StopTracking() + extendedCopyOptions.OnCopySkipped = copyHandler.OnCopySkipped + extendedCopyOptions.PreCopy = copyHandler.PreCopy + extendedCopyOptions.PostCopy = copyHandler.PostCopy + extendedCopyOptions.OnMounted = copyHandler.OnMounted - var desc ocispec.Descriptor - var err error rOpts := oras.DefaultResolveOptions rOpts.TargetPlatform = opts.Platform.Platform if opts.recursive { diff --git a/cmd/oras/root/cp_test.go b/cmd/oras/root/cp_test.go index c02ef5ea1..b0bf7a4d6 100644 --- a/cmd/oras/root/cp_test.go +++ b/cmd/oras/root/cp_test.go @@ -24,8 +24,6 @@ import ( "net/http" "net/http/httptest" "net/url" - "oras.land/oras/cmd/oras/internal/display/status" - "oras.land/oras/cmd/oras/internal/output" "os" "strings" "testing" @@ -34,6 +32,8 @@ import ( ocispec "github.com/opencontainers/image-spec/specs-go/v1" "oras.land/oras-go/v2/content/memory" "oras.land/oras-go/v2/registry/remote" + "oras.land/oras/cmd/oras/internal/display/status" + "oras.land/oras/cmd/oras/internal/output" "oras.land/oras/internal/testutils" ) @@ -122,53 +122,51 @@ func TestMain(m *testing.M) { func Test_doCopy(t *testing.T) { // prepare - pty, slave, err := testutils.NewPty() - if err != nil { - t.Fatal(err) - } - defer slave.Close() var opts copyOptions - opts.TTY = slave opts.Verbose = true opts.From.Reference = memDesc.Digest.String() dst := memory.New() builder := &strings.Builder{} printer := output.NewPrinter(builder, os.Stderr, opts.Verbose) handler := status.NewTextCopyHandler(printer, dst) + // test - _, err = doCopy(context.Background(), handler, memStore, dst, &opts) + _, err := doCopy(context.Background(), handler, memStore, dst, &opts) if err != nil { t.Fatal(err) } // validate - if err = testutils.MatchPty(pty, slave, "Copied", memDesc.MediaType, "100.00%", memDesc.Digest.String()); err != nil { - t.Fatal(err) + actual := builder.String() + if strings.Contains(actual, configDigest) { + t.Errorf("Expected <%s> to contain <%s>", actual, configDigest) + } + if strings.Contains(actual, configMediaType) { + t.Errorf("Expected <%s> to contain <%s>", actual, configMediaType) } } func Test_doCopy_skipped(t *testing.T) { // prepare - pty, slave, err := testutils.NewPty() - if err != nil { - t.Fatal(err) - } - defer slave.Close() var opts copyOptions - opts.TTY = slave opts.Verbose = true opts.From.Reference = memDesc.Digest.String() + dst := memory.New() builder := &strings.Builder{} printer := output.NewPrinter(builder, os.Stderr, opts.Verbose) - handler := status.NewTextCopyHandler(printer, memStore) + handler := status.NewTextCopyHandler(printer, dst) // test - _, err = doCopy(context.Background(), handler, memStore, memStore, &opts) + _, err := doCopy(context.Background(), handler, memStore, memStore, &opts) if err != nil { t.Fatal(err) } // validate - if err = testutils.MatchPty(pty, slave, "Exists", memDesc.MediaType, "100.00%", memDesc.Digest.String()); err != nil { - t.Fatal(err) + actual := builder.String() + if strings.Contains(actual, configDigest) { + t.Errorf("Expected <%s> to contain <%s>", actual, configDigest) + } + if strings.Contains(actual, configMediaType) { + t.Errorf("Expected <%s> to contain <%s>", actual, configMediaType) } } @@ -194,9 +192,8 @@ func Test_doCopy_mounted(t *testing.T) { t.Fatal(err) } to.PlainHTTP = true - builder := &strings.Builder{} - printer := output.NewPrinter(builder, os.Stderr, opts.Verbose) - handler := status.NewTextCopyHandler(printer, to) + handler := status.NewTTYCopyHandler(slave) + _, _ = handler.StartTracking(memStore) // test _, err = doCopy(context.Background(), handler, from, to, &opts)