Skip to content

Commit

Permalink
添加 OracleDB 插件及其 SQL 执行功能,更新 go.mod 和 go.sum 文件以包含相关依赖
Browse files Browse the repository at this point in the history
  • Loading branch information
zgxkbtl committed Dec 2, 2024
1 parent 1e972c9 commit 8d67644
Show file tree
Hide file tree
Showing 7 changed files with 282 additions and 5 deletions.
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ require (
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/pterm/pterm v0.12.79 // indirect
github.com/rivo/uniseg v0.4.4 // indirect
github.com/sijms/go-ora v1.3.2 // indirect
github.com/sijms/go-ora/v2 v2.8.22 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
golang.org/x/crypto v0.24.0 // indirect
Expand Down
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgY
github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE=
github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ=
github.com/sergi/go-diff v1.2.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM=
github.com/sijms/go-ora v1.3.2 h1:v9Ca63acRbrE5vYlHpABzlOvt8bI1Sj5PCVDwaAJjp8=
github.com/sijms/go-ora v1.3.2/go.mod h1:ZGVmJgxUfyGIVmYgA7MVGEq6BX5aoFECRMtHW5DEcs4=
github.com/sijms/go-ora/v2 v2.8.22 h1:3ABgRzVKxS439cEgSLjFKutIwOyhnyi4oOSBywEdOlU=
github.com/sijms/go-ora/v2 v2.8.22/go.mod h1:QgFInVi3ZWyqAiJwzBQA+nbKYKH77tdp1PYoCqhR2dU=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo=
Expand Down
201 changes: 201 additions & 0 deletions pkg/plugins/oracledb_plugin.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
package plugins

import (
"context"
"database/sql"
"fmt"
"reflect"
"time"

"github.com/open-dingtalk/dingtalk-stream-sdk-go/payload"
"github.com/open-dingtalk/ipaas-agent/pkg/logger"
v1 "github.com/open-dingtalk/ipaas-agent/pkg/plugins/v1"
go_ora "github.com/sijms/go-ora/v2"
"github.com/spf13/viper"
)

type OracleDBPlugin struct {
Name string
AllowRemote bool
Configs []Body
}

func (p *OracleDBPlugin) GetConnection(body *Body) (*sql.DB, error) {
var urlOptions map[string]string
if body.SID != "" {
urlOptions = map[string]string{
"SID": body.SID,
}
}
connString := go_ora.BuildUrl(
body.Host, int(body.Port), body.ServiceName, body.User, body.Password, urlOptions,
)

return sql.Open("oracle", connString)
}

// doSQLExecute 执行SQL查询
func (p *OracleDBPlugin) DoSQLExecute(body *Body) (qr *QueryResult) {
startTime := time.Now()
defer func() {
logger.Log1.WithField("cost", time.Since(startTime).String()).Infof("SQL查询结束")
}()

// 获取数据库连接
db, err := p.GetConnection(body)
if err != nil {
return &QueryResult{
Result: nil,
Columns: nil,
Message: err.Error(),
}
}
defer db.Close()

// Sleep for 10000ms to simulate processing time
// time.Sleep(10000 * time.Millisecond)

rows, err := db.Query(body.SQL)
if err != nil {
return &QueryResult{
Result: nil,
Columns: nil,
Message: err.Error(),
}
}
defer rows.Close()

// 获取列名
columns, err := rows.Columns()
if err != nil {
return &QueryResult{
Result: nil,
Columns: nil,
Message: err.Error(),
}
}

// 获取列的类型信息
columnTypes, err := rows.ColumnTypes()
if err != nil {
return &QueryResult{
Result: nil,
Columns: nil,
Message: err.Error(),
}
}

// 准备结果集
var result []map[string]interface{}

// 扫描每一行
for rows.Next() {
// 创建一个切片,用于存储一行的值
values := make([]interface{}, len(columns))
for i, colType := range columnTypes {
// 根据列的扫描类型创建对应的变量
values[i] = reflect.New(colType.ScanType()).Interface()
}

// 扫描行数据
err := rows.Scan(values...)
if err != nil {
continue
}

// 将行数据转换为 map
row := make(map[string]interface{})
for i, col := range columns {
// 处理指针类型,获取实际的值
val := values[i]
if bv, ok := val.(*interface{}); ok {
row[col] = *bv
} else {
row[col] = val
}
}
result = append(result, row)
}

return &QueryResult{
Result: result,
Columns: columns,
Message: "success",
}
}

func (p *OracleDBPlugin) findConfigByKey(key string) *Body {
for _, config := range p.Configs {
if config.ConfigKey == key {
logger.Log1.WithField("config", config).Info("找到配置")
return &config
}
}
return nil
}

func NewOracleDBPlugin() *OracleDBPlugin {
return &OracleDBPlugin{
Name: "oracledb_plugin",
}
}

func (p *OracleDBPlugin) Init() error {
// 定义一个变量来存储 SQL 配置
var sqlConfigs []Body

// 解析 SQL 配置
if err := viper.UnmarshalKey("plugins.oracledb", &sqlConfigs); err != nil {
logger.Log1.Fatalf("解析 oracle 数据库配置出错: %v", err)
}

p.Configs = sqlConfigs

p.AllowRemote = viper.GetBool("auth.oracledb.allow_remote")

logger.Log1.
WithField("插件名", p.Name).
WithField("配置列表", p.Configs).
WithField("允许远程配置", p.AllowRemote).
Info("插件已初始化")
return nil
}

func (p *OracleDBPlugin) HandleMessage(ctx context.Context, df *v1.DFWrap) (*payload.DataFrameResponse, error) {
// 初始化 Data
data, err := df.GetPluginDataWithType(reflect.TypeOf(Body{}))

if err != nil {
return payload.NewErrorDataFrameResponse(err), err
}

remoteConf := data.(*Body)
if remoteConf.ConfigKey == "" && p.AllowRemote {
logger.Log1.WithField("config", remoteConf).Info("使用远程配置")
} else {
localConf := p.findConfigByKey(remoteConf.ConfigKey)
if localConf == nil {
logger.Log1.WithField("configKey", remoteConf.ConfigKey).
WithField("是否允许远程配置", p.AllowRemote).
Error("未找到配置或不允许远程配置")
return payload.NewErrorDataFrameResponse(fmt.Errorf("未找到配置或不允许远程配置: %s", remoteConf.ConfigKey)), nil
}
remoteConf.completeFrom(localConf)
}

callBackResponse := &CallbackResponse{
Response: p.DoSQLExecute(remoteConf),
}

resp := payload.NewSuccessDataFrameResponse()

resp.SetJson(callBackResponse)

return resp, nil
}

func (p *OracleDBPlugin) Close() error {
// 关闭插件
logger.Log1.WithField("plugin", p.Name).Info("插件已关闭")
return nil
}
27 changes: 22 additions & 5 deletions pkg/plugins/pgsql_plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,27 +68,44 @@ func (p *PGSQLPlugin) DoSQLExecute(body *Body) (qr *QueryResult) {
}
}

// 获取列的类型信息
columnTypes, err := rows.ColumnTypes()
if err != nil {
return &QueryResult{
Result: nil,
Columns: nil,
Message: err.Error(),
}
}

// 准备结果集
var result []map[string]interface{}

// 扫描每一行
for rows.Next() {
// 创建一个切片,用于存储一行的值
values := make([]interface{}, len(columns))
for i := range values {
values[i] = new(interface{})
for i, colType := range columnTypes {
// 根据列的扫描类型创建对应的变量
values[i] = reflect.New(colType.ScanType()).Interface()
}

// 扫描行数据
err := rows.Scan(values...)
if err != nil {
continue
}

// 将行数据转换为map
// 将行数据转换为 map
row := make(map[string]interface{})
for i, col := range columns {
val := values[i].(*interface{})
row[col] = *val
// 处理指针类型,获取实际的值
val := values[i]
if bv, ok := val.(*interface{}); ok {
row[col] = *bv
} else {
row[col] = val
}
}
result = append(result, row)
}
Expand Down
8 changes: 8 additions & 0 deletions pkg/plugins/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,14 @@ func (pm *PluginManager) LoadPlugins() error {
}
pm.RegisterPlugin(pgsqlPlugin.Name, pgsqlPlugin)

// 6. oracle db 插件
oracleDBPlugin := NewOracleDBPlugin()
err = oracleDBPlugin.Init()
if err != nil {
logger.Log1.Errorf("初始化 OracleDB 插件失败: %v", err)
}
pm.RegisterPlugin(oracleDBPlugin.Name, oracleDBPlugin)

return nil
}

Expand Down
9 changes: 9 additions & 0 deletions pkg/plugins/sql_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ func (fi *FlexInt) UnmarshalJSON(data []byte) error {
}

// 将字符串转换为整数
if s == "" {
return nil
}
v, err := strconv.Atoi(s)
if err != nil {
return err
Expand All @@ -45,6 +48,9 @@ type Body struct {
Password string `json:"password,omitempty" mapstructure:"password,omitempty"`
Database string `json:"database,omitempty" mapstructure:"database,omitempty"`
SQL string `json:"sql,omitempty" mapstructure:"sql,omitempty"`
// 以下字段 Oracle DB 专用
ServiceName string `json:"service_name,omitempty" mapstructure:"service_name,omitempty"`
SID string `json:"sid,omitempty" mapstructure:"sid,omitempty"`
// 以下字段用于本地网关配置
Address string `json:"address,omitempty" mapstructure:"address,omitempty"`
ConfigKey string `json:"config_key,omitempty" mapstructure:"config_key,omitempty"`
Expand All @@ -64,6 +70,9 @@ func (b *Body) completeFrom(other *Body) {
b.Address = other.Address
b.ConfigKey = other.ConfigKey
b.ConnString = other.ConnString
// Oracle DB 专用
b.ServiceName = other.ServiceName
b.SID = other.SID
}

// QueryResult 结构体定义查询结果
Expand Down
36 changes: 36 additions & 0 deletions pkg/plugins/sql_executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,39 @@ func TestMYSQLPlugin_doSQLExecute(t *testing.T) {

require.Equal(t, "success", qr.Message)
}

func TestORACLEDBPlugin_doSQLExecute(t *testing.T) {
// 创建一个 MSSQL 插件
p := &plugin.OracleDBPlugin{
Name: "",
AllowRemote: true,
}
// 创建一个 Body
body := &plugin.Body{
Host: "localhost",
Port: 1521,
User: "system",
Password: "example",
SID: "FREE",
SQL: "SELECT * FROM HELP WHERE ROWNUM <= 10",
}
// 执行 SQL 查询
qr := p.DoSQLExecute(body)
// 断言结果
require.NotNil(t, qr)
require.NotNil(t, qr.Result)
require.NotNil(t, qr.Columns)
// 打印到控制台
for _, row := range qr.Result {
for key, col := range row {
switch v := col.(type) {
case []byte:
t.Logf("%s: %s", key, string(v))
default:
t.Logf("%s: %v", key, v)
}
}
}

require.Equal(t, "success", qr.Message)
}

0 comments on commit 8d67644

Please sign in to comment.