diff --git a/cmd/auth_test.go b/cmd/auth_test.go index b27e8c0..6e53a93 100644 --- a/cmd/auth_test.go +++ b/cmd/auth_test.go @@ -17,8 +17,6 @@ import ( ) func Test_Auth_Login_WithToken(t *testing.T) { - t.Cleanup(sessionCleanup) - ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -32,12 +30,7 @@ func Test_Auth_Login_WithToken(t *testing.T) { r.WriteString("token\n") printer := view.NewPrinter(r, w, w) ctx := createDefaultContext("") - _storage := storage.NewLocalStorage(".test_globalping-cli") - defer _storage.Remove() - err := _storage.Init() - if err != nil { - t.Fatal(err) - } + _storage := createDefaultTestStorage(t, utilsMock) _storage.GetProfile().Token = &globalping.Token{ AccessToken: "oldToken", RefreshToken: "oldRefreshToken", @@ -52,7 +45,7 @@ func Test_Auth_Login_WithToken(t *testing.T) { gbMock.EXPECT().RevokeToken("oldRefreshToken").Return(nil) os.Args = []string{"globalping", "auth", "login", "--with-token"} - err = root.Cmd.ExecuteContext(context.TODO()) + err := root.Cmd.ExecuteContext(context.TODO()) assert.NoError(t, err) assert.Equal(t, `Please enter your token: @@ -69,8 +62,6 @@ Logged in as test. } func Test_Auth_Login(t *testing.T) { - t.Cleanup(sessionCleanup) - ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -80,12 +71,7 @@ func Test_Auth_Login(t *testing.T) { w := new(bytes.Buffer) printer := view.NewPrinter(nil, w, w) ctx := createDefaultContext("") - _storage := storage.NewLocalStorage(".test_globalping-cli") - defer _storage.Remove() - err := _storage.Init() - if err != nil { - t.Fatal(err) - } + _storage := createDefaultTestStorage(t, utilsMock) _storage.GetProfile().Token = &globalping.Token{ AccessToken: "oldToken", RefreshToken: "oldRefreshToken", @@ -101,7 +87,7 @@ func Test_Auth_Login(t *testing.T) { utilsMock.EXPECT().OpenBrowser("http://localhost").Return(nil) os.Args = []string{"globalping", "auth", "login"} - err = root.Cmd.ExecuteContext(context.TODO()) + err := root.Cmd.ExecuteContext(context.TODO()) assert.NoError(t, err) assert.Equal(t, `Please visit the following URL to authenticate: @@ -112,8 +98,6 @@ Can't use the browser-based flow? Use "globalping auth login --with-token" to re } func Test_AuthStatus(t *testing.T) { - t.Cleanup(sessionCleanup) - ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -139,8 +123,6 @@ func Test_AuthStatus(t *testing.T) { } func Test_Logout(t *testing.T) { - t.Cleanup(sessionCleanup) - ctrl := gomock.NewController(t) defer ctrl.Finish() diff --git a/cmd/common.go b/cmd/common.go index 64231d2..b5e64bb 100644 --- a/cmd/common.go +++ b/cmd/common.go @@ -1,40 +1,25 @@ package cmd import ( - "bufio" "errors" "fmt" - "io" - "io/fs" "net" "net/http" "os" - "path/filepath" - "runtime" "slices" "strconv" "strings" - "github.com/icza/backscanner" "github.com/jsdelivr/globalping-cli/globalping" + "github.com/jsdelivr/globalping-cli/storage" "github.com/jsdelivr/globalping-cli/version" "github.com/jsdelivr/globalping-cli/view" - "github.com/shirou/gopsutil/process" ) var ( - ErrNoPreviousMeasurements = errors.New("no previous measurements found") - ErrInvalidIndex = errors.New("invalid index") - ErrIndexOutOfRange = errors.New("index out of range") ErrTargetIPVersionNotAllowed = errors.New("ipVersion is not allowed when target is not a domain") ErrResolverIPVersionNotAllowed = errors.New("ipVersion is not allowed when resolver is not a domain") ) -var ( - saveIdToSessionErr = "failed to save measurement ID: %s" - readMeasuremetsErr = "failed to read previous measurements: %s" -) - -var SESSION_PATH string func (r *Root) updateContext(cmd string, args []string) error { r.ctx.Cmd = cmd // Get the command name @@ -94,7 +79,7 @@ func (r *Root) updateContext(cmd string, args []string) error { func (r *Root) getLocations() ([]globalping.Locations, error) { fromArr := strings.Split(r.ctx.From, ",") if len(fromArr) == 1 { - mId, err := mapFromSession(fromArr[0]) + mId, err := r.mapFromSession(fromArr[0]) if err != nil { return nil, err } @@ -198,139 +183,26 @@ func findAndRemoveResolver(args []string) (string, []string) { } // Maps a location to a measurement ID from history, if possible. -func mapFromSession(location string) (string, error) { +func (r *Root) mapFromSession(location string) (string, error) { if location == "" { return "", nil } if location[0] == '@' { index, err := strconv.Atoi(location[1:]) if err != nil { - return "", ErrInvalidIndex + return "", storage.ErrInvalidIndex } - return getIdFromSession(index) + return r.storage.GetIdFromSession(index) } if location == "first" { - return getIdFromSession(1) + return r.storage.GetIdFromSession(1) } if location == "last" || location == "previous" { - return getIdFromSession(-1) + return r.storage.GetIdFromSession(-1) } return "", nil } -// Returns the measurement ID at the given index from the session history -func getIdFromSession(index int) (string, error) { - if index == 0 { - return "", ErrInvalidIndex - } - f, err := os.Open(getMeasurementsPath()) - if err != nil { - if errors.Is(err, fs.ErrNotExist) { - return "", ErrNoPreviousMeasurements - } - return "", fmt.Errorf(readMeasuremetsErr, err) - } - defer f.Close() - // Read ids from the end of the file - if index < 0 { - fStats, err := f.Stat() - if err != nil { - return "", fmt.Errorf(readMeasuremetsErr, err) - } - if fStats.Size() == 0 { - return "", ErrNoPreviousMeasurements - } - scanner := backscanner.New(f, int(fStats.Size()-1)) // -1 to skip last newline - for { - index++ - b, _, err := scanner.LineBytes() - if err != nil { - if err == io.EOF { - return "", ErrIndexOutOfRange - } - return "", fmt.Errorf(readMeasuremetsErr, err) - } - if index == 0 { - return string(b), nil - } - } - } - // Read ids from the beginning of the file - scanner := bufio.NewScanner(f) - for scanner.Scan() { - index-- - if index == 0 { - return scanner.Text(), nil - } - } - if err := scanner.Err(); err != nil { - return "", fmt.Errorf("failed to read previous measurements: %s", err) - } - return "", ErrIndexOutOfRange -} - -// Saves the measurement ID to the session history -func saveIdToSession(id string) error { - _, err := os.Stat(getSessionPath()) - if err != nil { - if errors.Is(err, fs.ErrNotExist) { - err := os.Mkdir(getSessionPath(), 0755) - if err != nil { - return fmt.Errorf(saveIdToSessionErr, err) - } - } else { - return fmt.Errorf(saveIdToSessionErr, err) - } - } - f, err := os.OpenFile(getMeasurementsPath(), os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) - if err != nil { - return fmt.Errorf(saveIdToSessionErr, err) - } - defer f.Close() - _, err = f.WriteString(id + "\n") - if err != nil { - return fmt.Errorf(saveIdToSessionErr, err) - } - return nil -} - -func getSessionPath() string { - if SESSION_PATH != "" { - return SESSION_PATH - } - SESSION_PATH = filepath.Join(os.TempDir(), getSessionId()) - return SESSION_PATH -} - -func getSessionId() string { - p, err := process.NewProcess(int32(os.Getppid())) - if err != nil { - return "globalping" - } - // Workaround for bash.exe on Windows - // PPID is different on each run. - // https://cygwin.com/git/gitweb.cgi?p=newlib-cygwin.git;a=commit;h=448cf5aa4b429d5a9cebf92a0da4ab4b5b6d23fe - if runtime.GOOS == "windows" { - name, _ := p.Name() - if name == "bash.exe" { - p, err = p.Parent() - if err != nil { - return "globalping" - } - } - } - createTime, _ := p.CreateTime() - return fmt.Sprintf("globalping_%d_%d", p.Pid, createTime) -} - -func getMeasurementsPath() string { - return filepath.Join(getSessionPath(), "measurements") -} - -func getHistoryPath() string { - return filepath.Join(getSessionPath(), "history") -} - func silenceUsageOnCreateMeasurementError(err error) bool { e, ok := err.(*globalping.MeasurementError) if ok { diff --git a/cmd/dns.go b/cmd/dns.go index 415b447..f45411b 100644 --- a/cmd/dns.go +++ b/cmd/dns.go @@ -118,7 +118,7 @@ func (r *Root) RunDNS(cmd *cobra.Command, args []string) error { r.ctx.History.Push(hm) if r.ctx.RecordToSession { r.ctx.RecordToSession = false - err := saveIdToSession(res.ID) + err := r.storage.SaveIdToSession(res.ID) if err != nil { r.printer.Printf("Warning: %s\n", err) } diff --git a/cmd/dns_test.go b/cmd/dns_test.go index 1378f0f..edb3889 100644 --- a/cmd/dns_test.go +++ b/cmd/dns_test.go @@ -14,8 +14,6 @@ import ( ) func Test_Execute_DNS_Default(t *testing.T) { - t.Cleanup(sessionCleanup) - ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -43,7 +41,8 @@ func Test_Execute_DNS_Default(t *testing.T) { w := new(bytes.Buffer) printer := view.NewPrinter(nil, w, w) ctx := createDefaultContext("dns") - root := NewRoot(printer, ctx, viewerMock, utilsMock, gbMock, nil, nil) + _storage := createDefaultTestStorage(t, utilsMock) + root := NewRoot(printer, ctx, viewerMock, utilsMock, gbMock, nil, _storage) os.Args = []string{"globalping", "dns", "jsdelivr.com", "from", "Berlin", @@ -68,24 +67,22 @@ func Test_Execute_DNS_Default(t *testing.T) { assert.Equal(t, expectedCtx, ctx) - b, err := os.ReadFile(getMeasurementsPath()) + b, err := _storage.GetMeasurements() assert.NoError(t, err) expectedHistory := measurementID1 + "\n" assert.Equal(t, expectedHistory, string(b)) - b, err = os.ReadFile(getHistoryPath()) + items, err := _storage.GetHistory(0) assert.NoError(t, err) - expectedHistory = createDefaultExpectedHistoryLogItem( + expectedHistoryItems := []string{createDefaultExpectedHistoryItem( "1", - measurementID1, "dns jsdelivr.com from Berlin --limit 2 --type MX --resolver 1.1.1.1 --port 99 --protocol tcp --trace", - ) - assert.Equal(t, expectedHistory, string(b)) + measurementID1, + )} + assert.Equal(t, expectedHistoryItems, items) } func Test_Execute_DNS_IPv4(t *testing.T) { - t.Cleanup(sessionCleanup) - ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -107,7 +104,8 @@ func Test_Execute_DNS_IPv4(t *testing.T) { w := new(bytes.Buffer) printer := view.NewPrinter(nil, w, w) ctx := createDefaultContext("dns") - root := NewRoot(printer, ctx, viewerMock, utilsMock, gbMock, nil, nil) + _storage := createDefaultTestStorage(t, utilsMock) + root := NewRoot(printer, ctx, viewerMock, utilsMock, gbMock, nil, _storage) os.Args = []string{"globalping", "dns", "jsdelivr.com", "from", "Berlin", @@ -124,8 +122,6 @@ func Test_Execute_DNS_IPv4(t *testing.T) { } func Test_Execute_DNS_IPv6(t *testing.T) { - t.Cleanup(sessionCleanup) - ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -147,7 +143,8 @@ func Test_Execute_DNS_IPv6(t *testing.T) { w := new(bytes.Buffer) printer := view.NewPrinter(nil, w, w) ctx := createDefaultContext("dns") - root := NewRoot(printer, ctx, viewerMock, utilsMock, gbMock, nil, nil) + _storage := createDefaultTestStorage(t, utilsMock) + root := NewRoot(printer, ctx, viewerMock, utilsMock, gbMock, nil, _storage) os.Args = []string{"globalping", "dns", "jsdelivr.com", "from", "Berlin", diff --git a/cmd/history.go b/cmd/history.go index e1b662b..159e424 100644 --- a/cmd/history.go +++ b/cmd/history.go @@ -1,35 +1,13 @@ package cmd import ( - "bufio" - "errors" "fmt" - "io" - "io/fs" "os" - "strconv" "strings" - "time" - "github.com/icza/backscanner" - "github.com/jsdelivr/globalping-cli/view" "github.com/spf13/cobra" ) -var ( - ErrReadHistory = errors.New("failed to read history") -) - -var ( - invalidHistoryItemErr = "invalid history item: %s" - saveToHistoryErr = "failed to save to history: %s" -) - -const ( - // ||