From 679df8c654a04e58bf1b52a47e29e31195917e03 Mon Sep 17 00:00:00 2001 From: Dane Madsen Date: Wed, 24 Apr 2024 15:01:07 +1000 Subject: [PATCH] chat streaming --- lib/providers/session.dart | 7 ++- lib/ui/mobile/pages/home_page.dart | 7 +-- .../widgets/chat_widgets/chat_message.dart | 60 +++++++++---------- packages/maid_llm | 2 +- 4 files changed, 40 insertions(+), 36 deletions(-) diff --git a/lib/providers/session.dart b/lib/providers/session.dart index 253fd00c..cc609188 100644 --- a/lib/providers/session.dart +++ b/lib/providers/session.dart @@ -129,7 +129,12 @@ class Session extends ChangeNotifier { final stringStream = model.prompt(messages); - await chat.tail.streamIn(stringStream); + await for (var message in stringStream) { + chat.tail.content += message; + notifyListeners(); + } + + chat.tail.finalised = true; notifyListeners(); } diff --git a/lib/ui/mobile/pages/home_page.dart b/lib/ui/mobile/pages/home_page.dart index 81d59914..8db6a75e 100644 --- a/lib/ui/mobile/pages/home_page.dart +++ b/lib/ui/mobile/pages/home_page.dart @@ -2,7 +2,7 @@ import 'dart:convert'; import 'dart:math'; import 'package:flutter/material.dart'; -import 'package:maid_llm/src/chat_node.dart'; +import 'package:maid_llm/maid_llm.dart'; import 'package:maid/providers/user.dart'; import 'package:maid/providers/character.dart'; import 'package:maid/providers/session.dart'; @@ -44,11 +44,10 @@ class HomePageState extends State { List chat = session.chat.getChat(); if (chat.isEmpty && character.useGreeting) { - final newKey = UniqueKey(); final index = Random().nextInt(character.greetings.length); final message = ChatNode( - key: newKey, + key: UniqueKey(), role: ChatRole.assistant, content: Utilities.formatPlaceholders(character.greetings[index], user.name, character.name), ); @@ -60,7 +59,7 @@ class HomePageState extends State { chatWidgets.clear(); for (final message in chat) { chatWidgets.add(ChatMessage( - node: message, + key: message.key, )); } diff --git a/lib/ui/mobile/widgets/chat_widgets/chat_message.dart b/lib/ui/mobile/widgets/chat_widgets/chat_message.dart index 84628c8b..f707760d 100644 --- a/lib/ui/mobile/widgets/chat_widgets/chat_message.dart +++ b/lib/ui/mobile/widgets/chat_widgets/chat_message.dart @@ -8,21 +8,19 @@ import 'package:maid_ui/maid_ui.dart'; import 'package:provider/provider.dart'; class ChatMessage extends StatefulWidget { - final ChatNode node; - const ChatMessage({ - super.key, - required this.node, + required super.key, }); @override - ChatMessageState createState() => ChatMessageState(); + State createState() => _ChatMessageState(); } -class ChatMessageState extends State with SingleTickerProviderStateMixin { - bool _editing = false; +class _ChatMessageState extends State with SingleTickerProviderStateMixin { + late ChatNode node; + bool editing = false; - Widget _messageBuilder(String message) { + Widget messageBuilder(String message) { List widgets = []; List parts = message.split('```'); @@ -57,8 +55,10 @@ class ChatMessageState extends State with SingleTickerProviderState Widget build(BuildContext context) { return Consumer3( builder: (context, session, user, character, child) { - int currentIndex = session.chat.indexOf(widget.node.key); - int siblingCount = session.chat.siblingCountOf(widget.node.key); + node = session.chat.find(widget.key!)!; + + int currentIndex = session.chat.indexOf(widget.key!); + int siblingCount = session.chat.siblingCountOf(widget.key!); return Column(crossAxisAlignment: CrossAxisAlignment.start, children: [ Row( @@ -66,7 +66,7 @@ class ChatMessageState extends State with SingleTickerProviderState children: [ const SizedBox(width: 10.0), FutureAvatar( - image: widget.node.role == ChatRole.user ? user.profile : character.profile, + image: node.role == ChatRole.user ? user.profile : character.profile, radius: 16, ), const SizedBox(width: 10.0), @@ -83,7 +83,7 @@ class ChatMessageState extends State with SingleTickerProviderState blendMode: BlendMode .srcIn, // This blend mode applies the shader to the text color. child: Text( - widget.node.role == ChatRole.user ? user.name : character.name, + node.role == ChatRole.user ? user.name : character.name, style: const TextStyle( fontWeight: FontWeight.normal, color: Colors @@ -93,7 +93,7 @@ class ChatMessageState extends State with SingleTickerProviderState ), ), const Expanded(child: SizedBox()), // Spacer - if (widget.node.finalised) ..._messageOptions(), + if (node.finalised) ...messageOptions(), Row( mainAxisSize: MainAxisSize.max, mainAxisAlignment: MainAxisAlignment.spaceEvenly, @@ -102,7 +102,7 @@ class ChatMessageState extends State with SingleTickerProviderState padding: const EdgeInsets.all(0), onPressed: () { if (!session.chat.tail.finalised) return; - session.chat.last(widget.node.key); + session.chat.last(node.key); session.notify(); }, icon: Icon(Icons.arrow_left, @@ -113,7 +113,7 @@ class ChatMessageState extends State with SingleTickerProviderState padding: const EdgeInsets.all(0), onPressed: () { if (!session.chat.tail.finalised) return; - session.chat.next(widget.node.key); + session.chat.next(node.key); session.notify(); }, icon: Icon(Icons.arrow_right, @@ -128,24 +128,24 @@ class ChatMessageState extends State with SingleTickerProviderState padding: const EdgeInsets.fromLTRB(20, 10, 20, 10), child: Column( crossAxisAlignment: CrossAxisAlignment.start, - children: _editing ? _editingColumn() : _standardColumn(), + children: editing ? editingColumn() : standardColumn(), )) ]); }, ); } - List _messageOptions() { - return widget.node.role == ChatRole.user ? _userOptions() : _assistantOptions(); + List messageOptions() { + return node.role == ChatRole.user ? userOptions() : assistantOptions(); } - List _userOptions() { + List userOptions() { return [ IconButton( onPressed: () { if (!context.read().chat.tail.finalised) return; setState(() { - _editing = true; + editing = true; }); }, icon: const Icon(Icons.edit), @@ -153,12 +153,12 @@ class ChatMessageState extends State with SingleTickerProviderState ]; } - List _assistantOptions() { + List assistantOptions() { return [ IconButton( onPressed: () { if (!context.read().chat.tail.finalised) return; - context.read().regenerate(widget.node.key, context); + context.read().regenerate(node.key, context); setState(() {}); }, icon: const Icon(Icons.refresh), @@ -166,8 +166,8 @@ class ChatMessageState extends State with SingleTickerProviderState ]; } - List _editingColumn() { - final messageController = TextEditingController(text: widget.node.content); + List editingColumn() { + final messageController = TextEditingController(text: node.content); return [ TextField( @@ -189,16 +189,16 @@ class ChatMessageState extends State with SingleTickerProviderState onPressed: () { if (!context.watch().chat.tail.finalised) return; setState(() { - _editing = false; + editing = false; }); - context.read().edit(widget.node.key, messageController.text, context); + context.read().edit(node.key, messageController.text, context); }, icon: const Icon(Icons.done)), IconButton( padding: const EdgeInsets.all(0), onPressed: () { setState(() { - _editing = false; + editing = false; }); }, icon: const Icon(Icons.close)) @@ -206,12 +206,12 @@ class ChatMessageState extends State with SingleTickerProviderState ]; } - List _standardColumn() { + List standardColumn() { return [ - if (!widget.node.finalised && widget.node.content.isEmpty) + if (!node.finalised && node.content.isEmpty) const TypingIndicator() // Assuming TypingIndicator is a custom widget you've defined. else - _messageBuilder(widget.node.content), + messageBuilder(node.content), ]; } } diff --git a/packages/maid_llm b/packages/maid_llm index 1cf79956..33d0e6a3 160000 --- a/packages/maid_llm +++ b/packages/maid_llm @@ -1 +1 @@ -Subproject commit 1cf799561d988b6e99e3c02f370d77de1e1c4e8b +Subproject commit 33d0e6a3f0cef240bb05f8fd4a5f9afa08fe0868