diff --git a/examples/reference/chat/ChatFeed.ipynb b/examples/reference/chat/ChatFeed.ipynb index 449ba664bc..3c03d464d6 100644 --- a/examples/reference/chat/ChatFeed.ipynb +++ b/examples/reference/chat/ChatFeed.ipynb @@ -66,6 +66,7 @@ "\n", "##### Other\n", "\n", + "* **`stop`**: Cancels the current callback task if possible.\n", "* **`clear`**: Clears the chat log and returns the messages that were cleared.\n", "* **`respond`**: Executes the callback with the latest message in the chat log.\n", "* **`undo`**: Removes the last `count` of messages from the chat log and returns them. Default `count` is 1.\n", @@ -343,7 +344,7 @@ "source": [ "The `ChatFeed` also support *async* `callback`s.\n", "\n", - "In fact, we recommend using *async* `callback`s whenever possible to keep your app fast and responsive." + "In fact, we recommend using *async* `callback`s whenever possible to keep your app fast and responsive, *as long as there's nothing blocking the event loop in the function*." ] }, { @@ -373,6 +374,40 @@ "message = chat_feed.send(\"Are you a parrot?\")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Do not mark the function as async if there's something blocking your event loop--if you do, the placeholder will **not** appear." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import panel as pn\n", + "import time\n", + "pn.extension()\n", + "\n", + "async def parrot_message(contents, user, instance):\n", + " time.sleep(2.8)\n", + " return {\"value\": f\"No, {contents.lower()}\", \"user\": \"Parrot\", \"avatar\": \"🦜\"}\n", + "\n", + "chat_feed = pn.chat.ChatFeed(callback=parrot_message, callback_user=\"Echo Bot\")\n", + "chat_feed" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "message = chat_feed.send(\"Are you a parrot?\")" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/examples/reference/chat/ChatInterface.ipynb b/examples/reference/chat/ChatInterface.ipynb index 7f8dd2bf8f..f04fe70e4c 100644 --- a/examples/reference/chat/ChatInterface.ipynb +++ b/examples/reference/chat/ChatInterface.ipynb @@ -25,7 +25,7 @@ "- Remove messages until the previous `user` input [`ChatMessage`](ChatMessage.ipynb).\n", "- Clear the chat log, erasing all [`ChatMessage`](ChatMessage.ipynb) objects.\n", "\n", - "Since `ChatInterface` inherits from [`ChatFeed`](ChatFeed.ipynb), it features all the capabilities of [`ChatFeed`](ChatFeed.ipynb); please see [ChatFeed.ipynb](ChatFeed.ipynb) for its backend capabilities.\n", + "**Since `ChatInterface` inherits from [`ChatFeed`](ChatFeed.ipynb), it features all the capabilities of [`ChatFeed`](ChatFeed.ipynb); please see [ChatFeed.ipynb](ChatFeed.ipynb) for its backend capabilities.**\n", "\n", "Check out the [panel-chat-examples](https://holoviz-topics.github.io/panel-chat-examples/) docs to see applicable examples related to [LangChain](https://python.langchain.com/docs/get_started/introduction), [OpenAI](https://openai.com/blog/chatgpt), [Mistral](https://www.google.com/url?sa=t&rct=j&q=&esrc=s&source=web&cd=&ved=2ahUKEwjZtP35yvSBAxU00wIHHerUDZAQFnoECBEQAQ&url=https%3A%2F%2Fdocs.mistral.ai%2F&usg=AOvVaw2qpx09O_zOzSksgjBKiJY_&opi=89978449), [Llama](https://ai.meta.com/llama/), etc. If you have an example to demo, we'd love to add it to the panel-chat-examples gallery!\n", "\n", @@ -45,6 +45,7 @@ "##### Styling\n", "\n", "* **`show_send`** (`bool`): Whether to show the send button. Default is True.\n", + "* **`show_stop`** (`bool`): Whether to show the stop button, temporarily replacing the send button during callback; has no effect if `callback` is not async.\n", "* **`show_rerun`** (`bool`): Whether to show the rerun button. Default is True.\n", "* **`show_undo`** (`bool`): Whether to show the undo button. Default is True.\n", "* **`show_clear`** (`bool`): Whether to show the clear button. Default is True.\n", @@ -458,7 +459,10 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Check out the [panel-chat-examples](https://holoviz-topics.github.io/panel-chat-examples/) docs for more examples related to [LangChain](https://python.langchain.com/docs/get_started/introduction), [OpenAI](https://openai.com/blog/chatgpt), [Mistral](https://www.google.com/url?sa=t&rct=j&q=&esrc=s&source=web&cd=&ved=2ahUKEwjZtP35yvSBAxU00wIHHerUDZAQFnoECBEQAQ&url=https%3A%2F%2Fdocs.mistral.ai%2F&usg=AOvVaw2qpx09O_zOzSksgjBKiJY_&opi=89978449), [Llama](https://ai.meta.com/llama/), etc." + "Check out the [panel-chat-examples](https://holoviz-topics.github.io/panel-chat-examples/) docs for more examples related to [LangChain](https://python.langchain.com/docs/get_started/introduction), [OpenAI](https://openai.com/blog/chatgpt), [Mistral](https://www.google.com/url?sa=t&rct=j&q=&esrc=s&source=web&cd=&ved=2ahUKEwjZtP35yvSBAxU00wIHHerUDZAQFnoECBEQAQ&url=https%3A%2F%2Fdocs.mistral.ai%2F&usg=AOvVaw2qpx09O_zOzSksgjBKiJY_&opi=89978449), [Llama](https://ai.meta.com/llama/), etc.\n", + "\n", + "\n", + "Also, since `ChatInterface` inherits from [`ChatFeed`](ChatFeed.ipynb), be sure to also read [ChatFeed.ipynb](ChatFeed.ipynb) to understand `ChatInterface`'s full potential!" ] } ], diff --git a/panel/chat/feed.py b/panel/chat/feed.py index 4b7900b17b..a6c48b2c43 100644 --- a/panel/chat/feed.py +++ b/panel/chat/feed.py @@ -8,9 +8,9 @@ import asyncio import traceback -from inspect import ( - isasyncgen, isasyncgenfunction, isawaitable, isgenerator, -) +from enum import Enum +from functools import partial +from inspect import isasyncgen, isawaitable, isgenerator from io import BytesIO from typing import ( TYPE_CHECKING, Any, Callable, ClassVar, Dict, List, Literal, @@ -94,6 +94,18 @@ """ # noqa: E501 +class CallbackState(Enum): + IDLE = "idle" + RUNNING = "running" + GENERATING = "generating" + STOPPING = "stopping" + STOPPED = "stopped" + + +class StopCallback(Exception): + pass + + class ChatFeed(ListPanel): """ A widget to display a list of `ChatMessage` objects and interact with them. @@ -193,6 +205,12 @@ class ChatFeed(ListPanel): The placeholder wrapped in a ChatMessage object; primarily to prevent recursion error in _update_placeholder.""") + _callback_future = param.ClassSelector(class_=asyncio.Future, allow_None=True, doc=""" + The current, cancellable async task being executed.""") + + _callback_state = param.ObjectSelector(objects=list(CallbackState), doc=""" + The current state of the callback.""") + _stylesheets: ClassVar[List[str]] = [f"{CDN_DIST}css/chat_feed.css"] def __init__(self, *objects, **params): @@ -324,9 +342,13 @@ def _upsert_message( Replace the placeholder message with the response or update the message's value with the response. """ + is_stopping = self._callback_state == CallbackState.STOPPING + is_stopped = self._callback_future is not None and self._callback_future.cancelled() if value is None: # don't add new message if the callback returns None return + elif is_stopping or is_stopped: + raise StopCallback("Callback was stopped.") user = self.callback_user avatar = None @@ -356,7 +378,7 @@ def _upsert_message( self._replace_placeholder(new_message) return new_message - def _extract_contents(self, message: ChatMessage) -> Any: + def _gather_callback_args(self, message: ChatMessage) -> Any: """ Extracts the contents from the message's panel object. """ @@ -369,7 +391,7 @@ def _extract_contents(self, message: ChatMessage) -> Any: contents = value.value else: contents = value - return contents + return contents, message.user, self async def _serialize_response(self, response: Any) -> ChatMessage | None: """ @@ -378,9 +400,11 @@ async def _serialize_response(self, response: Any) -> ChatMessage | None: """ response_message = None if isasyncgen(response): + self._callback_state = CallbackState.GENERATING async for token in response: response_message = self._upsert_message(token, response_message) elif isgenerator(response): + self._callback_state = CallbackState.GENERATING for token in response: response_message = self._upsert_message(token, response_message) elif isawaitable(response): @@ -389,12 +413,6 @@ async def _serialize_response(self, response: Any) -> ChatMessage | None: response_message = self._upsert_message(response, response_message) return response_message - async def _handle_callback(self, message: ChatMessage) -> ChatMessage | None: - contents = self._extract_contents(message) - response = self.callback(contents, message.user, self) - response_message = await self._serialize_response(response) - return response_message - async def _schedule_placeholder( self, task: asyncio.Task, @@ -407,13 +425,10 @@ async def _schedule_placeholder( if self.placeholder_threshold == 0: return - callable_is_async = asyncio.iscoroutinefunction( - self.callback - ) or isasyncgenfunction(self.callback) start = asyncio.get_event_loop().time() while not task.done() and num_entries == len(self._chat_log): duration = asyncio.get_event_loop().time() - start - if duration > self.placeholder_threshold or not callable_is_async: + if duration > self.placeholder_threshold: self.append(self._placeholder) return await asyncio.sleep(0.28) @@ -428,16 +443,31 @@ async def _prepare_response(self, _) -> None: disabled = self.disabled try: - self.disabled = True + with param.parameterized.batch_call_watchers(self): + self.disabled = True + self._callback_state = CallbackState.RUNNING + message = self._chat_log[-1] if not isinstance(message, ChatMessage): return num_entries = len(self._chat_log) - task = asyncio.create_task(self._handle_callback(message)) - await self._schedule_placeholder(task, num_entries) - await task - task.result() + callback_args = self._gather_callback_args(message) + loop = asyncio.get_event_loop() + if asyncio.iscoroutinefunction(self.callback): + future = loop.create_task(self.callback(*callback_args)) + else: + future = loop.run_in_executor(None, partial(self.callback, *callback_args)) + self._callback_future = future + await self._schedule_placeholder(future, num_entries) + + if not future.cancelled(): + await future + response = future.result() + await self._serialize_response(response) + except StopCallback: + # callback was stopped by user + self._callback_state = CallbackState.STOPPED except Exception as e: send_kwargs = dict(user="Exception", respond=False) if self.callback_exception == "summary": @@ -449,8 +479,10 @@ async def _prepare_response(self, _) -> None: else: raise e finally: - self._replace_placeholder(None) - self.disabled = disabled + with param.parameterized.batch_call_watchers(self): + self._replace_placeholder(None) + self.disabled = disabled + self._callback_state = CallbackState.IDLE # Public API @@ -528,6 +560,9 @@ def stream( ------- The message that was updated. """ + if self._callback_future is not None and self._callback_future.cancelled(): + raise StopCallback("Callback was stopped.") + if isinstance(value, ChatMessage) and (user is not None or avatar is not None): raise ValueError( "Cannot set user or avatar when explicitly streaming " @@ -559,6 +594,25 @@ def respond(self): """ self._callback_trigger.param.trigger("clicks") + def stop(self) -> bool: + """ + Cancels the current callback task if possible. + + Returns + ------- + Whether the task was successfully stopped or done. + """ + if self._callback_future is None: + return False + elif self._callback_state == CallbackState.GENERATING: + # cannot cancel generator directly as it's already "finished" + # by the time cancel is called; instead, set the state to STOPPING + # and let upsert_message raise StopCallback + self._callback_state = CallbackState.STOPPING + return True + else: + return self._callback_future.cancel() + def undo(self, count: int = 1) -> List[Any]: """ Removes the last `count` of messages from the chat log and returns them. diff --git a/panel/chat/interface.py b/panel/chat/interface.py index 695b21fbd3..3ef6e68e77 100644 --- a/panel/chat/interface.py +++ b/panel/chat/interface.py @@ -22,7 +22,7 @@ from ..widgets.base import Widget from ..widgets.button import Button from ..widgets.input import FileInput, TextInput -from .feed import ChatFeed +from .feed import CallbackState, ChatFeed from .message import _FileInputMessage @@ -87,6 +87,10 @@ class ChatInterface(ChatFeed): show_send = param.Boolean(default=True, doc=""" Whether to show the send button.""") + show_stop = param.Boolean(default=True, doc=""" + Whether to show the stop button temporarily replacing the send button during + callback; has no effect if `callback` is not async.""") + show_rerun = param.Boolean(default=True, doc=""" Whether to show the rerun button.""") @@ -135,6 +139,9 @@ class ChatInterface(ChatFeed): _button_data = param.Dict(default={}, doc=""" Metadata and data related to the buttons.""") + _buttons = param.Dict(default={}, doc=""" + The rendered buttons.""") + _stylesheets: ClassVar[List[str]] = [f"{CDN_DIST}css/chat_interface.css"] def __init__(self, *objects, **params): @@ -188,6 +195,7 @@ def _init_widgets(self): """ default_button_properties = { "send": {"icon": "send", "_default_callback": self._click_send}, + "stop": {"icon": "player-stop", "_default_callback": self._click_stop}, "rerun": {"icon": "repeat", "_default_callback": self._click_rerun}, "undo": {"icon": "arrow-back", "_default_callback": self._click_undo}, "clear": {"icon": "trash", "_default_callback": self._click_clear}, @@ -270,11 +278,11 @@ def _init_widgets(self): css_classes=["chat-interface-input-widget"] ) - buttons = [] + self._buttons = {} for button_data in self._button_data.values(): action = button_data.name try: - visible = self.param[f'show_{action}'] + visible = self.param[f'show_{action}'] if action != "stop" else False except KeyError: visible = True show_expr = self.param.show_button_name.rx() @@ -288,15 +296,16 @@ def _init_widgets(self): align="start", visible=visible ) - self._link_disabled_loading(button) + if action != "stop": + self._link_disabled_loading(button) callback = partial(button_data.callback, self) button.on_click(callback) - buttons.append(button) + self._buttons[action] = button button_data.buttons.append(button) message_row = Row( widget, - *buttons, + *list(self._buttons.values()), sizing_mode="stretch_width", css_classes=["chat-interface-input-row"], stylesheets=self._stylesheets, @@ -380,6 +389,16 @@ def _click_send( self._reset_button_data() self.send(value=value, user=self.user, avatar=self.avatar, respond=True) + def _click_stop( + self, + event: param.parameterized.Event | None = None, + instance: "ChatInterface" | None = None + ) -> bool: + """ + Cancel the callback when the user presses the Stop button. + """ + return self.stop() + def _get_last_user_entry_index(self) -> int: """ Get the index of the last user message. @@ -557,3 +576,15 @@ def _serialize_for_transformers( "assistant": [self.callback_user], } return super()._serialize_for_transformers(role_names, default_role, custom_serializer) + + @param.depends("_callback_state", watch=True) + async def _update_input_disabled(self): + busy_states = (CallbackState.RUNNING, CallbackState.GENERATING) + if not self.show_stop or self._callback_state not in busy_states: + with param.parameterized.batch_call_watchers(self): + self._buttons["send"].visible = True + self._buttons["stop"].visible = False + else: + with param.parameterized.batch_call_watchers(self): + self._buttons["send"].visible = False + self._buttons["stop"].visible = True diff --git a/panel/tests/chat/test_feed.py b/panel/tests/chat/test_feed.py index d75ecefd26..841779ec4a 100644 --- a/panel/tests/chat/test_feed.py +++ b/panel/tests/chat/test_feed.py @@ -1,8 +1,6 @@ import asyncio import time -from unittest.mock import MagicMock - import pytest from panel.chat.feed import ChatFeed @@ -10,7 +8,7 @@ from panel.layout import Column, Row from panel.pane.image import Image from panel.pane.markup import HTML -from panel.tests.util import wait_until +from panel.tests.util import async_wait_until, wait_until from panel.widgets.indicators import LinearGauge from panel.widgets.input import TextAreaInput, TextInput @@ -343,7 +341,8 @@ def test_default_avatars_message_params(self, chat_feed): def test_no_recursion_error(self, chat_feed): chat_feed.send("Some time ago, there was a recursion error like this") - def test_chained_response(self, chat_feed): + @pytest.mark.asyncio + async def test_chained_response(self, chat_feed): async def callback(contents, user, instance): if user == "User": yield { @@ -363,7 +362,7 @@ async def callback(contents, user, instance): chat_feed.callback = callback chat_feed.send("Testing!", user="User") - wait_until(lambda: len(chat_feed.objects) == 3) + await async_wait_until(lambda: len(chat_feed.objects) == 3) assert chat_feed.objects[1].user == "arm" assert chat_feed.objects[1].avatar == "🦾" assert chat_feed.objects[1].object == "Hey, leg! Did you hear the user?" @@ -529,100 +528,69 @@ async def echo(contents, user, instance): def test_placeholder_disabled(self, chat_feed): def echo(contents, user, instance): - time.sleep(0.25) - yield "hey testing" + time.sleep(1.25) + assert instance._placeholder not in instance._chat_log + return "hey testing" chat_feed.placeholder_threshold = 0 chat_feed.callback = echo - chat_feed.append = MagicMock( - side_effect=lambda message: chat_feed._chat_log.append(message) - ) chat_feed.send("Message", respond=True) - # only append sent message - assert chat_feed.append.call_count == 2 + assert chat_feed._placeholder not in chat_feed._chat_log def test_placeholder_enabled(self, chat_feed): def echo(contents, user, instance): - time.sleep(0.25) - yield "hey testing" + time.sleep(1.25) + assert instance._placeholder in instance._chat_log + return chat_feed.stream("hey testing") chat_feed.callback = echo - chat_feed.append = MagicMock( - side_effect=lambda message: chat_feed._chat_log.append(message) - ) chat_feed.send("Message", respond=True) + assert chat_feed._placeholder not in chat_feed._chat_log # append sent message and placeholder - assert chat_feed.append.call_args_list[1].args[0] == chat_feed._placeholder def test_placeholder_threshold_under(self, chat_feed): async def echo(contents, user, instance): await asyncio.sleep(0.25) + assert instance._placeholder not in instance._chat_log return "hey testing" chat_feed.placeholder_threshold = 5 chat_feed.callback = echo - chat_feed.append = MagicMock( - side_effect=lambda message: chat_feed._chat_log.append(message) - ) chat_feed.send("Message", respond=True) - assert chat_feed.append.call_args_list[1].args[0] != chat_feed._placeholder + assert chat_feed._placeholder not in chat_feed._chat_log def test_placeholder_threshold_under_generator(self, chat_feed): async def echo(contents, user, instance): + assert instance._placeholder not in instance._chat_log await asyncio.sleep(0.25) + assert instance._placeholder not in instance._chat_log yield "hey testing" chat_feed.placeholder_threshold = 5 chat_feed.callback = echo - chat_feed.append = MagicMock( - side_effect=lambda message: chat_feed._chat_log.append(message) - ) chat_feed.send("Message", respond=True) - assert chat_feed.append.call_args_list[1].args[0] != chat_feed._placeholder def test_placeholder_threshold_exceed(self, chat_feed): async def echo(contents, user, instance): await asyncio.sleep(0.5) - yield "hello testing" + assert instance._placeholder in instance._chat_log + return "hello testing" chat_feed.placeholder_threshold = 0.1 chat_feed.callback = echo - chat_feed.append = MagicMock( - side_effect=lambda message: chat_feed._chat_log.append(message) - ) chat_feed.send("Message", respond=True) - assert chat_feed.append.call_args_list[1].args[0] == chat_feed._placeholder + assert chat_feed._placeholder not in chat_feed._chat_log def test_placeholder_threshold_exceed_generator(self, chat_feed): async def echo(contents, user, instance): await asyncio.sleep(0.5) + assert instance._placeholder in instance._chat_log yield "hello testing" chat_feed.placeholder_threshold = 0.1 chat_feed.callback = echo - chat_feed.append = MagicMock( - side_effect=lambda message: chat_feed._chat_log.append(message) - ) - chat_feed.send("Message", respond=True) - assert chat_feed.append.call_args_list[1].args[0] == chat_feed._placeholder - - def test_placeholder_threshold_sync(self, chat_feed): - """ - Placeholder should always be appended if the - callback is synchronous. - """ - - def echo(contents, user, instance): - time.sleep(0.25) - yield "hey testing" - - chat_feed.placeholder_threshold = 5 - chat_feed.callback = echo - chat_feed.append = MagicMock( - side_effect=lambda message: chat_feed._chat_log.append(message) - ) chat_feed.send("Message", respond=True) - assert chat_feed.append.call_args_list[1].args[0] == chat_feed._placeholder + assert chat_feed._placeholder not in chat_feed._chat_log def test_renderers_pane(self, chat_feed): chat_feed.renderers = [HTML] @@ -699,6 +667,48 @@ def callback(msg, user, instance): chat_feed.send("Message", respond=True) wait_until(lambda: len(chat_feed.objects) == 1) + def test_callback_stop_async_generator(self, chat_feed): + async def callback(msg, user, instance): + yield "A" + assert chat_feed.stop() + await asyncio.sleep(0.5) + yield "B" + + chat_feed.callback = callback + chat_feed.send("Message", respond=True) + # use sleep here instead of wait for because + # the callback is timed and I want to confirm stop works + time.sleep(1) + assert chat_feed.objects[-1].object == "A" + + def test_callback_stop_async_function(self, chat_feed): + async def callback(msg, user, instance): + message = instance.stream("A") + assert chat_feed.stop() + await asyncio.sleep(0.5) + instance.stream("B", message=message) + + chat_feed.callback = callback + chat_feed.send("Message", respond=True) + # use sleep here instead of wait for because + # the callback is timed and I want to confirm stop works + time.sleep(1) + assert chat_feed.objects[-1].object == "A" + + def test_callback_stop_sync_function(self, chat_feed): + def callback(msg, user, instance): + message = instance.stream("A") + assert chat_feed.stop() + time.sleep(0.5) + instance.stream("B", message=message) # should not reach this point + + chat_feed.callback = callback + chat_feed.send("Message", respond=True) + # use sleep here instead of wait for because + # the callback is timed and I want to confirm stop works + time.sleep(1) + assert chat_feed.objects[-1].object == "A" + @pytest.mark.xdist_group("chat") class TestChatFeedSerializeForTransformers: diff --git a/panel/tests/chat/test_interface.py b/panel/tests/chat/test_interface.py index dc749bb214..ce22fd30e5 100644 --- a/panel/tests/chat/test_interface.py +++ b/panel/tests/chat/test_interface.py @@ -1,6 +1,3 @@ - - - from io import BytesIO import pytest @@ -19,7 +16,7 @@ def chat_interface(self): return ChatInterface() def test_init(self, chat_interface): - assert len(chat_interface._button_data) == 4 + assert len(chat_interface._button_data) == 5 assert len(chat_interface._widgets) == 1 assert isinstance(chat_interface._input_layout, Row) assert isinstance(chat_interface._widgets["TextInput"], TextInput) @@ -86,7 +83,51 @@ def test_click_send(self, chat_interface: ChatInterface): assert len(chat_interface.objects) == 0 chat_interface._click_send(None) assert len(chat_interface.objects) == 1 - assert chat_interface.objects[0].object == "Message" + + def test_show_stop_disabled(self, chat_interface: ChatInterface): + async def callback(msg, user, instance): + yield "A" + send_button = chat_interface._input_layout[1] + stop_button = chat_interface._input_layout[2] + assert send_button.name == "Send" + assert stop_button.name == "Stop" + assert send_button.visible + assert not stop_button.visible + yield "B" # should not stream this + + chat_interface.callback = callback + chat_interface.show_stop = False + chat_interface.send("Message", respond=True) + send_button = chat_interface._input_layout[1] + stop_button = chat_interface._input_layout[2] + assert send_button.name == "Send" + assert stop_button.name == "Stop" + assert send_button.visible + assert not stop_button.visible + + def test_show_stop_for_async(self, chat_interface: ChatInterface): + async def callback(msg, user, instance): + send_button = instance._input_layout[1] + stop_button = instance._input_layout[2] + assert send_button.name == "Send" + assert stop_button.name == "Stop" + assert not send_button.visible + assert stop_button.visible + + chat_interface.callback = callback + chat_interface.send("Message", respond=True) + + def test_show_stop_for_sync(self, chat_interface: ChatInterface): + def callback(msg, user, instance): + send_button = instance._input_layout[1] + stop_button = instance._input_layout[2] + assert send_button.name == "Send" + assert stop_button.name == "Stop" + assert not send_button.visible + assert stop_button.visible + + chat_interface.callback = callback + chat_interface.send("Message", respond=True) @pytest.mark.parametrize("widget", [TextInput(), TextAreaInput()]) def test_auto_send_types(self, chat_interface: ChatInterface, widget): diff --git a/panel/tests/util.py b/panel/tests/util.py index 91e8718414..86f927e5a8 100644 --- a/panel/tests/util.py +++ b/panel/tests/util.py @@ -1,3 +1,4 @@ +import asyncio import contextlib import os import platform @@ -192,6 +193,77 @@ def timed_out(): time.sleep(interval / 1000) +async def async_wait_until(fn, page=None, timeout=5000, interval=100): + """ + Exercise a test function in a loop until it evaluates to True + or times out. + + The function can either be a simple lambda that returns True or False: + >>> await async_wait_until(lambda: x.values() == ['x']) + + Or a defined function with an assert: + >>> async def _() + >>> assert x.values() == ['x'] + >>> await async_wait_until(_) + + In a Playwright context test, you should pass the page fixture: + >>> await async_wait_until(lambda: x.values() == ['x'], page) + + Parameters + ---------- + fn : callable + Callback + page : playwright.async_api.Page, optional + Playwright page + timeout : int, optional + Total timeout in milliseconds, by default 5000 + interval : int, optional + Waiting interval, by default 100 + + Adapted from pytest-qt. + """ + # Hide this function traceback from the pytest output if the test fails + __tracebackhide__ = True + + start = time.time() + + def timed_out(): + elapsed = time.time() - start + elapsed_ms = elapsed * 1000 + return elapsed_ms > timeout + + timeout_msg = f"wait_until timed out in {timeout} milliseconds" + + while True: + try: + result = fn() + if asyncio.iscoroutine(result): + result = await result + except AssertionError as e: + if timed_out(): + raise TimeoutError(timeout_msg) from e + else: + if result not in (None, True, False): + raise ValueError( + "`wait_until` callback must return None, True, or " + f"False, returned {result!r}" + ) + # None is returned when the function has an assert + if result is None: + return + # When the function returns True or False + if result: + return + if timed_out(): + raise TimeoutError(timeout_msg) + if page: + # Playwright recommends against using time.sleep + # https://playwright.dev/python/docs/intro#timesleep-leads-to-outdated-state + await page.wait_for_timeout(interval) + else: + await asyncio.sleep(interval / 1000) + + def get_ctrl_modifier(): """ Get the CTRL modifier on the current platform.