From ac7800151ab177ff9ed87408635a65764ba63bb5 Mon Sep 17 00:00:00 2001 From: tdakkota Date: Fri, 10 Nov 2023 08:01:54 +0300 Subject: [PATCH 1/4] fix(gen): handle multipart complex fields according to spec --- gen/_template/request_decode.tmpl | 16 +++++++- gen/_template/request_encode.tmpl | 11 +++++- gen/gen_contents.go | 63 ++++++++++++++++++++++++++----- 3 files changed, 78 insertions(+), 12 deletions(-) diff --git a/gen/_template/request_decode.tmpl b/gen/_template/request_decode.tmpl index 37a38bb98..d2ad2f0b4 100644 --- a/gen/_template/request_decode.tmpl +++ b/gen/_template/request_decode.tmpl @@ -242,7 +242,21 @@ func (s *{{ if $op.WebhookInfo }}Webhook{{ end }}Server) decode{{ $op.Name }}Req } if err := q.HasParam(cfg); err == nil { if err := q.DecodeParam(cfg, func(d uri.Decoder) error { - {{- template "uri/decode" $el }} + {{- if $p.Spec.Content }} + val, err := d.DecodeValue() + if err != nil { + return err + } + if err := func(d *jx.Decoder) error { + {{- template "json/dec" $el }} + return nil + }(jx.DecodeStr(val)); err != nil { + return err + } + return nil + {{- else }} + {{- template "uri/decode" $el }} + {{- end }} }); err != nil { return req, close, errors.Wrap(err, {{ printf "decode %q" $p.Spec.Name | quote }}) } diff --git a/gen/_template/request_encode.tmpl b/gen/_template/request_encode.tmpl index 9d04b20fd..ceeabf733 100644 --- a/gen/_template/request_encode.tmpl +++ b/gen/_template/request_encode.tmpl @@ -149,7 +149,16 @@ func encode{{ $op.Name }}Request( Explode: {{ if $param.Spec.Explode }}true{{ else }}false{{ end }}, } if err := q.EncodeParam(cfg, func(e uri.Encoder) error { - {{- template "uri/encode" elem $param.Type (printf "request.%s" $param.Name) }} + {{- $el := elem $param.Type (printf "request.%s" $param.Name) }} + {{- if $param.Spec.Content }} + var enc jx.Encoder + func(e *jx.Encoder) { + {{- template "json/enc" $el }} + }(&enc) + return e.EncodeValue(string(enc.Bytes())) + {{- else }} + {{- template "uri/encode" $el }} + {{- end }} }); err != nil { return errors.Wrap(err, "encode query") } diff --git a/gen/gen_contents.go b/gen/gen_contents.go index 1893ea536..a75e2f846 100644 --- a/gen/gen_contents.go +++ b/gen/gen_contents.go @@ -108,6 +108,16 @@ func (g *Generator) generateFormContent( return nil, &ErrNotImplemented{"complex form schema"} } + getEncoding := func(f *ir.Field) (ct ir.Encoding) { + if e, ok := media.Encoding[f.Tag.JSON]; ok { + ct = ir.Encoding(e.ContentType) + } + if ct == "" && encoding.MultipartForm() && isComplexMultipartType(f.Spec.Schema) { + ct = ir.EncodingJSON + } + return ct + } + var override generateSchemaOverride switch encoding { case ir.EncodingFormURLEncoded: @@ -140,7 +150,17 @@ func (g *Generator) generateFormContent( t.AddFeature("multipart-file") return nil } - f.Type.AddFeature("uri") + switch ct := getEncoding(f); ct { + case "", ir.EncodingFormURLEncoded: + f.Type.AddFeature("uri") + case ir.EncodingJSON: + f.Type.AddFeature("json") + default: + return errors.Wrapf( + &ErrNotImplemented{"form content encoding"}, + "%q", ct, + ) + } return nil } } @@ -188,17 +208,25 @@ func (g *Generator) generateFormContent( if e, ok := media.Encoding[tag]; ok { spec.Style = e.Style spec.Explode = e.Explode - if e.ContentType != "" { - return &ErrNotImplemented{"parameter content-type"} - } - } - - if err := isSupportedParamStyle(spec); err != nil { - return err } + switch ct := getEncoding(f); ct { + case "", ir.EncodingFormURLEncoded: + if err := isSupportedParamStyle(spec); err != nil { + return err + } - if err := isParamAllowed(f.Type, true, map[*ir.Type]struct{}{}); err != nil { - return err + if err := isParamAllowed(f.Type, true, map[*ir.Type]struct{}{}); err != nil { + return err + } + case ir.EncodingJSON: + spec.Content = &openapi.ParameterContent{ + Name: ct.String(), + } + default: + return errors.Wrapf( + &ErrNotImplemented{"form content encoding"}, + "%q", ct, + ) } return nil @@ -211,6 +239,21 @@ func (g *Generator) generateFormContent( return t, nil } +func isComplexMultipartType(s *jsonschema.Schema) bool { + if s == nil { + return true + } + + switch s.Type { + case jsonschema.Object, jsonschema.Empty: + return true + case jsonschema.Array: + return len(s.Items) > 0 || isComplexMultipartType(s.Item) + default: + return false + } +} + func (g *Generator) generateContents( ctx *genctx, name string, From 7ec004361382aed6989f275bf8e14cc923f5d013 Mon Sep 17 00:00:00 2001 From: tdakkota Date: Fri, 10 Nov 2023 11:48:10 +0300 Subject: [PATCH 2/4] chore: adjust form schema to pass --- _testdata/positive/form.json | 1 + 1 file changed, 1 insertion(+) diff --git a/_testdata/positive/form.json b/_testdata/positive/form.json index 96215b525..efa2bfe6d 100644 --- a/_testdata/positive/form.json +++ b/_testdata/positive/form.json @@ -39,6 +39,7 @@ "multipart/form-data": { "encoding": { "deepObject": { + "contentType": "application/x-www-form-urlencoded", "style": "deepObject" } }, From 00c3e8080f82ea88b3da7315b56be2c7c943d9af Mon Sep 17 00:00:00 2001 From: tdakkota Date: Fri, 10 Nov 2023 09:40:13 +0300 Subject: [PATCH 3/4] test(integration): update tests due to generator changes --- internal/integration/form_test.go | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/internal/integration/form_test.go b/internal/integration/form_test.go index 8157bf96d..6c6aa9e3f 100644 --- a/internal/integration/form_test.go +++ b/internal/integration/form_test.go @@ -59,13 +59,17 @@ func testFormMultipart() *api.TestFormMultipart { } } -func checkTestFormValues(a *assert.Assertions, form url.Values) { +func checkTestFormValues(a *assert.Assertions, form url.Values, multipartForm bool) { a.Equal("10", form.Get("id")) a.Equal("00000000-0000-0000-0000-000000000000", form.Get("uuid")) a.Equal("foobar", form.Get("description")) a.Equal([]string{"foo", "bar"}, form["array"]) - a.Equal("10", form.Get("min")) - a.Equal("10", form.Get("max")) + if multipartForm { + a.JSONEq(`{"min":10,"max":10}`, form.Get("object")) + } else { + a.Equal("10", form.Get("min")) + a.Equal("10", form.Get("max")) + } a.Equal("10", form.Get("deepObject[min]")) a.Equal("10", form.Get("deepObject[max]")) } @@ -196,7 +200,7 @@ func TestURIEncodingE2E(t *testing.T) { s := tt.serverSetup(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { a.NoError(req.ParseForm()) - checkTestFormValues(a, req.PostForm) + checkTestFormValues(a, req.PostForm, false) apiServer.ServeHTTP(w, req) })) defer s.Close() @@ -224,7 +228,7 @@ func TestMultipartEncodingE2E(t *testing.T) { s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { a.NoError(req.ParseMultipartForm(32 << 20)) form := url.Values(req.MultipartForm.Value) - checkTestFormValues(a, form) + checkTestFormValues(a, form, true) apiServer.ServeHTTP(w, req) })) defer s.Close() From 4ce928c4de6975957df4f84fa6076894aef7cb34 Mon Sep 17 00:00:00 2001 From: tdakkota Date: Fri, 10 Nov 2023 09:40:19 +0300 Subject: [PATCH 4/4] chore: commit generated files --- .../oas_response_decoders_gen.go | 117 +++++++++++ .../test_allof/oas_request_decoders_gen.go | 30 ++- .../test_allof/oas_request_encoders_gen.go | 12 +- .../integration/test_allof/oas_uri_gen.go | 122 ------------ .../integration/test_form/oas_json_gen.go | 181 ++++++++++++++++++ .../test_form/oas_request_decoders_gen.go | 17 +- .../test_form/oas_request_encoders_gen.go | 11 +- internal/integration/test_form/oas_uri_gen.go | 117 ----------- 8 files changed, 352 insertions(+), 255 deletions(-) delete mode 100644 internal/integration/test_allof/oas_uri_gen.go diff --git a/internal/integration/sample_api_no_otel/oas_response_decoders_gen.go b/internal/integration/sample_api_no_otel/oas_response_decoders_gen.go index 5f1108161..4f3915a85 100644 --- a/internal/integration/sample_api_no_otel/oas_response_decoders_gen.go +++ b/internal/integration/sample_api_no_otel/oas_response_decoders_gen.go @@ -180,6 +180,15 @@ func decodeFoobarGetResponse(resp *http.Response) (res FoobarGetRes, _ error) { } return res, err } + // Validate response. + if err := func() error { + if err := response.Validate(); err != nil { + return err + } + return nil + }(); err != nil { + return res, errors.Wrap(err, "validate") + } return &response, nil default: return res, validate.InvalidContentType(ct) @@ -224,6 +233,15 @@ func decodeFoobarPostResponse(resp *http.Response) (res FoobarPostRes, _ error) } return res, err } + // Validate response. + if err := func() error { + if err := response.Validate(); err != nil { + return err + } + return nil + }(); err != nil { + return res, errors.Wrap(err, "validate") + } return &response, nil default: return res, validate.InvalidContentType(ct) @@ -460,6 +478,15 @@ func decodePetCreateResponse(resp *http.Response) (res *Pet, _ error) { } return res, err } + // Validate response. + if err := func() error { + if err := response.Validate(); err != nil { + return err + } + return nil + }(); err != nil { + return res, errors.Wrap(err, "validate") + } return &response, nil default: return res, validate.InvalidContentType(ct) @@ -511,6 +538,15 @@ func decodePetFriendsNamesByIDResponse(resp *http.Response) (res []string, _ err } return res, err } + // Validate response. + if err := func() error { + if response == nil { + return errors.New("nil is invalid value") + } + return nil + }(); err != nil { + return res, errors.Wrap(err, "validate") + } return response, nil default: return res, validate.InvalidContentType(ct) @@ -552,6 +588,15 @@ func decodePetGetResponse(resp *http.Response) (res PetGetRes, _ error) { } return res, err } + // Validate response. + if err := func() error { + if err := response.Validate(); err != nil { + return err + } + return nil + }(); err != nil { + return res, errors.Wrap(err, "validate") + } return &response, nil default: return res, validate.InvalidContentType(ct) @@ -775,6 +820,15 @@ func decodePetGetByNameResponse(resp *http.Response) (res *Pet, _ error) { } return res, err } + // Validate response. + if err := func() error { + if err := response.Validate(); err != nil { + return err + } + return nil + }(); err != nil { + return res, errors.Wrap(err, "validate") + } return &response, nil default: return res, validate.InvalidContentType(ct) @@ -939,6 +993,15 @@ func decodeRecursiveArrayGetResponse(resp *http.Response) (res RecursiveArray, _ } return res, err } + // Validate response. + if err := func() error { + if err := response.Validate(); err != nil { + return err + } + return nil + }(); err != nil { + return res, errors.Wrap(err, "validate") + } return response, nil default: return res, validate.InvalidContentType(ct) @@ -1155,6 +1218,15 @@ func decodeTestNullableOneofsResponse(resp *http.Response) (res TestNullableOneo } return res, err } + // Validate response. + if err := func() error { + if err := response.Validate(); err != nil { + return err + } + return nil + }(); err != nil { + return res, errors.Wrap(err, "validate") + } return &response, nil default: return res, validate.InvalidContentType(ct) @@ -1190,6 +1262,15 @@ func decodeTestNullableOneofsResponse(resp *http.Response) (res TestNullableOneo } return res, err } + // Validate response. + if err := func() error { + if err := response.Validate(); err != nil { + return err + } + return nil + }(); err != nil { + return res, errors.Wrap(err, "validate") + } return &response, nil default: return res, validate.InvalidContentType(ct) @@ -1225,6 +1306,15 @@ func decodeTestNullableOneofsResponse(resp *http.Response) (res TestNullableOneo } return res, err } + // Validate response. + if err := func() error { + if err := response.Validate(); err != nil { + return err + } + return nil + }(); err != nil { + return res, errors.Wrap(err, "validate") + } return &response, nil default: return res, validate.InvalidContentType(ct) @@ -1266,6 +1356,15 @@ func decodeTestTupleResponse(resp *http.Response) (res *TupleTest, _ error) { } return res, err } + // Validate response. + if err := func() error { + if err := response.Validate(); err != nil { + return err + } + return nil + }(); err != nil { + return res, errors.Wrap(err, "validate") + } return &response, nil default: return res, validate.InvalidContentType(ct) @@ -1307,6 +1406,15 @@ func decodeTestTupleNamedResponse(resp *http.Response) (res *TupleNamedTest, _ e } return res, err } + // Validate response. + if err := func() error { + if err := response.Validate(); err != nil { + return err + } + return nil + }(); err != nil { + return res, errors.Wrap(err, "validate") + } return &response, nil default: return res, validate.InvalidContentType(ct) @@ -1348,6 +1456,15 @@ func decodeTestUniqueItemsResponse(resp *http.Response) (res *UniqueItemsTest, _ } return res, err } + // Validate response. + if err := func() error { + if err := response.Validate(); err != nil { + return err + } + return nil + }(); err != nil { + return res, errors.Wrap(err, "validate") + } return &response, nil default: return res, validate.InvalidContentType(ct) diff --git a/internal/integration/test_allof/oas_request_decoders_gen.go b/internal/integration/test_allof/oas_request_decoders_gen.go index f84f19b04..8ceac78cd 100644 --- a/internal/integration/test_allof/oas_request_decoders_gen.go +++ b/internal/integration/test_allof/oas_request_decoders_gen.go @@ -397,11 +397,22 @@ func (s *Server) decodeReferencedAllofRequest(r *http.Request) ( Name: "location", Style: uri.QueryStyleForm, Explode: true, - Fields: []uri.QueryParameterObjectField{{"lat", true}, {"lon", true}}, } if err := q.HasParam(cfg); err == nil { if err := q.DecodeParam(cfg, func(d uri.Decoder) error { - return request.Location.DecodeURI(d) + val, err := d.DecodeValue() + if err != nil { + return err + } + if err := func(d *jx.Decoder) error { + if err := request.Location.Decode(d); err != nil { + return err + } + return nil + }(jx.DecodeStr(val)); err != nil { + return err + } + return nil }); err != nil { return req, close, errors.Wrap(err, "decode \"location\"") } @@ -578,11 +589,22 @@ func (s *Server) decodeReferencedAllofOptionalRequest(r *http.Request) ( Name: "location", Style: uri.QueryStyleForm, Explode: true, - Fields: []uri.QueryParameterObjectField{{"lat", true}, {"lon", true}}, } if err := q.HasParam(cfg); err == nil { if err := q.DecodeParam(cfg, func(d uri.Decoder) error { - return request.Location.DecodeURI(d) + val, err := d.DecodeValue() + if err != nil { + return err + } + if err := func(d *jx.Decoder) error { + if err := request.Location.Decode(d); err != nil { + return err + } + return nil + }(jx.DecodeStr(val)); err != nil { + return err + } + return nil }); err != nil { return req, close, errors.Wrap(err, "decode \"location\"") } diff --git a/internal/integration/test_allof/oas_request_encoders_gen.go b/internal/integration/test_allof/oas_request_encoders_gen.go index 243fc1ada..d2218e179 100644 --- a/internal/integration/test_allof/oas_request_encoders_gen.go +++ b/internal/integration/test_allof/oas_request_encoders_gen.go @@ -111,7 +111,11 @@ func encodeReferencedAllofRequest( Explode: true, } if err := q.EncodeParam(cfg, func(e uri.Encoder) error { - return request.Location.EncodeURI(e) + var enc jx.Encoder + func(e *jx.Encoder) { + request.Location.Encode(e) + }(&enc) + return e.EncodeValue(string(enc.Bytes())) }); err != nil { return errors.Wrap(err, "encode query") } @@ -185,7 +189,11 @@ func encodeReferencedAllofOptionalRequest( Explode: true, } if err := q.EncodeParam(cfg, func(e uri.Encoder) error { - return request.Location.EncodeURI(e) + var enc jx.Encoder + func(e *jx.Encoder) { + request.Location.Encode(e) + }(&enc) + return e.EncodeValue(string(enc.Bytes())) }); err != nil { return errors.Wrap(err, "encode query") } diff --git a/internal/integration/test_allof/oas_uri_gen.go b/internal/integration/test_allof/oas_uri_gen.go deleted file mode 100644 index 90bcadc92..000000000 --- a/internal/integration/test_allof/oas_uri_gen.go +++ /dev/null @@ -1,122 +0,0 @@ -// Code generated by ogen, DO NOT EDIT. - -package api - -import ( - "math/bits" - "strconv" - - "github.com/go-faster/errors" - - "github.com/ogen-go/ogen/conv" - "github.com/ogen-go/ogen/uri" - "github.com/ogen-go/ogen/validate" -) - -// EncodeURI encodes Location as URI form. -func (s *Location) EncodeURI(e uri.Encoder) error { - if err := e.EncodeField("lat", func(e uri.Encoder) error { - return e.EncodeValue(conv.Float64ToString(s.Lat)) - }); err != nil { - return errors.Wrap(err, "encode field \"lat\"") - } - if err := e.EncodeField("lon", func(e uri.Encoder) error { - return e.EncodeValue(conv.Float64ToString(s.Lon)) - }); err != nil { - return errors.Wrap(err, "encode field \"lon\"") - } - return nil -} - -var uriFieldsNameOfLocation = [2]string{ - 0: "lat", - 1: "lon", -} - -// DecodeURI decodes Location from URI form. -func (s *Location) DecodeURI(d uri.Decoder) error { - if s == nil { - return errors.New("invalid: unable to decode Location to nil") - } - var requiredBitSet [1]uint8 - - if err := d.DecodeFields(func(k string, d uri.Decoder) error { - switch k { - case "lat": - requiredBitSet[0] |= 1 << 0 - if err := func() error { - val, err := d.DecodeValue() - if err != nil { - return err - } - - c, err := conv.ToFloat64(val) - if err != nil { - return err - } - - s.Lat = c - return nil - }(); err != nil { - return errors.Wrap(err, "decode field \"lat\"") - } - case "lon": - requiredBitSet[0] |= 1 << 1 - if err := func() error { - val, err := d.DecodeValue() - if err != nil { - return err - } - - c, err := conv.ToFloat64(val) - if err != nil { - return err - } - - s.Lon = c - return nil - }(); err != nil { - return errors.Wrap(err, "decode field \"lon\"") - } - default: - return nil - } - return nil - }); err != nil { - return errors.Wrap(err, "decode Location") - } - // Validate required fields. - var failures []validate.FieldError - for i, mask := range [1]uint8{ - 0b00000011, - } { - if result := (requiredBitSet[i] & mask) ^ mask; result != 0 { - // Mask only required fields and check equality to mask using XOR. - // - // If XOR result is not zero, result is not equal to expected, so some fields are missed. - // Bits of fields which would be set are actually bits of missed fields. - missed := bits.OnesCount8(result) - for bitN := 0; bitN < missed; bitN++ { - bitIdx := bits.TrailingZeros8(result) - fieldIdx := i*8 + bitIdx - var name string - if fieldIdx < len(uriFieldsNameOfLocation) { - name = uriFieldsNameOfLocation[fieldIdx] - } else { - name = strconv.Itoa(fieldIdx) - } - failures = append(failures, validate.FieldError{ - Name: name, - Error: validate.ErrFieldRequired, - }) - // Reset bit. - result &^= 1 << bitIdx - } - } - } - if len(failures) > 0 { - return &validate.Error{Fields: failures} - } - - return nil -} diff --git a/internal/integration/test_form/oas_json_gen.go b/internal/integration/test_form/oas_json_gen.go index 74daef679..0a67f6d5c 100644 --- a/internal/integration/test_form/oas_json_gen.go +++ b/internal/integration/test_form/oas_json_gen.go @@ -12,6 +12,41 @@ import ( "github.com/ogen-go/ogen/validate" ) +// Encode encodes int as json. +func (o OptInt) Encode(e *jx.Encoder) { + if !o.Set { + return + } + e.Int(int(o.Value)) +} + +// Decode decodes int from json. +func (o *OptInt) Decode(d *jx.Decoder) error { + if o == nil { + return errors.New("invalid: unable to decode OptInt to nil") + } + o.Set = true + v, err := d.Int() + if err != nil { + return err + } + o.Value = int(v) + return nil +} + +// MarshalJSON implements stdjson.Marshaler. +func (s OptInt) MarshalJSON() ([]byte, error) { + e := jx.Encoder{} + s.Encode(&e) + return e.Bytes(), nil +} + +// UnmarshalJSON implements stdjson.Unmarshaler. +func (s *OptInt) UnmarshalJSON(data []byte) error { + d := jx.DecodeBytes(data) + return s.Decode(d) +} + // Encode encodes string as json. func (o OptString) Encode(e *jx.Encoder) { if !o.Set { @@ -47,6 +82,39 @@ func (s *OptString) UnmarshalJSON(data []byte) error { return s.Decode(d) } +// Encode encodes TestFormMultipartObject as json. +func (o OptTestFormMultipartObject) Encode(e *jx.Encoder) { + if !o.Set { + return + } + o.Value.Encode(e) +} + +// Decode decodes TestFormMultipartObject from json. +func (o *OptTestFormMultipartObject) Decode(d *jx.Decoder) error { + if o == nil { + return errors.New("invalid: unable to decode OptTestFormMultipartObject to nil") + } + o.Set = true + if err := o.Value.Decode(d); err != nil { + return err + } + return nil +} + +// MarshalJSON implements stdjson.Marshaler. +func (s OptTestFormMultipartObject) MarshalJSON() ([]byte, error) { + e := jx.Encoder{} + s.Encode(&e) + return e.Bytes(), nil +} + +// UnmarshalJSON implements stdjson.Unmarshaler. +func (s *OptTestFormMultipartObject) UnmarshalJSON(data []byte) error { + d := jx.DecodeBytes(data) + return s.Decode(d) +} + // Encode implements json.Marshaler. func (s *SharedRequest) Encode(e *jx.Encoder) { e.ObjStart() @@ -127,6 +195,119 @@ func (s *SharedRequest) UnmarshalJSON(data []byte) error { return s.Decode(d) } +// Encode implements json.Marshaler. +func (s *TestFormMultipartObject) Encode(e *jx.Encoder) { + e.ObjStart() + s.encodeFields(e) + e.ObjEnd() +} + +// encodeFields encodes fields. +func (s *TestFormMultipartObject) encodeFields(e *jx.Encoder) { + { + if s.Min.Set { + e.FieldStart("min") + s.Min.Encode(e) + } + } + { + e.FieldStart("max") + e.Int(s.Max) + } +} + +var jsonFieldsNameOfTestFormMultipartObject = [2]string{ + 0: "min", + 1: "max", +} + +// Decode decodes TestFormMultipartObject from json. +func (s *TestFormMultipartObject) Decode(d *jx.Decoder) error { + if s == nil { + return errors.New("invalid: unable to decode TestFormMultipartObject to nil") + } + var requiredBitSet [1]uint8 + + if err := d.ObjBytes(func(d *jx.Decoder, k []byte) error { + switch string(k) { + case "min": + if err := func() error { + s.Min.Reset() + if err := s.Min.Decode(d); err != nil { + return err + } + return nil + }(); err != nil { + return errors.Wrap(err, "decode field \"min\"") + } + case "max": + requiredBitSet[0] |= 1 << 1 + if err := func() error { + v, err := d.Int() + s.Max = int(v) + if err != nil { + return err + } + return nil + }(); err != nil { + return errors.Wrap(err, "decode field \"max\"") + } + default: + return d.Skip() + } + return nil + }); err != nil { + return errors.Wrap(err, "decode TestFormMultipartObject") + } + // Validate required fields. + var failures []validate.FieldError + for i, mask := range [1]uint8{ + 0b00000010, + } { + if result := (requiredBitSet[i] & mask) ^ mask; result != 0 { + // Mask only required fields and check equality to mask using XOR. + // + // If XOR result is not zero, result is not equal to expected, so some fields are missed. + // Bits of fields which would be set are actually bits of missed fields. + missed := bits.OnesCount8(result) + for bitN := 0; bitN < missed; bitN++ { + bitIdx := bits.TrailingZeros8(result) + fieldIdx := i*8 + bitIdx + var name string + if fieldIdx < len(jsonFieldsNameOfTestFormMultipartObject) { + name = jsonFieldsNameOfTestFormMultipartObject[fieldIdx] + } else { + name = strconv.Itoa(fieldIdx) + } + failures = append(failures, validate.FieldError{ + Name: name, + Error: validate.ErrFieldRequired, + }) + // Reset bit. + result &^= 1 << bitIdx + } + } + } + if len(failures) > 0 { + return &validate.Error{Fields: failures} + } + + return nil +} + +// MarshalJSON implements stdjson.Marshaler. +func (s *TestFormMultipartObject) MarshalJSON() ([]byte, error) { + e := jx.Encoder{} + s.Encode(&e) + return e.Bytes(), nil +} + +// UnmarshalJSON implements stdjson.Unmarshaler. +func (s *TestFormMultipartObject) UnmarshalJSON(data []byte) error { + d := jx.DecodeBytes(data) + return s.Decode(d) +} + // Encode implements json.Marshaler. func (s *TestMultipartUploadOK) Encode(e *jx.Encoder) { e.ObjStart() diff --git a/internal/integration/test_form/oas_request_decoders_gen.go b/internal/integration/test_form/oas_request_decoders_gen.go index d9fd5ebc4..f9de668ab 100644 --- a/internal/integration/test_form/oas_request_decoders_gen.go +++ b/internal/integration/test_form/oas_request_decoders_gen.go @@ -670,17 +670,22 @@ func (s *Server) decodeTestMultipartRequest(r *http.Request) ( Name: "object", Style: uri.QueryStyleForm, Explode: true, - Fields: []uri.QueryParameterObjectField{{"min", false}, {"max", true}}, } if err := q.HasParam(cfg); err == nil { if err := q.DecodeParam(cfg, func(d uri.Decoder) error { - var requestDotObjectVal TestFormMultipartObject - if err := func() error { - return requestDotObjectVal.DecodeURI(d) - }(); err != nil { + val, err := d.DecodeValue() + if err != nil { + return err + } + if err := func(d *jx.Decoder) error { + request.Object.Reset() + if err := request.Object.Decode(d); err != nil { + return err + } + return nil + }(jx.DecodeStr(val)); err != nil { return err } - request.Object.SetTo(requestDotObjectVal) return nil }); err != nil { return req, close, errors.Wrap(err, "decode \"object\"") diff --git a/internal/integration/test_form/oas_request_encoders_gen.go b/internal/integration/test_form/oas_request_encoders_gen.go index 4838312dd..01b991b66 100644 --- a/internal/integration/test_form/oas_request_encoders_gen.go +++ b/internal/integration/test_form/oas_request_encoders_gen.go @@ -290,10 +290,13 @@ func encodeTestMultipartRequest( Explode: true, } if err := q.EncodeParam(cfg, func(e uri.Encoder) error { - if val, ok := request.Object.Get(); ok { - return val.EncodeURI(e) - } - return nil + var enc jx.Encoder + func(e *jx.Encoder) { + if request.Object.Set { + request.Object.Encode(e) + } + }(&enc) + return e.EncodeValue(string(enc.Bytes())) }); err != nil { return errors.Wrap(err, "encode query") } diff --git a/internal/integration/test_form/oas_uri_gen.go b/internal/integration/test_form/oas_uri_gen.go index 844aa2a30..272c42f4d 100644 --- a/internal/integration/test_form/oas_uri_gen.go +++ b/internal/integration/test_form/oas_uri_gen.go @@ -247,123 +247,6 @@ func (s *TestFormMultipartDeepObject) DecodeURI(d uri.Decoder) error { return nil } -// EncodeURI encodes TestFormMultipartObject as URI form. -func (s *TestFormMultipartObject) EncodeURI(e uri.Encoder) error { - if err := e.EncodeField("min", func(e uri.Encoder) error { - if val, ok := s.Min.Get(); ok { - return e.EncodeValue(conv.IntToString(val)) - } - return nil - }); err != nil { - return errors.Wrap(err, "encode field \"min\"") - } - if err := e.EncodeField("max", func(e uri.Encoder) error { - return e.EncodeValue(conv.IntToString(s.Max)) - }); err != nil { - return errors.Wrap(err, "encode field \"max\"") - } - return nil -} - -var uriFieldsNameOfTestFormMultipartObject = [2]string{ - 0: "min", - 1: "max", -} - -// DecodeURI decodes TestFormMultipartObject from URI form. -func (s *TestFormMultipartObject) DecodeURI(d uri.Decoder) error { - if s == nil { - return errors.New("invalid: unable to decode TestFormMultipartObject to nil") - } - var requiredBitSet [1]uint8 - - if err := d.DecodeFields(func(k string, d uri.Decoder) error { - switch k { - case "min": - if err := func() error { - var sDotMinVal int - if err := func() error { - val, err := d.DecodeValue() - if err != nil { - return err - } - - c, err := conv.ToInt(val) - if err != nil { - return err - } - - sDotMinVal = c - return nil - }(); err != nil { - return err - } - s.Min.SetTo(sDotMinVal) - return nil - }(); err != nil { - return errors.Wrap(err, "decode field \"min\"") - } - case "max": - requiredBitSet[0] |= 1 << 1 - if err := func() error { - val, err := d.DecodeValue() - if err != nil { - return err - } - - c, err := conv.ToInt(val) - if err != nil { - return err - } - - s.Max = c - return nil - }(); err != nil { - return errors.Wrap(err, "decode field \"max\"") - } - default: - return nil - } - return nil - }); err != nil { - return errors.Wrap(err, "decode TestFormMultipartObject") - } - // Validate required fields. - var failures []validate.FieldError - for i, mask := range [1]uint8{ - 0b00000010, - } { - if result := (requiredBitSet[i] & mask) ^ mask; result != 0 { - // Mask only required fields and check equality to mask using XOR. - // - // If XOR result is not zero, result is not equal to expected, so some fields are missed. - // Bits of fields which would be set are actually bits of missed fields. - missed := bits.OnesCount8(result) - for bitN := 0; bitN < missed; bitN++ { - bitIdx := bits.TrailingZeros8(result) - fieldIdx := i*8 + bitIdx - var name string - if fieldIdx < len(uriFieldsNameOfTestFormMultipartObject) { - name = uriFieldsNameOfTestFormMultipartObject[fieldIdx] - } else { - name = strconv.Itoa(fieldIdx) - } - failures = append(failures, validate.FieldError{ - Name: name, - Error: validate.ErrFieldRequired, - }) - // Reset bit. - result &^= 1 << bitIdx - } - } - } - if len(failures) > 0 { - return &validate.Error{Fields: failures} - } - - return nil -} - // EncodeURI encodes TestFormObject as URI form. func (s *TestFormObject) EncodeURI(e uri.Encoder) error { if err := e.EncodeField("min", func(e uri.Encoder) error {