From 75e1defd6cc2f30ec2a4fb61c334dd2c23bf04d4 Mon Sep 17 00:00:00 2001 From: sakura <32236524+yuemingming@users.noreply.github.com> Date: Tue, 27 Aug 2024 19:24:39 +0800 Subject: [PATCH] feat(plugin) add ai-history plugin (#1224) --- .../wasm-go/extensions/ai-history/.buildrc | 1 + .../wasm-go/extensions/ai-history/.gitignore | 19 + .../wasm-go/extensions/ai-history/README.md | 172 +++++++ plugins/wasm-go/extensions/ai-history/go.mod | 22 + plugins/wasm-go/extensions/ai-history/go.sum | 20 + plugins/wasm-go/extensions/ai-history/main.go | 480 ++++++++++++++++++ .../extensions/ai-history/main_test.go | 36 ++ .../wasm-go/extensions/ai-history/option.yaml | 52 ++ 8 files changed, 802 insertions(+) create mode 100644 plugins/wasm-go/extensions/ai-history/.buildrc create mode 100644 plugins/wasm-go/extensions/ai-history/.gitignore create mode 100644 plugins/wasm-go/extensions/ai-history/README.md create mode 100644 plugins/wasm-go/extensions/ai-history/go.mod create mode 100644 plugins/wasm-go/extensions/ai-history/go.sum create mode 100644 plugins/wasm-go/extensions/ai-history/main.go create mode 100644 plugins/wasm-go/extensions/ai-history/main_test.go create mode 100644 plugins/wasm-go/extensions/ai-history/option.yaml diff --git a/plugins/wasm-go/extensions/ai-history/.buildrc b/plugins/wasm-go/extensions/ai-history/.buildrc new file mode 100644 index 0000000000..f76a2883ac --- /dev/null +++ b/plugins/wasm-go/extensions/ai-history/.buildrc @@ -0,0 +1 @@ +EXTRA_TAGS=proxy_wasm_version_0_2_100 \ No newline at end of file diff --git a/plugins/wasm-go/extensions/ai-history/.gitignore b/plugins/wasm-go/extensions/ai-history/.gitignore new file mode 100644 index 0000000000..b9340139b9 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-history/.gitignore @@ -0,0 +1,19 @@ +# File generated by hgctl. Modify as required. + +* + +!/.gitignore +!/.buildrc +!*.go +!go.sum +!go.mod + +!LICENSE +!*.md +!*.yaml +!*.yml + +!*/ + +/out +/test diff --git a/plugins/wasm-go/extensions/ai-history/README.md b/plugins/wasm-go/extensions/ai-history/README.md new file mode 100644 index 0000000000..8d1e6a3a1d --- /dev/null +++ b/plugins/wasm-go/extensions/ai-history/README.md @@ -0,0 +1,172 @@ +--- +title: AI 历史对话 +keywords: [ AI网关, AI历史对话 ] +description: AI 历史对话插件配置参考 +--- + +## 功能说明 + +`AI 历史对话` 基于请求头实现用户身份识别,并自动缓存对应用户的历史对话,且在后续对话中自动填充到上下文。同时支持用户主动查询历史对话。 +**Note** + +> 需要数据面的proxy wasm版本大于等于0.2.100 + +> 编译时,需要带上版本的tag,例如: +`tinygo build -o main.wasm -scheduler=none -target=wasi -gc=custom -tags="custommalloc nottinygc_finalizer proxy_wasm_version_0_2_100" ./` + +> 路径后缀匹配 `ai-history/query` 时,会返回历史对话 + +## 配置字段 + +| 名称 | 数据类型 | 填写要求 | 默认值 | Description | +|-------------------|---------|----------|-----------------------|---------------------------------------------------------------------------| +| identityHeader | string | optional | "Authorization" | 身份解析对应的请求头,可用 Authorization,X-Mse-Consumer等 | +| fillHistoryCnt | integer | optional | 3 | 默认填充历史对话轮次 | +| cacheKeyPrefix | string | optional | "higress-ai-history:" | Redis缓存Key的前缀 | +| cacheTTL | integer | optional | 0 | 缓存的过期时间,单位是秒,默认值为0,即永不过期 | +| redis.serviceName | string | required | - | redis 服务名称,带服务类型的完整 FQDN 名称,例如 my-redis.dns、redis.my-ns.svc.cluster.local | +| redis.servicePort | integer | optional | 6379 | redis 服务端口 | +| redis.timeout | integer | optional | 1000 | 请求 redis 的超时时间,单位为毫秒 | +| redis.username | string | optional | - | 登陆 redis 的用户名 | +| redis.password | string | optional | - | 登陆 redis 的密码 | + +## 用法示例 + +### 配置信息 + +```yaml +redis: + serviceName: my-redis.dns + timeout: 2000 +``` + +### 请求示例 + +**自动填充请求示例:** + +第一轮请求: + +``` + curl 'http://example.com/api/openai/v1/chat/completions?fill_history_cnt=3' \ + -H 'Accept: application/json, text/event-stream' \ + -H 'Content-Type: application/json' \ + -H 'Authorization: Bearer sk-Nzf7RtkdS4s0zFyn5575124129254d9bAf9473A5D7D06dD3' + --data-raw '{"model":"qwen-long","frequency_penalty":0,"max_tokens":800,"stream":false,"messages":[ + { + "role": "user", + "content": "Higress 可以替换 Nginx 吗?" + } + ],"presence_penalty":0,"temperature":0.7,"top_p":0.95}' +``` + +请求填充之后: +> 第一轮请求,无填充。和原始请求一致。 + +第一轮响应: + +```json +{ + "id": "02f4c621-820e-97d4-a905-1e3d0d8f59c6", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Higress 和 Nginx 虽然都有作为网关的功能,但它们的设计理念和应用场景有所不同。Nginx 更多是作为一个高性能的 HTTP 和反向代理服务器被大家熟知,而 Higress 是一个云原生网关,除了基础的路由转发能力外,还集成了服务网格、可观测性、安全管理等众多云原生特性。\n\n因此,如果你想在云原生环境中部署应用,并且希望获得现代应用所需的高级功能,比如服务治理、灰度发布、熔断限流、安全认证等功能,那么 Higress 可以作为一个很好的 Nginx 替代方案。但如果是较为简单的静态网站或者仅需要基本的反向代理功能,传统的 Nginx 配置可能会更为简单直接。" + }, + "finish_reason": "stop" + } + ], + "created": 1724077770, + "model": "qwen-long", + "object": "chat.completion", + "usage": { + "prompt_tokens": 7316, + "completion_tokens": 164, + "total_tokens": 7480 + } +} +``` + +第二轮请求: + +``` + curl 'http://example.com/api/openai/v1/chat/completions?fill_history_cnt=3' \ + -H 'Accept: application/json, text/event-stream' \ + -H 'Content-Type: application/json' \ + -H 'Authorization: Bearer sk-Nzf7RtkdS4s0zFyn5575124129254d9bAf9473A5D7D06dD3' + --data-raw '{"model":"qwen-long","frequency_penalty":0,"max_tokens":800,"stream":false,"messages":[ + { + "role": "user", + "content": "Spring Cloud GateWay 呢?" + } + ],"presence_penalty":0,"temperature":0.7,"top_p":0.95}' +``` + +请求填充之后: +> 第二轮请求,自动填充上一轮的历史对话。 + +``` + curl 'http://example.com/api/openai/v1/chat/completions?fill_history_cnt=3' \ + -H 'Accept: application/json, text/event-stream' \ + -H 'Content-Type: application/json' \ + -H 'Authorization: Bearer sk-Nzf7RtkdS4s0zFyn5575124129254d9bAf9473A5D7D06dD3' + --data-raw '{"model":"qwen-long","frequency_penalty":0,"max_tokens":800,"stream":false,"messages":[ + { + "role": "user", + "content": "Higress 可以替换 Nginx 吗?" + }, + { + "role": "assistant", + "content": "Higress 和 Nginx 虽然都有作为网关的功能,但它们的设计理念和应用场景有所不同。Nginx 更多是作为一个高性能的 HTTP 和反向代理服务器被大家熟知,而 Higress 是一个云原生网关,除了基础的路由转发能力外,还集成了服务网格、可观测性、安全管理等众多云原生特性。\n\n因此,如果你想在云原生环境中部署应用,并且希望获得现代应用所需的高级功能,比如服务治理、灰度发布、熔断限流、安全认证等功能,那么 Higress 可以作为一个很好的 Nginx 替代方案。但如果是较为简单的静态网站或者仅需要基本的反向代理功能,传统的 Nginx 配置可能会更为简单直接。" + }, + { + "role": "user", + "content": "Spring Cloud GateWay 呢?" + } + ],"presence_penalty":0,"temperature":0.7,"top_p":0.95}' +``` + +每轮请求只需要带上当前问题,以及当前需要填充的历史对话轮数,即可自动完成历史对话填充。 + +**获取历史数据示例:** + +``` +curl 'http://example.com/api/openai/v1/chat/completions/ai-history/query?cnt=3' \ + -H 'Accept: application/json, text/event-stream' \ + -H 'Content-Type: application/json' \ + -H 'Authorization: Bearer sk-Nzf7RtkdS4s0zFyn5575124129254d9bAf9473A5D7D06dD3' +``` + +响应示例: + +```json +[ + { + "role": "user", + "content": "Higress 可以替换 Nginx 吗?" + }, + { + "role": "assistant", + "content": "Higress 和 Nginx 虽然都有作为网关的功能,但它们的设计理念和应用场景有所不同。Nginx 更多是作为一个高性能的 HTTP 和反向代理服务器被大家熟知,而 Higress 是一个云原生网关,除了基础的路由转发能力外,还集成了服务网格、可观测性、安全管理等众多云原生特性。\\n\\n因此,如果你想在云原生环境中部署应用,并且希望获得现代应用所需的高级功能,比如服务治理、灰度发布、熔断限流、安全认证等功能,那么 Higress 可以作为一个很好的 Nginx 替代方案。但如果是较为简单的静态网站或者仅需要基本的反向代理功能,传统的 Nginx 配置可能会更为简单直接。" + }, + { + "role": "user", + "content": "SpringCloud GateWay 呢?" + }, + { + "role": "assistant", + "content": "与 Spring Cloud Gateway 相比,Higress 也是一个 API 网关,但它们之间存在一些关键的区别:\\n\\n- **设计理念**:Spring Cloud Gateway 主要针对微服务架构中的服务间通信和路由,它作为 Spring Cloud 生态系统的一部分,更加专注于 Java 开发者的微服务场景。而 Higress 作为云原生网关,不仅关注服务间的通信,还提供了一系列云原生功能,如服务网格、可观测性、安全管理等。\\n- **部署方式**:Spring Cloud Gateway 通常作为微服务应用的一部分运行在应用服务器内,而 Higress 通常以独立的微服务或者容器化服务的形式部署在 Kubernetes 环境中,适用于现代云原生部署模型。\\n- **扩展性和集成**:Higress 提供了更广泛的集成和支持,例如与 Istio、Kubernetes 等生态系统的深度集成,这使得它可以更好地适应复杂的云原生环境。\\n\\n因此,如果你的应用程序是基于 Spring Cloud 构建的,并且你想要一个轻量级的、易于集成的服务网关,那么 Spring Cloud Gateway 可能是一个合适的选择。但是,如果你正在构建或重构云原生应用,并且需要更强大的路由规则、服务治理、可观测性等功能,那么 Higress 将是一个更好的选择。" + }, + { + "role": "user", + "content": "Higress 可以替换 Nginx 吗?" + }, + { + "role": "assistant", + "content": "Higress 和 Nginx 虽然都有作为网关的功能,但它们的设计理念和应用场景有所不同。Nginx 更多是作为一个高性能的 HTTP 和反向代理服务器被大家熟知,而 Higress 是一个云原生网关,除了基础的路由转发能力外,还集成了服务网格、可观测性、安全管理等众多云原生特性。\\n\\n因此,如果你想在云原生环境中部署应用,并且希望获得现代应用所需的高级功能,比如服务治理、灰度发布、熔断限流、安全认证等功能,那么 Higress 可以作为一个很好的 Nginx 替代方案。但如果是较为简单的静态网站或者仅需要基本的反向代理功能,传统的 Nginx 配置可能会更为简单直接。" + } +] +``` + +返回三个历史对话,如果未传入 cnt 默认返回所有缓存历史对话。 \ No newline at end of file diff --git a/plugins/wasm-go/extensions/ai-history/go.mod b/plugins/wasm-go/extensions/ai-history/go.mod new file mode 100644 index 0000000000..7b1c337098 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-history/go.mod @@ -0,0 +1,22 @@ +// File generated by hgctl. Modify as required. + +module github.com/alibaba/higress/plugins/wasm-go/extensions/ai-history + +go 1.19 + +replace github.com/alibaba/higress/plugins/wasm-go => ../.. + +require ( + github.com/alibaba/higress/plugins/wasm-go v1.3.6-0.20240528060522-53bccf89f441 + github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f + github.com/tidwall/gjson v1.14.3 + github.com/tidwall/resp v0.1.1 +) + +require ( + github.com/google/uuid v1.3.0 // indirect + github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 // indirect + github.com/magefile/mage v1.14.0 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.0 // indirect +) diff --git a/plugins/wasm-go/extensions/ai-history/go.sum b/plugins/wasm-go/extensions/ai-history/go.sum new file mode 100644 index 0000000000..f473e12b2d --- /dev/null +++ b/plugins/wasm-go/extensions/ai-history/go.sum @@ -0,0 +1,20 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA= +github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew= +github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg= +github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo= +github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo= +github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw= +github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE= +github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/plugins/wasm-go/extensions/ai-history/main.go b/plugins/wasm-go/extensions/ai-history/main.go new file mode 100644 index 0000000000..512e13f1c6 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-history/main.go @@ -0,0 +1,480 @@ +// File generated by hgctl. Modify as required. +// See: https://higress.io/zh-cn/docs/user/wasm-go#2-%E7%BC%96%E5%86%99-maingo-%E6%96%87%E4%BB%B6 + +package main + +import ( + "encoding/json" + "errors" + "fmt" + "net/url" + "strconv" + "strings" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/tidwall/gjson" + "github.com/tidwall/resp" +) + +const ( + QuestionContextKey = "question" + AnswerContentContextKey = "answer" + PartialMessageContextKey = "partialMessage" + ToolCallsContextKey = "toolCalls" + StreamContextKey = "stream" + DefaultCacheKeyPrefix = "higress-ai-history:" + IdentityKey = "identity" + ChatHistories = "chatHistories" +) + +func main() { + wrapper.SetCtx( + "ai-history", + wrapper.ParseConfigBy(parseConfig), + wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders), + wrapper.ProcessRequestBodyBy(onHttpRequestBody), + wrapper.ProcessResponseHeadersBy(onHttpResponseHeaders), + wrapper.ProcessStreamingResponseBodyBy(onHttpStreamResponseBody), + ) +} + +// @Name ai-history +// @Category protocol +// @Phase AUTHN +// @Priority 10 +// @Title zh-CN AI History +// @Description zh-CN 大模型对话历史缓存 +// @IconUrl +// @Version 0.1.0 +// +// @Contact.name sakura +// @Contact.url +// @Contact.email +// +// @Example +// redis: +// serviceName: my-redis.dns +// timeout: 2000 +// +// @End + +type RedisInfo struct { + // @Title zh-CN redis 服务名称 + // @Description zh-CN 带服务类型的完整 FQDN 名称,例如 my-redis.dns、redis.my-ns.svc.cluster.local + ServiceName string `required:"true" yaml:"serviceName" json:"serviceName"` + // @Title zh-CN redis 服务端口 + // @Description zh-CN 默认值为6379 + ServicePort int `required:"false" yaml:"servicePort" json:"servicePort"` + // @Title zh-CN 用户名 + // @Description zh-CN 登陆 redis 的用户名,非必填 + Username string `required:"false" yaml:"username" json:"username"` + // @Title zh-CN 密码 + // @Description zh-CN 登陆 redis 的密码,非必填,可以只填密码 + Password string `required:"false" yaml:"password" json:"password"` + // @Title zh-CN 请求超时 + // @Description zh-CN 请求 redis 的超时时间,单位为毫秒。默认值是1000,即1秒 + Timeout int `required:"false" yaml:"timeout" json:"timeout"` +} + +type KVExtractor struct { + // @Title zh-CN 从请求 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 + RequestBody string `required:"false" yaml:"requestBody" json:"requestBody"` + // @Title zh-CN 从响应 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 + ResponseBody string `required:"false" yaml:"responseBody" json:"responseBody"` +} + +type PluginConfig struct { + // @Title zh-CN Redis 地址信息 + // @Description zh-CN 用于存储缓存结果的 Redis 地址 + RedisInfo RedisInfo `required:"true" yaml:"redis" json:"redis"` + // @Title zh-CN 缓存 key 的来源 + // @Description zh-CN 往 redis 里存时,问题的提取方式 + QuestionFrom KVExtractor `required:"true" yaml:"questionFrom" json:"questionFrom"` + // @Title zh-CN 缓存 value 的来源 + // @Description zh-CN 往 redis 里存时,使用的 answer 的提取方式 + AnswerValueFrom KVExtractor `required:"true" yaml:"answerValueFrom" json:"answerValueFrom"` + // @Title zh-CN 流式响应下,缓存 value 的来源 + // @Description zh-CN 往 redis 里存时,使用的 answer 的提取方式 + AnswerStreamValueFrom KVExtractor `required:"true" yaml:"answerStreamValueFrom" json:"answerStreamValueFrom"` + // @Title zh-CN Redis缓存Key的前缀 + // @Description zh-CN 默认值是"higress-ai-cache:" + CacheKeyPrefix string `required:"false" yaml:"cacheKeyPrefix" json:"cacheKeyPrefix"` + // @Title zh-CN 身份解析方式 + // @Description zh-CN 默认值是"Authorization" + IdentityHeader string `required:"false" yaml:"identityHeader" json:"identityHeader"` + // @Title zh-CN 默认填充历史对话轮数 + // @Description zh-CN 默认值是 3 + FillHistoryCnt int `required:"false" yaml:"fillHistoryCnt" json:"fillHistoryCnt"` + // @Title zh-CN 缓存的过期时间 + // @Description zh-CN 单位是秒,默认值为0,即永不过期 + CacheTTL int `required:"false" yaml:"cacheTTL" json:"cacheTTL"` + redisClient wrapper.RedisClient `yaml:"-" json:"-"` +} + +type ChatHistory struct { + Role string `json:"role"` + Content string `json:"content"` +} + +func parseConfig(json gjson.Result, c *PluginConfig, log wrapper.Log) error { + c.RedisInfo.ServiceName = json.Get("redis.serviceName").String() + if c.RedisInfo.ServiceName == "" { + return errors.New("redis service name must not be empty") + } + c.RedisInfo.ServicePort = int(json.Get("redis.servicePort").Int()) + if c.RedisInfo.ServicePort == 0 { + if strings.HasSuffix(c.RedisInfo.ServiceName, ".static") { + // use default logic port which is 80 for static service + c.RedisInfo.ServicePort = 80 + } else { + c.RedisInfo.ServicePort = 6379 + } + } + c.RedisInfo.Username = json.Get("redis.username").String() + c.RedisInfo.Password = json.Get("redis.password").String() + c.RedisInfo.Timeout = int(json.Get("redis.timeout").Int()) + if c.RedisInfo.Timeout == 0 { + c.RedisInfo.Timeout = 1000 + } + c.QuestionFrom.RequestBody = "messages.@reverse.0.content" + c.AnswerValueFrom.ResponseBody = "choices.0.message.content" + c.AnswerStreamValueFrom.ResponseBody = "choices.0.delta.content" + + c.CacheKeyPrefix = json.Get("cacheKeyPrefix").String() + if c.CacheKeyPrefix == "" { + c.CacheKeyPrefix = DefaultCacheKeyPrefix + } + c.IdentityHeader = json.Get("identityHeader").String() + if c.IdentityHeader == "" { + c.IdentityHeader = "Authorization" + } + c.FillHistoryCnt = int(json.Get("fillHistoryCnt").Int()) + if c.FillHistoryCnt == 0 { + c.FillHistoryCnt = 3 + } + c.CacheTTL = int(json.Get("cacheTTL").Int()) + c.redisClient = wrapper.NewRedisClusterClient(wrapper.FQDNCluster{ + FQDN: c.RedisInfo.ServiceName, + Port: int64(c.RedisInfo.ServicePort), + }) + return c.redisClient.Init(c.RedisInfo.Username, c.RedisInfo.Password, int64(c.RedisInfo.Timeout)) +} + +func onHttpRequestHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action { + contentType, _ := proxywasm.GetHttpRequestHeader("content-type") + if !strings.Contains(contentType, "application/json") { + log.Warnf("content is not json, can't process:%s", contentType) + ctx.DontReadRequestBody() + return types.ActionContinue + } + // get identity key + identityKey, _ := proxywasm.GetHttpRequestHeader(config.IdentityHeader) + if identityKey == "" { + log.Warnf("identity key is empty") + return types.ActionContinue + } + identityKey = strings.ReplaceAll(identityKey, " ", "") + ctx.SetContext(IdentityKey, identityKey) + _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") + _ = proxywasm.RemoveHttpRequestHeader("Content-Length") + // The request has a body and requires delaying the header transmission until a cache miss occurs, + // at which point the header should be sent. + return types.HeaderStopIteration +} + +func TrimQuote(source string) string { + return strings.Trim(source, `"`) +} + +func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte, log wrapper.Log) types.Action { + bodyJson := gjson.ParseBytes(body) + if bodyJson.Get("stream").Bool() { + ctx.SetContext(StreamContextKey, struct{}{}) + } + identityKey := ctx.GetStringContext(IdentityKey, "") + err := config.redisClient.Get(config.CacheKeyPrefix+identityKey, func(response resp.Value) { + if err := response.Error(); err != nil { + log.Errorf("redis get failed, err:%v", err) + _ = proxywasm.ResumeHttpRequest() + return + } + if response.IsNull() { + log.Debugf("cache miss, identityKey:%s", identityKey) + _ = proxywasm.ResumeHttpRequest() + return + } + chatHistories := response.String() + ctx.SetContext(ChatHistories, chatHistories) + var chat []ChatHistory + err := json.Unmarshal([]byte(chatHistories), &chat) + if err != nil { + log.Errorf("unmarshal chatHistories:%s failed, err:%v", chatHistories, err) + _ = proxywasm.ResumeHttpRequest() + return + } + path := ctx.Path() + if isQueryHistory(path) { + cnt := getIntQueryParameter("cnt", path, len(chat)/2) * 2 + if cnt > len(chat) { + cnt = len(chat) + } + chat = chat[len(chat)-cnt:] + res, err := json.Marshal(chat) + if err != nil { + log.Errorf("marshal chat:%v failed, err:%v", chat, err) + _ = proxywasm.ResumeHttpRequest() + return + } + _ = proxywasm.SendHttpResponseWithDetail(200, "OK", [][2]string{{"content-type", "application/json; charset=utf-8"}}, res, -1) + return + } + question := TrimQuote(bodyJson.Get(config.QuestionFrom.RequestBody).String()) + if question == "" { + log.Debug("parse question from request body failed") + _ = proxywasm.ResumeHttpRequest() + return + } + ctx.SetContext(QuestionContextKey, question) + fillHistoryCnt := getIntQueryParameter("fill_history_cnt", path, config.FillHistoryCnt) * 2 + currJson := bodyJson.Get("messages").String() + var currMessage []ChatHistory + err = json.Unmarshal([]byte(currJson), &currMessage) + if err != nil { + log.Errorf("unmarshal currMessage:%s failed, err:%v", currJson, err) + _ = proxywasm.ResumeHttpRequest() + return + } + finalChat := fillHistory(chat, currMessage, fillHistoryCnt) + var parameter map[string]any + err = json.Unmarshal(body, ¶meter) + if err != nil { + log.Errorf("unmarshal body:%s failed, err:%v", body, err) + _ = proxywasm.ResumeHttpRequest() + return + } + parameter["messages"] = finalChat + parameterJson, err := json.Marshal(parameter) + if err != nil { + log.Errorf("marshal parameter:%v failed, err:%v", parameter, err) + _ = proxywasm.ResumeHttpRequest() + return + } + log.Infof("start to replace request body, parameter:%s", string(parameterJson)) + _ = proxywasm.ReplaceHttpRequestBody(parameterJson) + _ = proxywasm.ResumeHttpRequest() + }) + if err != nil { + log.Error("redis access failed") + return types.ActionContinue + } + return types.ActionPause +} + +func fillHistory(chat []ChatHistory, currMessage []ChatHistory, fillHistoryCnt int) []ChatHistory { + userInputCnt := 0 + for i := 0; i < len(currMessage); i++ { + if currMessage[i].Role == "user" { + userInputCnt++ + } + } + if userInputCnt > 1 { + return currMessage + } + if fillHistoryCnt > len(chat) { + fillHistoryCnt = len(chat) + } + finalChat := append(chat[len(chat)-fillHistoryCnt:], currMessage...) + return finalChat +} + +func isQueryHistory(path string) bool { + return strings.Contains(path, "ai-history/query") +} + +func getIntQueryParameter(name string, path string, defaultValue int) int { + // 解析 URL + parsedURL, err := url.ParseRequestURI(path) + if err != nil { + fmt.Println("Error parsing URL:", err) + return defaultValue + } + + // 获取查询参数 + values := parsedURL.Query() + + // 获取特定的查询参数 "defaultValue" + queryStr := values.Get(name) + if queryStr == "" { + return defaultValue + } + num, err := strconv.Atoi(queryStr) + if err != nil { + return defaultValue + } + return num +} + +func processSSEMessage(ctx wrapper.HttpContext, config PluginConfig, sseMessage string, log wrapper.Log) string { + subMessages := strings.Split(sseMessage, "\n") + var message string + for _, msg := range subMessages { + if strings.HasPrefix(msg, "data:") { + message = msg + break + } + } + if len(message) < 6 { + log.Errorf("invalid message:%s", message) + return "" + } + // skip the prefix "data:" + bodyJson := message[5:] + if gjson.Get(bodyJson, config.AnswerStreamValueFrom.ResponseBody).Exists() { + tempContentI := ctx.GetContext(AnswerContentContextKey) + if tempContentI == nil { + content := TrimQuote(gjson.Get(bodyJson, config.AnswerStreamValueFrom.ResponseBody).Raw) + ctx.SetContext(AnswerContentContextKey, content) + return content + } + append := TrimQuote(gjson.Get(bodyJson, config.AnswerStreamValueFrom.ResponseBody).Raw) + content := tempContentI.(string) + append + ctx.SetContext(AnswerContentContextKey, content) + return content + } else if gjson.Get(bodyJson, "choices.0.delta.content.tool_calls").Exists() { + // TODO: compatible with other providers + ctx.SetContext(ToolCallsContextKey, struct{}{}) + return "" + } + log.Debugf("unknown message:%s", bodyJson) + return "" +} + +func onHttpResponseHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action { + contentType, _ := proxywasm.GetHttpResponseHeader("content-type") + if strings.Contains(contentType, "text/event-stream") { + ctx.SetContext(StreamContextKey, struct{}{}) + } + return types.ActionContinue +} +func onHttpStreamResponseBody(ctx wrapper.HttpContext, config PluginConfig, chunk []byte, isLastChunk bool, log wrapper.Log) []byte { + if ctx.GetContext(ToolCallsContextKey) != nil { + // we should not cache tool call result + return chunk + } + questionI := ctx.GetContext(QuestionContextKey) + if questionI == nil { + return chunk + } + if isQueryHistory(ctx.Path()) { + return chunk + } + if !isLastChunk { + stream := ctx.GetContext(StreamContextKey) + if stream == nil { + tempContentI := ctx.GetContext(AnswerContentContextKey) + if tempContentI == nil { + ctx.SetContext(AnswerContentContextKey, chunk) + return chunk + } + tempContent := tempContentI.([]byte) + tempContent = append(tempContent, chunk...) + ctx.SetContext(AnswerContentContextKey, tempContent) + } else { + var partialMessage []byte + partialMessageI := ctx.GetContext(PartialMessageContextKey) + if partialMessageI != nil { + partialMessage = append(partialMessageI.([]byte), chunk...) + } else { + partialMessage = chunk + } + messages := strings.Split(string(partialMessage), "\n\n") + for i, msg := range messages { + if i < len(messages)-1 { + // process complete message + processSSEMessage(ctx, config, msg, log) + } + } + if !strings.HasSuffix(string(partialMessage), "\n\n") { + ctx.SetContext(PartialMessageContextKey, []byte(messages[len(messages)-1])) + } else { + ctx.SetContext(PartialMessageContextKey, nil) + } + } + return chunk + } + + stream := ctx.GetContext(StreamContextKey) + var value string + if stream == nil { + var body []byte + tempContentI := ctx.GetContext(AnswerContentContextKey) + if tempContentI != nil { + body = append(tempContentI.([]byte), chunk...) + } else { + body = chunk + } + bodyJson := gjson.ParseBytes(body) + + value = TrimQuote(bodyJson.Get(config.AnswerValueFrom.ResponseBody).Raw) + if value == "" { + log.Warnf("parse value from response body failded, body:%s", body) + return chunk + } + } else { + if len(chunk) > 0 { + var lastMessage []byte + partialMessageI := ctx.GetContext(PartialMessageContextKey) + if partialMessageI != nil { + lastMessage = append(partialMessageI.([]byte), chunk...) + } else { + lastMessage = chunk + } + if !strings.HasSuffix(string(lastMessage), "\n\n") { + log.Warnf("invalid lastMessage:%s", lastMessage) + return chunk + } + // remove the last \n\n + lastMessage = lastMessage[:len(lastMessage)-2] + value = processSSEMessage(ctx, config, string(lastMessage), log) + } else { + tempContentI := ctx.GetContext(AnswerContentContextKey) + if tempContentI == nil { + return chunk + } + value = tempContentI.(string) + } + } + saveChatHistory(ctx, config, questionI, value, log) + return chunk +} + +func saveChatHistory(ctx wrapper.HttpContext, config PluginConfig, questionI any, value string, log wrapper.Log) { + question := questionI.(string) + identityKey := ctx.GetStringContext(IdentityKey, "") + var chat []ChatHistory + chatHistories := ctx.GetStringContext(ChatHistories, "") + if chatHistories != "" { + err := json.Unmarshal([]byte(chatHistories), &chat) + if err != nil { + log.Errorf("unmarshal chatHistories:%s failed, err:%v", chatHistories, err) + return + } + } + chat = append(chat, ChatHistory{Role: "user", Content: question}) + chat = append(chat, ChatHistory{Role: "assistant", Content: value}) + if len(chat) > config.FillHistoryCnt*2 { + chat = chat[len(chat)-config.FillHistoryCnt*2:] + } + str, err := json.Marshal(chat) + if err != nil { + log.Errorf("marshal chat:%v failed, err:%v", chat, err) + return + } + log.Infof("start to Set history, identityKey:%s, chat:%s", identityKey, string(str)) + _ = config.redisClient.Set(config.CacheKeyPrefix+identityKey, string(str), nil) + if config.CacheTTL != 0 { + _ = config.redisClient.Expire(config.CacheKeyPrefix+identityKey, config.CacheTTL, nil) + } +} diff --git a/plugins/wasm-go/extensions/ai-history/main_test.go b/plugins/wasm-go/extensions/ai-history/main_test.go new file mode 100644 index 0000000000..400a924907 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-history/main_test.go @@ -0,0 +1,36 @@ +package main + +import ( + "reflect" + "testing" +) + +func TestDistinctChat(t *testing.T) { + type args struct { + chat []ChatHistory + currMessage []ChatHistory + } + firstChat := []ChatHistory{{Role: "user", Content: "userInput1"}, {Role: "assistant", Content: "assistantOutput1"}} + sendUser := []ChatHistory{{Role: "user", Content: "userInput2"}} + tests := []struct { + name string + args args + want []ChatHistory + }{ + {name: "填充历史", args: args{ + chat: append([]ChatHistory{}, firstChat...), + currMessage: append([]ChatHistory{}, sendUser...)}, + want: append(append([]ChatHistory{}, firstChat...), sendUser...)}, + {name: "无需填充", args: args{ + chat: append([]ChatHistory{}, firstChat...), + currMessage: append(append([]ChatHistory{}, firstChat...), sendUser...)}, + want: append(append([]ChatHistory{}, firstChat...), sendUser...)}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := fillHistory(tt.args.chat, tt.args.currMessage, 3); !reflect.DeepEqual(got, tt.want) { + t.Errorf("fillHistory() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/plugins/wasm-go/extensions/ai-history/option.yaml b/plugins/wasm-go/extensions/ai-history/option.yaml new file mode 100644 index 0000000000..8be6bfd1cf --- /dev/null +++ b/plugins/wasm-go/extensions/ai-history/option.yaml @@ -0,0 +1,52 @@ +# File generated by hgctl. Modify as required. + +version: 1.0.0 + +build: + # The official builder image version + builder: + go: 1.19 + tinygo: 0.28.1 + oras: 1.0.0 + # The WASM plugin project directory + input: ./ + # The output of the build products + output: + # Choose between 'files' and 'image' + type: files + # Destination address: when type=files, specify the local directory path, e.g., './out' or + # type=image, specify the remote docker repository, e.g., 'docker.io//' + dest: ./out + # The authentication configuration for pushing image to the docker repository + docker-auth: ~/.docker/config.json + # The directory for the WASM plugin configuration structure + model-dir: ./ + # The WASM plugin configuration structure name + model: PluginConfig + # Enable debug mode + debug: false + +test: + # Test environment name, that is a docker compose project name + name: wasm-test + # The output path to build products, that is the source of test configuration parameters + from-path: ./out + # The test configuration source + test-path: ./test + # Docker compose configuration, which is empty, looks for the following files from 'test-path': + # compose.yaml, compose.yml, docker-compose.yml, docker-compose.yaml + compose-file: + # Detached mode: Run containers in the background + detach: false + +install: + # The namespace of the installation + namespace: higress-system + # Use to validate WASM plugin configuration when install by yaml + spec-yaml: ./out/spec.yaml + # Installation source. Choose between 'from-yaml' and 'from-go-project' + from-yaml: ./test/plugin-conf.yaml + # If 'from-go-src' is non-empty, the output type of the build option must be 'image' + from-go-src: + # Enable debug mode + debug: false