diff --git a/impl_request.go b/impl_request.go index aa970c28..2512bd6d 100644 --- a/impl_request.go +++ b/impl_request.go @@ -509,3 +509,77 @@ func (r *defaultHttpClient) Do(ctx context.Context, req *http.Request) (*http.Re req = req.WithContext(ctx) return r.ins.Do(req) } + +func FormatRequestBody(body interface{}) interface{} { + if body == nil { + return nil + } + vt := reflect.TypeOf(body) + for vt.Kind() == reflect.Ptr { + vt = vt.Elem() + } + if vt.Kind() != reflect.Struct { + return body + } + + queries := url.Values{} + res := map[string]interface{}{} + + _ = rangeStruct(body, func(fieldVV reflect.Value, fieldVT reflect.StructField) error { + key, val := formatRequestBodyKey(fieldVV, fieldVT, &queries) + if key != "" { + res[key] = val + } + return nil + }) + + for k, v := range queries { + if len(v) == 1 { + res[k] = v[0] + } else if len(v) > 1 { + res[k] = v + } + } + return res +} + +func formatRequestBodyKey(fieldVV reflect.Value, fieldVT reflect.StructField, queries *url.Values) (string, interface{}) { + if path := fieldVT.Tag.Get("path"); path != "" { + return path, internal.ReflectToString(fieldVV) + } else if queryKey := fieldVT.Tag.Get("query"); queryKey != "" { + value := internal.ReflectToQueryString(fieldVV) + sep := fieldVT.Tag.Get("join_sep") + if sep != "" { + queries.Add(queryKey, strings.Join(value, sep)) + } else { + for _, v := range value { + queries.Add(queryKey, v) + } + } + return "", nil + } else if header := fieldVT.Tag.Get("header"); header != "" { + switch header { + case "range": + if fieldVV.Kind() != reflect.Array || fieldVV.Len() != 2 || fieldVV.Index(0).Kind() != reflect.Int64 { + return "", nil + } + from := fieldVV.Index(0).Int() + to := fieldVV.Index(1).Int() + if from != 0 || to != 0 { + return "range", fmt.Sprintf("bytes=%d-%d", from, to) + } + } + } else if j := fieldVT.Tag.Get("json"); j != "" { + if strings.HasSuffix(j, ",omitempty") { + j = j[:len(j)-10] + } + if _, ok := fieldVV.Interface().(io.Reader); ok { + return j, "" + } + for fieldVV.Kind() == reflect.Ptr { + fieldVV = fieldVV.Elem() + } + return j, FormatRequestBody(fieldVV.Interface()) + } + return "", nil +} diff --git a/impl_request_test.go b/impl_request_test.go new file mode 100644 index 00000000..892ed837 --- /dev/null +++ b/impl_request_test.go @@ -0,0 +1,42 @@ +package lark + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFormatRequestBody(t *testing.T) { + t.Run("", func(t *testing.T) { + assert.Equal(t, map[string]interface{}{ + "app_id": "app", + "lang": "lang", + "user_id_type": "app_id", + }, FormatRequestBody(GetApplicationReq{ + AppID: "app", + Lang: "lang", + UserIDType: IDTypePtr(IDTypeAppID), + })) + }) + t.Run("", func(t *testing.T) { + assert.Equal(t, map[string]interface{}{ + "user_ids": []string{"1"}, + "user_id_type": "app_id", + "app_feed_card": map[string]interface{}{ + "biz_id": "biz", + "link": map[string]interface{}{ + "link": "link", + }, + }, + }, FormatRequestBody(CreateAppFeedCardReq{ + UserIDType: IDTypePtr(IDTypeAppID), + AppFeedCard: &CreateAppFeedCardReqAppFeedCard{ + BizID: &[]string{"biz"}[0], + Link: &CreateAppFeedCardReqAppFeedCardLink{ + Link: &[]string{"link"}[0], + }, + }, + UserIDs: []string{"1"}, + })) + }) +}