diff --git a/.idea/rwkv.iml b/.idea/rwkv.iml index 0795325..97376f8 100644 --- a/.idea/rwkv.iml +++ b/.idea/rwkv.iml @@ -2,7 +2,7 @@ - + @@ -12,6 +12,5 @@ - \ No newline at end of file diff --git a/README.md b/README.md index 9ce72e5..7d3298a 100644 --- a/README.md +++ b/README.md @@ -233,7 +233,7 @@ import ( "github.com/seasonjs/rwkv" ) func main() { - rwkv, err := NewRwkvAutoModel(RwkvOptions{ + model, err := rwkv.NewRwkvAutoModel(rwkv.RwkvOptions{ MaxTokens: 500, StopString: "\n\n", Temperature: 0.8, @@ -249,9 +249,9 @@ func main() { return } - defer rwkv.Close() + defer model.Close() - err = rwkv.LoadFromFile("./models/RWKV-novel-4-World-7B-20230810-ctx128k-ggml-f16.bin") + err = model.LoadFromFile("./models/RWKV-novel-4-World-7B-20230810-ctx128k-ggml-f16.bin") if err != nil { fmt.Print(err.Error()) return @@ -275,7 +275,7 @@ func main() { user := "Bob: 请介绍北京的旅游景点?" + "\n\n" + "Alice: " - ctx, err := rwkv.InitState(prompt) + ctx, err := model.InitState(prompt) if err != nil { fmt.Print(err.Error()) diff --git a/rwkv.go b/rwkv.go index 164f22d..cfa7c18 100644 --- a/rwkv.go +++ b/rwkv.go @@ -172,6 +172,43 @@ func (m *RwkvModel) InitState(prompt ...string) (*RwkvState, error) { }, nil } +// CleanState will clean old state and set new state for new chat context state +func (s *RwkvState) CleanState(prompt ...string) (*RwkvState, error) { + if err := hasCtx(s.rwkvModel.ctx); err != nil { + return nil, err + } + if s.state != nil { + s.state = nil + } + if s.logits != nil { + s.logits = nil + } + state := make([]float32, s.rwkvModel.cRwkv.RwkvGetStateLength(s.rwkvModel.ctx)) + s.rwkvModel.cRwkv.RwkvInitState(s.rwkvModel.ctx, state) + logits := make([]float32, s.rwkvModel.cRwkv.RwkvGetLogitsLength(s.rwkvModel.ctx)) + p := "" + if len(prompt) > 0 { + p = prompt[0] + } + if len(p) > 0 { + startT := time.Now() + encode, err := s.rwkvModel.tokenizer.Encode(p) + for _, token := range encode { + err = s.rwkvModel.cRwkv.RwkvEval(s.rwkvModel.ctx, uint32(token), state, state, logits) + if err != nil { + return nil, err + } + } + tc := time.Since(startT) + log.Print("init state time cost: ", tc, "total tokens: ", len(encode)) + } + return &RwkvState{ + state: state, + rwkvModel: s.rwkvModel, + logits: logits, + }, nil +} + // Predict give current chat a response func (s *RwkvState) Predict(input string) (string, error) { err := s.handelInput(input)