Skip to content

Commit

Permalink
Merge pull request #1367 from tdakkota/feat/request-options
Browse files Browse the repository at this point in the history
feat(gen): add request options
  • Loading branch information
tdakkota authored Dec 16, 2024
2 parents 66e8bcd + 3dc4b07 commit b987c22
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 11 deletions.
117 changes: 107 additions & 10 deletions gen/_template/client.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,59 @@
{{- /*gotype: github.com/ogen-go/ogen/gen.TemplateConfig*/ -}}
{{ template "header" $ }}

{{- if $.RequestOptionsEnabled }}
type requestConfig struct {
Client ht.Client
ServerURL *url.URL
EditRequest func(req *http.Request) error
EditResponse func(resp *http.Response) error
}

func (cfg *requestConfig) setDefaults(c baseClient) {
if cfg.Client == nil {
cfg.Client = c.cfg.Client
}
}

func (cfg *requestConfig) onRequest(req *http.Request) error {
if fn := cfg.EditRequest; fn != nil {
return fn(req)
}
return nil
}

func (cfg *requestConfig) onResponse(resp *http.Response) error {
if fn := cfg.EditResponse; fn != nil {
return fn(resp)
}
return nil
}

// RequestOption defines options for request.
type RequestOption func(cfg *requestConfig)

// WithRequestClient sets client for request.
func WithRequestClient(client ht.Client) RequestOption {
return func(cfg *requestConfig) {
cfg.Client = client
}
}

// WithEditRequest sets function to edit request.
func WithEditRequest(fn func(req *http.Request) error) RequestOption {
return func(cfg *requestConfig) {
cfg.EditRequest = fn
}
}

// WithEditResponse sets function to edit response.
func WithEditResponse(fn func(resp *http.Response) error) RequestOption {
return func(cfg *requestConfig) {
cfg.EditResponse = fn
}
}
{{- end }}

{{- if $.PathsClientEnabled }}

// Invoker invokes operations described by OpenAPI v3 specification.
Expand All @@ -16,7 +69,8 @@ type Invoker interface {
{{ $op.Name }}(ctx context.Context
{{- if $op.WebhookInfo }}, targetURL string{{ end }}
{{- if $op.Request }}, request {{ $op.Request.GoType }}{{ end }}
{{- if $op.Params }}, params {{ $op.Name }}Params {{ end }}) {{ $op.Responses.ResultTuple "" "" }}
{{- if $op.Params }}, params {{ $op.Name }}Params {{ end }}
{{- if $.RequestOptionsEnabled }}, options ...RequestOption{{ end }}) {{ $op.Responses.ResultTuple "" "" }}
{{- end }}
}

Expand All @@ -32,7 +86,8 @@ type {{ $group.Name }}Invoker interface {
{{ $op.Name }}(ctx context.Context
{{- if $op.WebhookInfo }}, targetURL string{{ end }}
{{- if $op.Request }}, request {{ $op.Request.GoType }}{{ end }}
{{- if $op.Params }}, params {{ $op.Name }}Params {{ end }}) {{ $op.Responses.ResultTuple "" "" }}
{{- if $op.Params }}, params {{ $op.Name }}Params {{ end }}
{{- if $.RequestOptionsEnabled }}, options ...RequestOption {{ end }}) {{ $op.Responses.ResultTuple "" "" }}
{{- end }}
}
{{- end }}
Expand All @@ -46,7 +101,7 @@ type Client struct {
baseClient
}

{{- if $.PathsServerEnabled }}
{{- if and $.PathsServerEnabled (not $.RequestOptionsEnabled) }}
{{- if $.Error }}
type errorHandler interface {
NewError(ctx context.Context, err error) {{ $.ErrorGoType }}
Expand Down Expand Up @@ -84,6 +139,7 @@ func NewClient(serverURL string, {{- if $.Securities }}sec SecuritySource,{{- en
}, nil
}

{{- if not $.RequestOptionsEnabled }}
type serverURLKey struct{}

// WithServerURL sets context key to override server URL.
Expand All @@ -98,6 +154,7 @@ func (c *Client) requestURL(ctx context.Context) *url.URL {
}
return u
}
{{- end }}

{{- range $op := $.Operations }}
{{ template "client/operation" op_elem $op $ }}
Expand Down Expand Up @@ -142,7 +199,8 @@ func NewWebhookClient(opts ...ClientOption) (*WebhookClient, error) {
func (c *{{ if $op.WebhookInfo }}Webhook{{ end }}Client) {{ $op.Name }}(ctx context.Context
{{- if $op.WebhookInfo }}, targetURL string{{ end }}
{{- if $op.Request }}, request {{ $op.Request.GoType }}{{ end }}
{{- if $op.Params }}, params {{ $op.Name }}Params {{ end }}) {{ $op.Responses.ResultTuple "" "" }} {
{{- if $op.Params }}, params {{ $op.Name }}Params {{ end }}
{{- if $cfg.RequestOptionsEnabled }}, options ...RequestOption {{ end }}) {{ $op.Responses.ResultTuple "" "" }} {
{{ if $op.Responses.DoPass }}res{{ else }}_{{ end }}, err := c.send{{ $op.Name }}(ctx
{{- if $op.WebhookInfo }},targetURL{{ end -}}
{{- if $op.Request }},request{{ end -}}
Expand All @@ -154,7 +212,8 @@ func (c *{{ if $op.WebhookInfo }}Webhook{{ end }}Client) {{ $op.Name }}(ctx cont
func (c *{{ if $op.WebhookInfo }}Webhook{{ end }}Client) send{{ $op.Name }}(ctx context.Context
{{- if $op.WebhookInfo }}, targetURL string{{ end }}
{{- if $op.Request }}, request {{ $op.Request.GoType }}{{ end }}
{{- if $op.Params }}, params {{ $op.Name }}Params {{ end }}) (res {{ $op.Responses.GoType }}, err error) {
{{- if $op.Params }}, params {{ $op.Name }}Params {{ end }}
{{- if $cfg.RequestOptionsEnabled }}, requestOptions ...RequestOption {{ end }}) (res {{ $op.Responses.GoType }}, err error) {

{{- if and $op.Request $cfg.RequestValidationEnabled }}{{/* Request validation */}}
{{- if $op.Request.Type.IsInterface }}
Expand Down Expand Up @@ -238,14 +297,36 @@ func (c *{{ if $op.WebhookInfo }}Webhook{{ end }}Client) send{{ $op.Name }}(ctx
}()
{{- end }}

{{ if $cfg.RequestOptionsEnabled -}}
var reqCfg requestConfig
reqCfg.setDefaults(c.baseClient)
for _, o := range requestOptions {
o(&reqCfg)
}
{{- end }}

{{ if $otel }}stage = "BuildURL"{{ end }}
{{- if $op.WebhookInfo }}
u, err := url.Parse(targetURL)
if err != nil {
return res, errors.Wrap(err, "parse target URL")
}
trimTrailingSlashes(u)
u, err := url.Parse(targetURL)
if err != nil {
return res, errors.Wrap(err, "parse target URL")
}
{{- if $cfg.RequestOptionsEnabled }}
if override := reqCfg.ServerURL; override != nil {
u = uri.Clone(override)
}
{{- end }}
trimTrailingSlashes(u)
{{- else }}
{{- if $cfg.RequestOptionsEnabled }}
u := c.serverURL
if override := reqCfg.ServerURL; override != nil {
u = override
}
u = uri.Clone(u)
{{- else }}
u := uri.Clone(c.requestURL(ctx))
{{- end }}
{{- template "encode_path_parameters" $op }}
{{- end }}

Expand Down Expand Up @@ -317,13 +398,29 @@ func (c *{{ if $op.WebhookInfo }}Webhook{{ end }}Client) send{{ $op.Name }}(ctx
}
{{- end }}

{{ if $cfg.RequestOptionsEnabled -}}
if err := reqCfg.onRequest(r); err != nil {
return res, errors.Wrap(err, "edit request")
}
{{- end }}

{{ if $otel }}stage = "SendRequest"{{ end }}
{{- if $cfg.RequestOptionsEnabled }}
resp, err := reqCfg.Client.Do(r)
{{- else }}
resp, err := c.cfg.Client.Do(r)
{{- end }}
if err != nil {
return res, errors.Wrap(err, "do request")
}
defer resp.Body.Close()

{{ if $cfg.RequestOptionsEnabled -}}
if err := reqCfg.onResponse(resp); err != nil {
return res, errors.Wrap(err, "edit response")
}
{{- end }}

{{ if $otel }}stage = "DecodeResponse"{{ end }}
result, err := decode{{ $op.Name }}Response(resp)
if err != nil {
Expand Down
1 change: 0 additions & 1 deletion gen/_template/parameter_encode.tmpl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
{{ define "encode_path_parameters" }}{{/*gotype: github.com/ogen-go/ogen/gen/ir.Operation*/}}
u := uri.Clone(c.requestURL(ctx))
var pathParts [{{ len $.PathParts }}]string
{{- range $idx, $part := $.PathParts }}{{/* Range over path parts */}}
{{- if $part.Raw }}
Expand Down
5 changes: 5 additions & 0 deletions gen/features.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ var (
"client/security/reentrant",
`Enables client usage in security source implementations`,
}
ClientRequestOptions = Feature{
"client/request/options",
`Enables function options for client requests`,
}
ClientRequestValidation = Feature{
"client/request/validation",
`Enables validation of client requests`,
Expand Down Expand Up @@ -152,6 +156,7 @@ var AllFeatures = []Feature{
WebhooksClient,
WebhooksServer,
ClientSecurityReentrant,
ClientRequestOptions,
ClientRequestValidation,
ServerResponseValidation,
OgenOtel,
Expand Down
2 changes: 2 additions & 0 deletions gen/write.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ type TemplateConfig struct {
WebhookServerEnabled bool
OpenTelemetryEnabled bool
SecurityReentrantEnabled bool
RequestOptionsEnabled bool
RequestValidationEnabled bool
ResponseValidationEnabled bool

Expand Down Expand Up @@ -270,6 +271,7 @@ func (g *Generator) WriteSource(fs FileSystem, pkgName string) error {
WebhookServerEnabled: features.Has(WebhooksServer) && len(g.webhooks) > 0,
OpenTelemetryEnabled: features.Has(OgenOtel),
SecurityReentrantEnabled: features.Has(ClientSecurityReentrant),
RequestOptionsEnabled: features.Has(ClientRequestOptions),
RequestValidationEnabled: features.Has(ClientRequestValidation),
ResponseValidationEnabled: features.Has(ServerResponseValidation),
// Unused for now.
Expand Down

0 comments on commit b987c22

Please sign in to comment.