Skip to content

Commit

Permalink
add CleanState api
Browse files Browse the repository at this point in the history
  • Loading branch information
Cyberhan123 committed Nov 20, 2023
1 parent 097340f commit 260d855
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 6 deletions.
3 changes: 1 addition & 2 deletions .idea/rwkv.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ import (
"github.com/seasonjs/rwkv"
)
func main() {
rwkv, err := NewRwkvAutoModel(RwkvOptions{
model, err := rwkv.NewRwkvAutoModel(rwkv.RwkvOptions{
MaxTokens: 500,
StopString: "\n\n",
Temperature: 0.8,
Expand All @@ -249,9 +249,9 @@ func main() {
return
}

defer rwkv.Close()
defer model.Close()

err = rwkv.LoadFromFile("./models/RWKV-novel-4-World-7B-20230810-ctx128k-ggml-f16.bin")
err = model.LoadFromFile("./models/RWKV-novel-4-World-7B-20230810-ctx128k-ggml-f16.bin")
if err != nil {
fmt.Print(err.Error())
return
Expand All @@ -275,7 +275,7 @@ func main() {
user := "Bob: 请介绍北京的旅游景点?" +
"\n\n" +
"Alice: "
ctx, err := rwkv.InitState(prompt)
ctx, err := model.InitState(prompt)

if err != nil {
fmt.Print(err.Error())
Expand Down
37 changes: 37 additions & 0 deletions rwkv.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,43 @@ func (m *RwkvModel) InitState(prompt ...string) (*RwkvState, error) {
}, nil
}

// CleanState will clean old state and set new state for new chat context state
func (s *RwkvState) CleanState(prompt ...string) (*RwkvState, error) {
if err := hasCtx(s.rwkvModel.ctx); err != nil {
return nil, err
}
if s.state != nil {
s.state = nil
}
if s.logits != nil {
s.logits = nil
}
state := make([]float32, s.rwkvModel.cRwkv.RwkvGetStateLength(s.rwkvModel.ctx))
s.rwkvModel.cRwkv.RwkvInitState(s.rwkvModel.ctx, state)
logits := make([]float32, s.rwkvModel.cRwkv.RwkvGetLogitsLength(s.rwkvModel.ctx))
p := ""
if len(prompt) > 0 {
p = prompt[0]
}
if len(p) > 0 {
startT := time.Now()
encode, err := s.rwkvModel.tokenizer.Encode(p)
for _, token := range encode {
err = s.rwkvModel.cRwkv.RwkvEval(s.rwkvModel.ctx, uint32(token), state, state, logits)
if err != nil {
return nil, err
}
}
tc := time.Since(startT)
log.Print("init state time cost: ", tc, "total tokens: ", len(encode))
}
return &RwkvState{
state: state,
rwkvModel: s.rwkvModel,
logits: logits,
}, nil
}

// Predict give current chat a response
func (s *RwkvState) Predict(input string) (string, error) {
err := s.handelInput(input)
Expand Down

0 comments on commit 260d855

Please sign in to comment.