diff --git a/.sonarcloud.properties b/.sonarcloud.properties index 2f16527..cbdc1f3 100644 --- a/.sonarcloud.properties +++ b/.sonarcloud.properties @@ -1 +1 @@ -sonar.cpd.exclusions=examples/** +sonar.exclusions=/examples/**,*_test.go diff --git a/caller.go b/caller.go index e20f8d1..4c8ca9a 100644 --- a/caller.go +++ b/caller.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -43,7 +44,7 @@ func (c *Caller) Call(ctx context.Context, in interface{}, opts ...CallOption) ( var data []byte if inVal := reflect.ValueOf(in); !options.ForceBody && (in == nil || inVal.Kind() == reflect.Struct || (inVal.Kind() == reflect.Ptr && inVal.Elem().Kind() == reflect.Struct)) && - (c.method == http.MethodHead || c.method == http.MethodGet) { + (c.method == http.MethodHead || c.method == http.MethodGet || c.method == http.MethodDelete) { var values url.Values values, err = structToValues(in) if err != nil { @@ -84,10 +85,22 @@ func (c *Caller) Call(ctx context.Context, in interface{}, opts ...CallOption) ( result.Data = data if contentType := resp.Header.Get("Content-Type"); contentType != "" { - err = validateJSONContentType(contentType) + validMediaTypes := []string{"application/json"} + if resp.StatusCode != http.StatusOK { + validMediaTypes = append(validMediaTypes, "text/plain") + } + var mediaType string + mediaType, _, err = validateContentType(contentType, validMediaTypes...) if err != nil { return result, &InvalidContentTypeError{err, contentType} } + if mediaType == "text/plain" { + text := bytes.TrimSpace(data) + if len(text) > 1024 { + text = text[:1024] + } + return result, &PlainTextError{errors.New(string(text))} + } } isErr := resp.StatusCode != http.StatusOK && c.options.ErrOut != nil @@ -101,7 +114,8 @@ func (c *Caller) Call(ctx context.Context, in interface{}, opts ...CallOption) ( return result, fmt.Errorf("unable to copy output: %w", err) } - if len(data) > 0 || (isErr && req.Method != http.MethodHead) { + //if len(data) > 0 || (isErr && req.Method != http.MethodHead) { + if len(data) > 0 { err = json.Unmarshal(data, copiedOutVal.Interface()) if err != nil { return result, fmt.Errorf("unable to unmarshal response body: %w", err) @@ -161,6 +175,9 @@ func (f *Factory) Caller(endpoint string, method string, out interface{}, opts . method: strings.ToUpper(method), out: out, } + if !strings.HasPrefix(result.url.Path, "/") { + result.url.Path = "/" + result.url.Path + } if endpoint != "" { result.url.Path = path.Join(result.url.Path, endpoint) } diff --git a/common.go b/common.go index d2f8541..ff78f84 100644 --- a/common.go +++ b/common.go @@ -22,7 +22,7 @@ type Response struct { type DoFunc func(req *Request, send SendFunc) // MiddlewareFunc is a function type to process requests as middleware from Handler. -type MiddlewareFunc func(req *Request, send SendFunc, do DoFunc) +type MiddlewareFunc func(req *Request, send SendFunc, next DoFunc) // SendFunc is a function type to send response in DoFunc or MiddlewareFunc. type SendFunc func(out interface{}, code int, header ...http.Header) diff --git a/errors.go b/errors.go index 8470e21..8e30bb5 100644 --- a/errors.go +++ b/errors.go @@ -4,13 +4,18 @@ import "fmt" // RequestError is the request error from http.Client. // It is returned from Caller.Call. -type RequestError struct{ error } +type RequestError struct{ error error } // Error is the implementation of error. func (e *RequestError) Error() string { return fmt.Errorf("request error: %w", e.error).Error() } +// Unwrap unwraps the underlying error. +func (e *RequestError) Unwrap() error { + return e.error +} + // InvalidContentTypeError occurs when the request or response body content type is invalid. type InvalidContentTypeError struct { error error @@ -22,7 +27,26 @@ func (e *InvalidContentTypeError) Error() string { return fmt.Errorf("invalid content type %q: %w", e.contentType, e.error).Error() } +// Unwrap unwraps the underlying error. +func (e *InvalidContentTypeError) Unwrap() error { + return e.error +} + // ContentType returns the invalid content type. func (e *InvalidContentTypeError) ContentType() string { return e.contentType } + +// PlainTextError is the plain text error returned from http server. +// It is returned from Caller.Call. +type PlainTextError struct{ error error } + +// Error is the implementation of error. +func (e *PlainTextError) Error() string { + return fmt.Errorf("plain text error: %w", e.error).Error() +} + +// Unwrap unwraps the underlying error. +func (e *PlainTextError) Unwrap() error { + return e.error +} diff --git a/handler.go b/handler.go index da19f3d..f838322 100644 --- a/handler.go +++ b/handler.go @@ -85,6 +85,16 @@ func (h *patternHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *patternHandler) Register(method string, in interface{}, do DoFunc, opts ...HandlerOption) Registrar { method = strings.ToUpper(method) + switch method { + case http.MethodGet: + case http.MethodPost: + case http.MethodPut: + case http.MethodPatch: + case http.MethodDelete: + default: + panic(fmt.Errorf("method %q not allowed", method)) + } + h.methodHandlersMu.Lock() defer h.methodHandlersMu.Unlock() @@ -94,6 +104,9 @@ func (h *patternHandler) Register(method string, in interface{}, do DoFunc, opts } mh = newMethodhandler(in, do, h.options, opts...) h.methodHandlers[method] = mh + if method == http.MethodGet { + h.methodHandlers[http.MethodHead] = mh + } return &struct{ Registrar }{h} } @@ -129,8 +142,8 @@ func (h *methodHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - var nopwc io.WriteCloser = nopWriteCloser{w} - wc := nopwc + var nopcw io.WriteCloser = nopCloserForWriter{w} + wc := nopcw if h.options.AllowEncoding { wc, err = getContentEncoder(w, r.Header.Get("Accept-Encoding")) if err != nil { @@ -157,7 +170,7 @@ func (h *methodHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } w.Header().Set("Content-Type", "application/json; charset=utf-8") - if wc == nopwc { + if wc == nopcw { w.Header().Set("Content-Length", strconv.FormatInt(int64(len(data)), 10)) } else { w.Header().Del("Content-Length") @@ -203,7 +216,7 @@ func (h *methodHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { contentType := r.Header.Get("Content-Type") if contentType != "" { - err = validateJSONContentType(contentType) + _, _, err = validateContentType(contentType, "application/json") if err != nil { h.options.PerformError(&InvalidContentTypeError{err, contentType}, r) httpError(r, w, "invalid content type", http.StatusBadRequest) @@ -220,7 +233,7 @@ func (h *methodHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } if contentType == "" && copiedInVal.Elem().Kind() == reflect.Struct && - (r.Method == http.MethodHead || r.Method == http.MethodGet) { + (r.Method == http.MethodHead || r.Method == http.MethodGet || r.Method == http.MethodDelete) { err = valuesToStruct(r.URL.Query(), copiedInVal.Interface()) if err != nil { h.options.PerformError(fmt.Errorf("invalid query: %w", err), r) diff --git a/utils.go b/utils.go index 13e293b..644d821 100644 --- a/utils.go +++ b/utils.go @@ -26,26 +26,38 @@ func httpError(r *http.Request, w http.ResponseWriter, error string, code int) { http.Error(w, error, code) } -// validateJSONContentType validates whether the content type is 'application/json; charset=utf-8'. -func validateJSONContentType(contentType string) error { - mediatype, params, err := mime.ParseMediaType(contentType) +// validateContentType validates whether the content type is in the given valid media types. +func validateContentType(contentType string, validMediaTypes ...string) (mediaType, charset string, err error) { + mediaType, params, err := mime.ParseMediaType(contentType) if err != nil { - return fmt.Errorf("media type parse error: %w", err) + return "", "", fmt.Errorf("media type parse error: %w", err) } - switch mediatype { - case "application/json": - default: - return fmt.Errorf("invalid media type %q", mediatype) + mediaType = strings.ToLower(mediaType) + + ok := false + for _, validMediaType := range validMediaTypes { + validMediaType = strings.ToLower(validMediaType) + if mediaType == validMediaType { + ok = true + break + } + } + if !ok { + return mediaType, "", fmt.Errorf("invalid media type %q", mediaType) } - if charset, ok := params["charset"]; ok { + + charset, ok = params["charset"] + if ok { charset = strings.ToLower(charset) switch charset { + case "ascii": case "utf-8": default: - return fmt.Errorf("invalid charset %q", charset) + return mediaType, charset, fmt.Errorf("invalid charset %q", charset) } } - return nil + + return mediaType, charset, nil } // copyReflectValue copies val and always returns pointer value if val is not pointer. @@ -103,7 +115,7 @@ func valuesToStruct(values url.Values, target interface{}) (err error) { continue } - fieldName := getJSONFieldName(field) + fieldName, _ := parseJSONField(field) if fieldName == "" { continue } @@ -173,7 +185,7 @@ func structToValues(source interface{}) (values url.Values, err error) { continue } - fieldName := getJSONFieldName(field) + fieldName, fieldOmitempty := parseJSONField(field) if fieldName == "" { continue } @@ -182,8 +194,13 @@ func structToValues(source interface{}) (values url.Values, err error) { ifc, kind := fieldVal.Interface(), fieldVal.Kind() - if kind == reflect.Ptr && fieldVal.IsNil() { - continue + if fieldOmitempty { + if fieldVal.IsZero() { + continue + } + if (kind == reflect.Array || kind == reflect.Slice || kind == reflect.Map) && fieldVal.Len() == 0 { + continue + } } switch ifc.(type) { @@ -216,20 +233,29 @@ func structToValues(source interface{}) (values url.Values, err error) { return values, nil } -// getJSONFieldName retrieves the JSON field name from the structure field. -func getJSONFieldName(sf reflect.StructField) string { - fieldName := toJSONFieldName(sf.Name) +// parseJSONField parses the JSON field from the structure field. +func parseJSONField(sf reflect.StructField) (name string, omitempty bool) { + name = toJSONFieldName(sf.Name) if v, ok := sf.Tag.Lookup("json"); ok { - s := strings.SplitN(v, ",", 2)[0] - if s == "-" { - return "" + sl := strings.Split(v, ",") + s := sl[0] + if s != "-" { + s = toJSONFieldName(s) + if s != "" { + name = s + } + } else { + name = "" } - s = toJSONFieldName(s) - if s != "" { - fieldName = s + for _, s = range sl[1:] { + switch s { + case "omitempty": + omitempty = true + case "string": + } } } - return fieldName + return } // toJSONFieldName converts the given string to the JSON field name. @@ -292,16 +318,16 @@ func getContentEncoder(w http.ResponseWriter, acceptEncoding string) (result io. } } - return nopWriteCloser{w}, nil + return nopCloserForWriter{w}, nil } -// nopWriteCloser implements io.WriteCloser with a no-op Close method wrapping the provided io.Writer. -type nopWriteCloser struct { +// nopCloserForWriter implements io.WriteCloser with a no-op Close method wrapping the provided io.Writer. +type nopCloserForWriter struct { io.Writer } // Close is the implementation of io.WriteCloser. -func (nopWriteCloser) Close() error { return nil } +func (nopCloserForWriter) Close() error { return nil } // httpHeaderOption defines single http header option. type httpHeaderOption struct {