Skip to content

Commit

Permalink
feat: protoc-gen-rpc-swagger (#6)
Browse files Browse the repository at this point in the history
* feat: protoc-gen-http-swagger plugin

* fix: protoc-gen-http-swagger plugin

* fix: go version

* fix: go mod

* fix: golangci-lint

* fix: license header

* fix: typos

* fix: remove unnecessary documents

* fix: fix some points according to the code review

* fix: remove indirect licenses

* fix: response schema mistake

* chore: optimize sorted imports

* chore: format licenses

* fix: generate different response schema name

* fix: content type

* feat: protoc-gen-rpc-swagger plugin

* fix: pb rpc swagger server gen
  • Loading branch information
EZ4Jam1n authored Sep 21, 2024
1 parent cad164f commit c7a708e
Show file tree
Hide file tree
Showing 31 changed files with 6,261 additions and 136 deletions.
6 changes: 6 additions & 0 deletions common/consts/consts.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,14 @@ const (

const (
CodeGenerationCommentPbHttp = "// Code generated by protoc-gen-http-swagger."
CodeGenerationCommentPbRpc = "// Code generated by protoc-gen-rpc-swagger."
CodeGenerationCommentThriftHttp = "// Code generated by thrift-gen-http-swagger."
CodeGenerationCommentThriftRpc = "// Code generated by thrift-gen-rpc-swagger."
)

const (
PluginNameProtocHttpSwagger = "protoc-gen-http-swagger"
PluginNameProtocRpcSwagger = "protoc-gen-rpc-swagger"
PluginNameThriftHttpSwagger = "thrift-gen-http-swagger"
PluginNameThriftRpcSwagger = "thrift-gen-rpc-swagger"
)
Expand All @@ -82,6 +84,7 @@ const (
SchemaObjectType = "object"
ComponentSchemaPrefix = "#/components/schemas/"
ComponentSchemaSuffixBody = "Body"
ComponentSchemaSuffixForm = "Form"
ComponentSchemaSuffixRawBody = "RawBody"

ContentTypeJSON = "application/json"
Expand All @@ -106,4 +109,7 @@ const (

CommentPatternRegexp = `//\s*(.*)|/\*([\s\S]*?)\*/`
LinterRulePatternRegexp = `\(-- .* --\)`

ProtobufValueName = "GoogleProtobufValue"
ProtobufAnyName = "GoogleProtobufAny"
)
293 changes: 289 additions & 4 deletions common/tpl/tpl.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,13 +169,23 @@ func findThriftFile(fileName string) (string, error) {
}
foundPath := ""
relativePath := fileName
err = filepath.Walk(workingDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if !info.IsDir() && info.Name() == fileName {
foundPath = path
return filepath.SkipDir
if !info.IsDir() {
relative, err := filepath.Rel(workingDir, path)
if err != nil {
return err
}
if relative == relativePath {
foundPath = path
return filepath.SkipDir
}
}
return nil
})
Expand Down Expand Up @@ -210,7 +220,282 @@ func initializeGenericClient() genericclient.Client {
g, err := generic.JSONThriftGeneric(p)
if err != nil {
hlog.Fatal("Failed to create HTTPThriftGeneric:", err)
hlog.Fatal("Failed to create JsonThriftGeneric:", err)
}
var opts []client.Option
opts = append(opts, client.WithTransportProtocol(transport.TTHeader))
opts = append(opts, client.WithMetaHandler(transmeta.ClientTTHeaderHandler))
opts = append(opts, client.WithHostPorts(kitexAddr))
cli, err := genericclient.NewClient("swagger", g, opts...)
if err != nil {
hlog.Fatal("Failed to create generic client:", err)
}
return cli
}
func setupSwaggerRoutes(h *server.Hertz) {
h.GET("swagger/*any", swagger.WrapHandler(swaggerFiles.Handler, swagger.URL("/openapi.yaml")))
h.GET("/openapi.yaml", func(c context.Context, ctx *app.RequestContext) {
ctx.Header("Content-Type", "application/x-yaml")
ctx.Write(openapiYAML)
})
}
func setupProxyRoutes(h *server.Hertz, cli genericclient.Client) {
h.Any("/*ServiceMethod", func(c context.Context, ctx *app.RequestContext) {
serviceMethod := ctx.Param("ServiceMethod")
if serviceMethod == "" {
handleError(ctx, "ServiceMethod not provided", http.StatusBadRequest)
return
}
bodyBytes := ctx.Request.Body()
queryMap := formatQueryParams(ctx)
for k, v := range queryMap {
if strings.HasPrefix(k, "p_") {
c = metainfo.WithPersistentValue(c, k, v)
} else {
c = metainfo.WithValue(c, k, v)
}
}
c = metainfo.WithBackwardValues(c)
jReq := string(bodyBytes)
jRsp, err := cli.GenericCall(c, serviceMethod, jReq)
if err != nil {
hlog.Errorf("GenericCall error: %v", err)
ctx.JSON(500, map[string]interface{}{
"error": err.Error(),
})
return
}
result := make(map[string]interface{})
if err := json.Unmarshal([]byte(jRsp.(string)), &result); err != nil {
hlog.Errorf("Failed to unmarshal response body: %v", err)
ctx.JSON(500, map[string]interface{}{
"error": "Failed to unmarshal response body",
})
return
}
m := metainfo.RecvAllBackwardValues(c)
for key, value := range m {
result[key] = value
}
respBody, err := json.Marshal(result)
if err != nil {
hlog.Errorf("Failed to marshal response body: %v", err)
ctx.JSON(500, map[string]interface{}{
"error": "Failed to marshal response body",
})
return
}
ctx.Data(http.StatusOK, "application/json", respBody)
})
}
func formatQueryParams(ctx *app.RequestContext) map[string]string {
var QueryParams = make(map[string]string)
ctx.Request.URI().QueryArgs().VisitAll(func(key, value []byte) {
QueryParams[string(key)] = string(value)
})
return QueryParams
}
func handleError(ctx *app.RequestContext, errMsg string, statusCode int) {
hlog.Errorf("Error: %s", errMsg)
ctx.JSON(statusCode, map[string]interface{}{
"error": errMsg,
})
}
`

const ServerTemplateRpcPb = `package swagger
import (
"context"
_ "embed"
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
"os"
"path/filepath"
"regexp"
"strings"
"github.com/bytedance/gopkg/cloud/metainfo"
"github.com/cloudwego/hertz/pkg/app"
"github.com/cloudwego/hertz/pkg/app/server"
"github.com/cloudwego/hertz/pkg/common/hlog"
"github.com/cloudwego/hertz/pkg/network"
"github.com/cloudwego/hertz/pkg/route"
"github.com/cloudwego/kitex/client"
"github.com/cloudwego/kitex/client/genericclient"
"github.com/cloudwego/kitex/pkg/endpoint"
"github.com/cloudwego/kitex/pkg/generic"
"github.com/cloudwego/kitex/pkg/klog"
"github.com/cloudwego/kitex/pkg/remote"
"github.com/cloudwego/kitex/pkg/remote/trans/detection"
"github.com/cloudwego/kitex/pkg/remote/trans/netpoll"
"github.com/cloudwego/kitex/pkg/remote/trans/nphttp2"
"github.com/cloudwego/kitex/pkg/transmeta"
"github.com/cloudwego/kitex/transport"
"github.com/hertz-contrib/cors"
"github.com/hertz-contrib/swagger"
swaggerFiles "github.com/swaggo/files"
)
var (
//go:embed openapi.yaml
openapiYAML []byte
hertzEngine *route.Engine
httpReg = regexp.MustCompile("^(?:GET |POST|PUT|DELE|HEAD|OPTI|CONN|TRAC|PATC)$")
)
const (
kitexAddr = "{{.KitexAddr}}"
idlFile = "{{.IdlPath}}"
)
type MixTransHandlerFactory struct {
OriginFactory remote.ServerTransHandlerFactory
}
type transHandler struct {
remote.ServerTransHandler
}
func (t *transHandler) SetInvokeHandleFunc(inkHdlFunc endpoint.Endpoint) {
t.ServerTransHandler.(remote.InvokeHandleFuncSetter).SetInvokeHandleFunc(inkHdlFunc)
}
func (m MixTransHandlerFactory) NewTransHandler(opt *remote.ServerOption) (remote.ServerTransHandler, error) {
if hertzEngine == nil {
StartServer()
}
var kitexOrigin remote.ServerTransHandler
var err error
if m.OriginFactory != nil {
kitexOrigin, err = m.OriginFactory.NewTransHandler(opt)
} else {
kitexOrigin, err = detection.NewSvrTransHandlerFactory(netpoll.NewSvrTransHandlerFactory(), nphttp2.NewSvrTransHandlerFactory()).NewTransHandler(opt)
}
if err != nil {
return nil, err
}
return &transHandler{ServerTransHandler: kitexOrigin}, nil
}
func (t *transHandler) OnRead(ctx context.Context, conn net.Conn) error {
c, ok := conn.(network.Conn)
if ok {
pre, _ := c.Peek(4)
if httpReg.Match(pre) {
klog.Info("using Hertz to process request")
err := hertzEngine.Serve(ctx, c)
if err != nil {
err = errors.New(fmt.Sprintf("HERTZ: %s", err.Error()))
}
return err
}
}
return t.ServerTransHandler.OnRead(ctx, conn)
}
func StartServer() {
h := server.Default()
h.Use(cors.Default())
cli := initializeGenericClient()
setupSwaggerRoutes(h)
setupProxyRoutes(h, cli)
hlog.Info("Swagger UI is available at: http://" + kitexAddr + "/swagger/index.html")
err := h.Engine.Init()
if err != nil {
panic(err)
}
hertzEngine = h.Engine
}
func findPbFile(fileName string) (string, error) {
workingDir, err := os.Getwd()
if err != nil {
return "", err
}
foundPath := ""
relativePath := fileName
err = filepath.Walk(workingDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if !info.IsDir() {
relative, err := filepath.Rel(workingDir, path)
if err != nil {
return err
}
if relative == relativePath {
foundPath = path
return filepath.SkipDir
}
}
return nil
})
if err == nil && foundPath != "" {
return foundPath, nil
}
parentDir := filepath.Dir(workingDir)
for parentDir != "/" && parentDir != "." && parentDir != workingDir {
filePath := filepath.Join(parentDir, fileName)
if _, err := os.Stat(filePath); err == nil {
return filePath, nil
}
workingDir = parentDir
parentDir = filepath.Dir(parentDir)
}
return "", errors.New("proto file not found: " + fileName)
}
func initializeGenericClient() genericclient.Client {
pbFile, err := findPbFile(idlFile)
if err != nil {
hlog.Fatal("Failed to locate Proto file:", err)
}
dOpts := proto.Options{}
p, err := generic.NewPbFileProviderWithDynamicGo(pbFile, context.Background(), dOpts)
if err != nil {
hlog.Fatal("Failed to create PbFileProvider:", err)
}
g, err := generic.JSONPbGeneric(p)
if err != nil {
hlog.Fatal("Failed to create JsonPbGeneric:", err)
}
var opts []client.Option
opts = append(opts, client.WithTransportProtocol(transport.TTHeader))
Expand Down
Loading

0 comments on commit c7a708e

Please sign in to comment.