Skip to content

Commit

Permalink
cmd/atlas/internal/cloudapi: expose http error to callers
Browse files Browse the repository at this point in the history
  • Loading branch information
a8m committed Jan 1, 2025
1 parent d739d6d commit 9a8ae0d
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 4 deletions.
24 changes: 21 additions & 3 deletions cmd/atlas/internal/cloudapi/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ func New(endpoint, token string) *Client {
transport = client.HTTPClient.Transport
)
client.HTTPClient.Timeout = time.Second * 30
client.ErrorHandler = func(res *http.Response, err error, _ int) (*http.Response, error) {
return res, err // Let Client.post handle the error.
}
client.HTTPClient.Transport = &roundTripper{
token: token,
base: transport,
Expand Down Expand Up @@ -284,16 +287,21 @@ func (c *Client) post(ctx context.Context, query string, vars, data any) error {
if err != nil {
return err
}
defer req.Body.Close()
defer res.Body.Close()
switch {
case res.StatusCode == http.StatusUnauthorized:
return ErrUnauthorized
case res.StatusCode != http.StatusOK:
buf, err := io.ReadAll(io.LimitReader(res.Body, 1<<20))
if err != nil {
return &HTTPError{StatusCode: res.StatusCode, Message: err.Error()}
}
var v struct {
Errors errlist `json:"errors,omitempty"`
}
if err := json.NewDecoder(res.Body).Decode(&v); err != nil || len(v.Errors) == 0 {
return fmt.Errorf("unexpected status code: %d", res.StatusCode)
if err := json.Unmarshal(buf, &v); err != nil || len(v.Errors) == 0 {
// If the error is not a GraphQL error, return the message as is.
return &HTTPError{StatusCode: res.StatusCode, Message: string(bytes.TrimSpace(buf))}
}
return v.Errors
}
Expand Down Expand Up @@ -347,6 +355,16 @@ func (r *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return r.base.RoundTrip(req)
}

// HTTPError represents a generic HTTP error. Hence, non 2xx status codes.
type HTTPError struct {
StatusCode int
Message string
}

func (e *HTTPError) Error() string {
return fmt.Sprintf("unexpected error code %d: %s", e.StatusCode, e.Message)
}

// RedactedURL returns a URL string with the userinfo redacted.
func RedactedURL(s string) (string, error) {
u, err := sqlclient.ParseURL(s)
Expand Down
55 changes: 54 additions & 1 deletion cmd/atlas/internal/cloudapi/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"net/http"
"net/http/httptest"
"runtime"
"strings"
"testing"

"ariga.io/atlas/sql/migrate"
Expand Down Expand Up @@ -53,7 +54,7 @@ func TestClient_Dir(t *testing.T) {
require.Equal(t, dcheck.Sum(), gcheck.Sum())
}

func TestClient_Error(t *testing.T) {
func TestClient_GraphQLError(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnprocessableEntity)
_, err := w.Write([]byte(`{"errors":[{"message":"error\n","path":["variable","input","driver"],"extensions":{}}],"data":null}`))
Expand All @@ -69,6 +70,58 @@ func TestClient_Error(t *testing.T) {
require.Empty(t, link)
}

func TestClient_HTTPError(t *testing.T) {
var (
body string
code = http.StatusInternalServerError
)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, body, code)
}))
client := New(srv.URL, "atlas")
defer srv.Close()
body = "internal error"
_, err := client.ReportMigration(context.Background(), ReportMigrationInput{
EnvName: "foo",
ProjectName: "bar",
})
require.EqualError(t, err, `unexpected error code 500: internal error`)

// Error should be limited to 1MB.
body = fmt.Sprintf("%s!", strings.Repeat("a", 1<<20))
_, err = client.ReportMigration(context.Background(), ReportMigrationInput{
EnvName: "foo",
ProjectName: "bar",
})
require.ErrorContains(t, err, "unexpected error code 500: a")
require.NotContains(t, err.Error(), "!")

// Unauthorized error.
body = "unauthorized"
code = http.StatusUnauthorized
_, err = client.ReportMigration(context.Background(), ReportMigrationInput{
EnvName: "foo",
ProjectName: "bar",
})
require.ErrorIs(t, err, ErrUnauthorized)

code = http.StatusForbidden
body = "Forbidden"
_, err = client.ReportMigration(context.Background(), ReportMigrationInput{
EnvName: "foo",
ProjectName: "bar",
})
require.EqualError(t, err, "unexpected error code 403: Forbidden")

code = http.StatusConflict
body = `{"errors":[{"message":"conflict\n","path":["variable","input","driver"],"extensions":{}}],"data":null}`
_, err = client.ReportMigration(context.Background(), ReportMigrationInput{
EnvName: "foo",
ProjectName: "bar",
})
require.EqualError(t, err, "variable.input.driver conflict", "GraphQL error")
}

func TestClient_ReportMigration(t *testing.T) {
const project, env = "atlas", "dev"
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down

0 comments on commit 9a8ae0d

Please sign in to comment.