Skip to content

Commit

Permalink
add chat model support to README.md
Browse files Browse the repository at this point in the history
  • Loading branch information
Cyberhan123 committed Oct 22, 2023
1 parent f79f19b commit fc0774c
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 5 deletions.
64 changes: 63 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func main() {
}

```
Now GPU is supported!! you can use `NewRwkvAutoModel` to set `GpuEnable`. see `AutoModel Compatibility` about gpu support.
Now GPU is supported.you can use `NewRwkvAutoModel` to set `GpuEnable`. see `AutoModel Compatibility` about gpu support.

```go
package main
Expand Down Expand Up @@ -222,6 +222,68 @@ func main() {

```

### Chat With Bot
This library can be used to chat with bot.

```go

package main
import (
"fmt"
"github.com/seasonjs/rwkv"
)
func main() {
rwkv, err := NewRwkvAutoModel(RwkvOptions{
MaxTokens: 500,
StopString: "\\n\\n",
Temperature: 0.8,
TopP: 0.5,
TokenizerType: World, //or World
PrintError: true,
CpuThreads: 10,
GpuEnable: true,
})

if err != nil {
fmt.Print(err.Error())
return
}

defer rwkv.Close()

err = rwkv.LoadFromFile("./models/RWKV-novel-4-World-7B-20230810-ctx128k-ggml-f16.bin")
if err != nil {
fmt.Print(err.Error())
return
}
prompt := "\\nThe following is a coherent verbose detailed conversation between a Chinese girl named Alice and her friend Bob." +
" Alice is very intelligent, creative and friendly." +
" Alice likes to tell Bob a lot about herself and her opinions." +
" Alice usually gives Bob kind, helpful and informative advices." +
"\\n\\nBob: lhc\\n\\nAlice: LHC是指大型强子对撞机(Large Hadron Collider),是世界最大最强的粒子加速器,由欧洲核子中心(CERN)在瑞士日内瓦地下建造。" +
"LHC的原理是加速质子(氢离子)并让它们相撞,让科学家研究基本粒子和它们之间的相互作用,并在2012年证实了希格斯玻色子的存在。" +
"\\n\\nBob: 企鹅会飞吗\\n\\nAlice: 企鹅是不会飞的。企鹅的翅膀短而扁平,更像是游泳时的一对桨。" +
"企鹅的身体结构和羽毛密度也更适合在水中游泳,而不是飞行。\\n\\n"

user := "Bob: 一加一在什么情况下等于三?" +
"\\n\\n" +
"Alice: "
ctx, err := rwkv.InitState(prompt)

if err != nil {
fmt.Print(err.Error())
}

out, err := ctx.Predict(user)

if err != nil {
fmt.Print(err.Error())
}

fmt.Print(out)
}
```

## Packaging

To ship a working program that includes this AI, you will need to include the following files:
Expand Down
4 changes: 4 additions & 0 deletions rwkv.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"log"
"os"
"strings"
"time"
)

type RwkvModel struct {
Expand Down Expand Up @@ -153,13 +154,16 @@ func (m *RwkvModel) InitState(prompt ...string) (*RwkvState, error) {
p = prompt[0]
}
if len(p) > 0 {
startT := time.Now()
encode, err := m.tokenizer.Encode(p)
for _, token := range encode {
err = m.cRwkv.RwkvEval(m.ctx, uint32(token), state, state, logits)
if err != nil {
return nil, err
}
}
tc := time.Since(startT)
log.Print("init state time cost: ", tc, "total tokens: ", len(encode))
}
return &RwkvState{
state: state,
Expand Down
65 changes: 61 additions & 4 deletions rwkv_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@ func TestRwkvModel_QuantizeModelFile(t *testing.T) {

func TestNewRwkvAutoModelGPU(t *testing.T) {
rwkv, err := NewRwkvAutoModel(RwkvOptions{
MaxTokens: 100,
StopString: "/n",
MaxTokens: 50,
StopString: "\\n\\n",
Temperature: 0.8,
TopP: 0.5,
TokenizerType: World, //or World
Expand Down Expand Up @@ -190,11 +190,13 @@ func TestNewRwkvAutoModelGPU(t *testing.T) {
}

t.Run("test predict", func(t *testing.T) {
ctx, err := rwkv.InitState("\\n我希望你能充当一个英语翻译、拼写纠正和改进助手。我会用任何语言与你交谈,你将检测语言、翻译并用修正和改进后的版本回答我,用中文表达。我希望你能用更加美丽、优雅且高级的英语词汇和句子替换我的简化A0级词汇和句子。保持意思相同,但使其更具文学性。请只回复纠正和改进的部分,不要写解释。")
ctx, err := rwkv.InitState()
if err != nil {
t.Error(err)
}
out, err := ctx.Predict("幸運を \\n\\n")
out, err := ctx.Predict("天元大陆上有五个国家,分别是北方的天金帝国,南方的华盛帝国,西方的落日帝国和东方的索域联邦,而处于四大国中央,分别和四国接壤的一片面积不大呈六角形的土地就是天元大陆上最著名的神圣教廷。" +
"四大王国中除了落日帝国和华盛帝国关系不佳以外,其他国家到是可以和平相处。" +
"每年,各个国家都要向教廷上交一定的“保护费”以作为教廷的开销。")
if err != nil {
t.Error(err.Error())
}
Expand All @@ -220,3 +222,58 @@ func TestNewRwkvAutoModelGPU(t *testing.T) {
assert(t, len(responseText) >= 0)
})
}

func TestChat(t *testing.T) {
rwkv, err := NewRwkvAutoModel(RwkvOptions{
MaxTokens: 500,
StopString: "\\n\\n",
Temperature: 0.8,
TopP: 0.5,
TokenizerType: World, //or World
PrintError: true,
CpuThreads: 10,
GpuEnable: true,
})

if err != nil {
t.Error(err)
return
}

defer func(rwkv *RwkvModel) {
err := rwkv.Close()
if err != nil {
t.Error(err)
}
}(rwkv)
err = rwkv.LoadFromFile("./models/RWKV-novel-4-World-7B-20230810-ctx128k-ggml-f16.bin")
if err != nil {
t.Error(err)
return
}
prompt := "\\nThe following is a coherent verbose detailed conversation between a Chinese girl named Alice and her friend Bob." +
" Alice is very intelligent, creative and friendly." +
" Alice likes to tell Bob a lot about herself and her opinions." +
" Alice usually gives Bob kind, helpful and informative advices." +
"\\n\\nBob: lhc\\n\\nAlice: LHC是指大型强子对撞机(Large Hadron Collider),是世界最大最强的粒子加速器,由欧洲核子中心(CERN)在瑞士日内瓦地下建造。" +
"LHC的原理是加速质子(氢离子)并让它们相撞,让科学家研究基本粒子和它们之间的相互作用,并在2012年证实了希格斯玻色子的存在。" +
"\\n\\nBob: 企鹅会飞吗\\n\\nAlice: 企鹅是不会飞的。企鹅的翅膀短而扁平,更像是游泳时的一对桨。" +
"企鹅的身体结构和羽毛密度也更适合在水中游泳,而不是飞行。\\n\\n"

user := "Bob: 一加一在什么情况下等于三?" +
"\\n\\n" +
"Alice: "
t.Run("test chat with Chinese", func(t *testing.T) {
ctx, err := rwkv.InitState(prompt)
if err != nil {
t.Error(err)
}
out, err := ctx.Predict(user)
if err != nil {
t.Error(err.Error())
}
t.Log(out)
assert(t, len(out) >= 0)
})

}

0 comments on commit fc0774c

Please sign in to comment.