From 6680e8ccc319bd865cdca966c792aa0a50e684c4 Mon Sep 17 00:00:00 2001 From: danemadsen Date: Tue, 23 Apr 2024 21:27:20 +1000 Subject: [PATCH] work on simplifying chat node logic --- lib/providers/session.dart | 33 +------ lib/ui/mobile/pages/home_page.dart | 26 ++++-- .../widgets/chat_widgets/chat_field.dart | 6 +- .../widgets/chat_widgets/chat_message.dart | 90 +++++-------------- lib/ui/mobile/widgets/home_drawer.dart | 4 +- .../mobile/widgets/session_busy_overlay.dart | 2 +- lib/ui/mobile/widgets/tiles/session_tile.dart | 2 +- packages/maid_llm | 2 +- 8 files changed, 50 insertions(+), 115 deletions(-) diff --git a/lib/providers/session.dart b/lib/providers/session.dart index 200e374b..253fd00c 100644 --- a/lib/providers/session.dart +++ b/lib/providers/session.dart @@ -17,20 +17,16 @@ import 'package:shared_preferences/shared_preferences.dart'; class Session extends ChangeNotifier { Key _key = UniqueKey(); - bool _busy = false; LargeLanguageModel model = LlamaCppModel(); ChatNodeTree chat = ChatNodeTree(); String _name = ""; - bool get isBusy => _busy; - String get name => _name; Key get key => _key; set busy(bool value) { - _busy = value; notifyListeners(); } @@ -114,9 +110,6 @@ class Session extends ChangeNotifier { } void prompt(BuildContext context) async { - _busy = true; - notifyListeners(); - final user = context.read(); final character = context.read(); @@ -136,12 +129,8 @@ class Session extends ChangeNotifier { final stringStream = model.prompt(messages); - await for (var message in stringStream) { - stream(message); - } + await chat.tail.streamIn(stringStream); - _busy = false; - finalise(); notifyListeners(); } @@ -171,31 +160,11 @@ class Session extends ChangeNotifier { void stop() { (model as LlamaCppModel).stop(); - _busy = false; Logger.log('Local generation stopped'); - finalise(); notifyListeners(); } - void stream(String? message) async { - if (message == null) { - finalise(); - } else { - chat.buffer += message; - - if (!(chat.tail.messageController.isClosed)) { - chat.tail.messageController.add(chat.buffer); - chat.buffer = ""; - } - - chat.tail.content += message; - } - } - void finalise() { - _busy = false; - - chat.tail.messageController.close(); SharedPreferences.getInstance().then((prefs) { prefs.setString("last_session", json.encode(toMap())); diff --git a/lib/ui/mobile/pages/home_page.dart b/lib/ui/mobile/pages/home_page.dart index 8fa62527..81d59914 100644 --- a/lib/ui/mobile/pages/home_page.dart +++ b/lib/ui/mobile/pages/home_page.dart @@ -1,3 +1,4 @@ +import 'dart:convert'; import 'dart:math'; import 'package:flutter/material.dart'; @@ -11,6 +12,7 @@ import 'package:maid/ui/mobile/widgets/chat_widgets/chat_field.dart'; import 'package:maid/ui/mobile/widgets/appbars/home_app_bar.dart'; import 'package:maid/ui/mobile/widgets/home_drawer.dart'; import 'package:provider/provider.dart'; +import 'package:shared_preferences/shared_preferences.dart'; class HomePage extends StatefulWidget { final String title; @@ -36,23 +38,29 @@ class HomePageState extends State { Widget _buildBody() { return Consumer3( builder: (context, session, user, character, child) { - Map history = session.chat.getHistory(); - if (history.isEmpty && character.useGreeting) { + SharedPreferences.getInstance().then((prefs) { + prefs.setString("last_session", json.encode(session.toMap())); + }); + + List chat = session.chat.getChat(); + if (chat.isEmpty && character.useGreeting) { final newKey = UniqueKey(); final index = Random().nextInt(character.greetings.length); - session.chat.add( - newKey, + + final message = ChatNode( + key: newKey, + role: ChatRole.assistant, content: Utilities.formatPlaceholders(character.greetings[index], user.name, character.name), - role: ChatRole.assistant ); - history = {newKey: ChatRole.assistant}; + + session.chat.addNode(message); + chat = [message]; } chatWidgets.clear(); - for (var key in history.keys) { + for (final message in chat) { chatWidgets.add(ChatMessage( - key: key, - role: history[key] ?? ChatRole.assistant, + node: message, )); } diff --git a/lib/ui/mobile/widgets/chat_widgets/chat_field.dart b/lib/ui/mobile/widgets/chat_widgets/chat_field.dart index cc11972a..ec50a1ec 100644 --- a/lib/ui/mobile/widgets/chat_widgets/chat_field.dart +++ b/lib/ui/mobile/widgets/chat_widgets/chat_field.dart @@ -84,7 +84,7 @@ class _ChatFieldState extends State { padding: const EdgeInsets.all(8.0), child: Row( children: [ - if (session.isBusy && + if (!session.chat.tail.finalised && session.model.type != LargeLanguageModelType.ollama) IconButton( onPressed: session.stop, @@ -109,14 +109,14 @@ class _ChatFieldState extends State { ), IconButton( onPressed: () { - if (!session.isBusy) { + if (session.chat.tail.finalised ) { send(); } }, iconSize: 50, icon: Icon( Icons.arrow_circle_right, - color: session.isBusy + color: !session.chat.tail.finalised ? Theme.of(context).colorScheme.onPrimary : Theme.of(context).colorScheme.secondary, )), diff --git a/lib/ui/mobile/widgets/chat_widgets/chat_message.dart b/lib/ui/mobile/widgets/chat_widgets/chat_message.dart index 22f7d8e4..84628c8b 100644 --- a/lib/ui/mobile/widgets/chat_widgets/chat_message.dart +++ b/lib/ui/mobile/widgets/chat_widgets/chat_message.dart @@ -1,5 +1,5 @@ import 'package:flutter/material.dart'; -import 'package:maid_llm/src/chat_node.dart'; +import 'package:maid_llm/maid_llm.dart'; import 'package:maid/providers/character.dart'; import 'package:maid/providers/session.dart'; import 'package:maid/providers/user.dart'; @@ -8,54 +8,20 @@ import 'package:maid_ui/maid_ui.dart'; import 'package:provider/provider.dart'; class ChatMessage extends StatefulWidget { - final ChatRole role; + final ChatNode node; const ChatMessage({ - required super.key, - this.role = ChatRole.assistant, + super.key, + required this.node, }); @override ChatMessageState createState() => ChatMessageState(); } -class ChatMessageState extends State - with SingleTickerProviderStateMixin { - late Session session; - final TextEditingController _messageController = TextEditingController(); - String _message = ""; - bool _finalised = false; +class ChatMessageState extends State with SingleTickerProviderStateMixin { bool _editing = false; - @override - void initState() { - super.initState(); - session = context.read(); - - if (session.chat.messageOf(widget.key!).isNotEmpty) { - _message = session.chat.messageOf(widget.key!); - _finalised = true; - } else { - session.chat.getMessageStream(widget.key!).stream.listen((textChunk) { - setState(() { - _message += textChunk; - }); - }).onDone(() { - _message = _message.trim(); - - session.chat.add( - widget.key!, - content: _message, - role: widget.role - ); - - session.notify(); - - _finalised = true; - }); - } - } - Widget _messageBuilder(String message) { List widgets = []; List parts = message.split('```'); @@ -91,9 +57,8 @@ class ChatMessageState extends State Widget build(BuildContext context) { return Consumer3( builder: (context, session, user, character, child) { - int currentIndex = session.chat.indexOf(widget.key!); - int siblingCount = session.chat.siblingCountOf(widget.key!); - bool busy = session.isBusy; + int currentIndex = session.chat.indexOf(widget.node.key); + int siblingCount = session.chat.siblingCountOf(widget.node.key); return Column(crossAxisAlignment: CrossAxisAlignment.start, children: [ Row( @@ -101,7 +66,7 @@ class ChatMessageState extends State children: [ const SizedBox(width: 10.0), FutureAvatar( - image: widget.role == ChatRole.user ? user.profile : character.profile, + image: widget.node.role == ChatRole.user ? user.profile : character.profile, radius: 16, ), const SizedBox(width: 10.0), @@ -118,7 +83,7 @@ class ChatMessageState extends State blendMode: BlendMode .srcIn, // This blend mode applies the shader to the text color. child: Text( - widget.role == ChatRole.user ? user.name : character.name, + widget.node.role == ChatRole.user ? user.name : character.name, style: const TextStyle( fontWeight: FontWeight.normal, color: Colors @@ -128,7 +93,7 @@ class ChatMessageState extends State ), ), const Expanded(child: SizedBox()), // Spacer - if (_finalised) ..._messageOptions(), + if (widget.node.finalised) ..._messageOptions(), Row( mainAxisSize: MainAxisSize.max, mainAxisAlignment: MainAxisAlignment.spaceEvenly, @@ -136,8 +101,8 @@ class ChatMessageState extends State IconButton( padding: const EdgeInsets.all(0), onPressed: () { - if (busy) return; - session.chat.last(widget.key!); + if (!session.chat.tail.finalised) return; + session.chat.last(widget.node.key); session.notify(); }, icon: Icon(Icons.arrow_left, @@ -147,8 +112,8 @@ class ChatMessageState extends State IconButton( padding: const EdgeInsets.all(0), onPressed: () { - if (busy) return; - session.chat.next(widget.key!); + if (!session.chat.tail.finalised) return; + session.chat.next(widget.node.key); session.notify(); }, icon: Icon(Icons.arrow_right, @@ -171,18 +136,16 @@ class ChatMessageState extends State } List _messageOptions() { - return widget.role == ChatRole.user ? _userOptions() : _assistantOptions(); + return widget.node.role == ChatRole.user ? _userOptions() : _assistantOptions(); } List _userOptions() { return [ IconButton( onPressed: () { - if (session.isBusy) return; + if (!context.read().chat.tail.finalised) return; setState(() { - _messageController.text = _message; _editing = true; - _finalised = false; }); }, icon: const Icon(Icons.edit), @@ -194,8 +157,8 @@ class ChatMessageState extends State return [ IconButton( onPressed: () { - if (session.isBusy) return; - session.regenerate(widget.key!, context); + if (!context.read().chat.tail.finalised) return; + context.read().regenerate(widget.node.key, context); setState(() {}); }, icon: const Icon(Icons.refresh), @@ -204,11 +167,11 @@ class ChatMessageState extends State } List _editingColumn() { - final busy = context.watch().isBusy; + final messageController = TextEditingController(text: widget.node.content); return [ TextField( - controller: _messageController, + controller: messageController, autofocus: true, cursorColor: Theme.of(context).colorScheme.secondary, style: Theme.of(context).textTheme.bodyMedium, @@ -224,23 +187,18 @@ class ChatMessageState extends State IconButton( padding: const EdgeInsets.all(0), onPressed: () { - if (busy) return; - final inputMessage = _messageController.text; + if (!context.watch().chat.tail.finalised) return; setState(() { - _messageController.text = _message; _editing = false; - _finalised = true; }); - session.edit(widget.key!, inputMessage, context); + context.read().edit(widget.node.key, messageController.text, context); }, icon: const Icon(Icons.done)), IconButton( padding: const EdgeInsets.all(0), onPressed: () { setState(() { - _messageController.text = _message; _editing = false; - _finalised = true; }); }, icon: const Icon(Icons.close)) @@ -250,10 +208,10 @@ class ChatMessageState extends State List _standardColumn() { return [ - if (!_finalised && _message.isEmpty) + if (!widget.node.finalised && widget.node.content.isEmpty) const TypingIndicator() // Assuming TypingIndicator is a custom widget you've defined. else - _messageBuilder(_message), + _messageBuilder(widget.node.content), ]; } } diff --git a/lib/ui/mobile/widgets/home_drawer.dart b/lib/ui/mobile/widgets/home_drawer.dart index 245b345c..bf1ba072 100644 --- a/lib/ui/mobile/widgets/home_drawer.dart +++ b/lib/ui/mobile/widgets/home_drawer.dart @@ -123,7 +123,7 @@ class _HomeDrawerState extends State { ), FilledButton( onPressed: () { - if (session.isBusy) return; + if (!session.chat.tail.finalised) return; setState(() { final newSession = Session(); sessions.add(newSession); @@ -144,7 +144,7 @@ class _HomeDrawerState extends State { return SessionTile( session: sessions[index], onDelete: () { - if (session.isBusy) return; + if (!session.chat.tail.finalised) return; setState(() { if (sessions[index].key == session.key) { session.from(sessions.firstOrNull ?? Session()); diff --git a/lib/ui/mobile/widgets/session_busy_overlay.dart b/lib/ui/mobile/widgets/session_busy_overlay.dart index 88229567..8225b1eb 100644 --- a/lib/ui/mobile/widgets/session_busy_overlay.dart +++ b/lib/ui/mobile/widgets/session_busy_overlay.dart @@ -17,7 +17,7 @@ class _SessionBusyOverlayState extends State { return Stack( children: [ widget.child, - if (context.watch().isBusy) + if (!context.watch().chat.tail.finalised) Positioned.fill( child: Container( color: Colors.black.withOpacity(0.4), diff --git a/lib/ui/mobile/widgets/tiles/session_tile.dart b/lib/ui/mobile/widgets/tiles/session_tile.dart index 07d7c10c..7aac590b 100644 --- a/lib/ui/mobile/widgets/tiles/session_tile.dart +++ b/lib/ui/mobile/widgets/tiles/session_tile.dart @@ -27,7 +27,7 @@ class _SessionTileState extends State { onSecondaryTapUp: _onSecondaryTapUp, onLongPressStart: _onLongPressStart, onTap: () { - if (session.isBusy) return; + if (!session.chat.tail.finalised) return; session.from(widget.session); }, child: ListTile( diff --git a/packages/maid_llm b/packages/maid_llm index 2c063e67..1cf79956 160000 --- a/packages/maid_llm +++ b/packages/maid_llm @@ -1 +1 @@ -Subproject commit 2c063e678c2943671b55a935fdfcbdd95795cbb3 +Subproject commit 1cf799561d988b6e99e3c02f370d77de1e1c4e8b