Skip to content

Commit

Permalink
Allow stopping respond callbacks midway (#5962)
Browse files Browse the repository at this point in the history
  • Loading branch information
ahuang11 authored Dec 10, 2023
1 parent e792464 commit 4b36d7f
Show file tree
Hide file tree
Showing 7 changed files with 336 additions and 89 deletions.
37 changes: 36 additions & 1 deletion examples/reference/chat/ChatFeed.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
"\n",
"##### Other\n",
"\n",
"* **`stop`**: Cancels the current callback task if possible.\n",
"* **`clear`**: Clears the chat log and returns the messages that were cleared.\n",
"* **`respond`**: Executes the callback with the latest message in the chat log.\n",
"* **`undo`**: Removes the last `count` of messages from the chat log and returns them. Default `count` is 1.\n",
Expand Down Expand Up @@ -343,7 +344,7 @@
"source": [
"The `ChatFeed` also support *async* `callback`s.\n",
"\n",
"In fact, we recommend using *async* `callback`s whenever possible to keep your app fast and responsive."
"In fact, we recommend using *async* `callback`s whenever possible to keep your app fast and responsive, *as long as there's nothing blocking the event loop in the function*."
]
},
{
Expand Down Expand Up @@ -373,6 +374,40 @@
"message = chat_feed.send(\"Are you a parrot?\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Do not mark the function as async if there's something blocking your event loop--if you do, the placeholder will **not** appear."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import panel as pn\n",
"import time\n",
"pn.extension()\n",
"\n",
"async def parrot_message(contents, user, instance):\n",
" time.sleep(2.8)\n",
" return {\"value\": f\"No, {contents.lower()}\", \"user\": \"Parrot\", \"avatar\": \"🦜\"}\n",
"\n",
"chat_feed = pn.chat.ChatFeed(callback=parrot_message, callback_user=\"Echo Bot\")\n",
"chat_feed"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"message = chat_feed.send(\"Are you a parrot?\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
8 changes: 6 additions & 2 deletions examples/reference/chat/ChatInterface.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"- Remove messages until the previous `user` input [`ChatMessage`](ChatMessage.ipynb).\n",
"- Clear the chat log, erasing all [`ChatMessage`](ChatMessage.ipynb) objects.\n",
"\n",
"Since `ChatInterface` inherits from [`ChatFeed`](ChatFeed.ipynb), it features all the capabilities of [`ChatFeed`](ChatFeed.ipynb); please see [ChatFeed.ipynb](ChatFeed.ipynb) for its backend capabilities.\n",
"**Since `ChatInterface` inherits from [`ChatFeed`](ChatFeed.ipynb), it features all the capabilities of [`ChatFeed`](ChatFeed.ipynb); please see [ChatFeed.ipynb](ChatFeed.ipynb) for its backend capabilities.**\n",
"\n",
"Check out the [panel-chat-examples](https://holoviz-topics.github.io/panel-chat-examples/) docs to see applicable examples related to [LangChain](https://python.langchain.com/docs/get_started/introduction), [OpenAI](https://openai.com/blog/chatgpt), [Mistral](https://www.google.com/url?sa=t&rct=j&q=&esrc=s&source=web&cd=&ved=2ahUKEwjZtP35yvSBAxU00wIHHerUDZAQFnoECBEQAQ&url=https%3A%2F%2Fdocs.mistral.ai%2F&usg=AOvVaw2qpx09O_zOzSksgjBKiJY_&opi=89978449), [Llama](https://ai.meta.com/llama/), etc. If you have an example to demo, we'd love to add it to the panel-chat-examples gallery!\n",
"\n",
Expand All @@ -45,6 +45,7 @@
"##### Styling\n",
"\n",
"* **`show_send`** (`bool`): Whether to show the send button. Default is True.\n",
"* **`show_stop`** (`bool`): Whether to show the stop button, temporarily replacing the send button during callback; has no effect if `callback` is not async.\n",
"* **`show_rerun`** (`bool`): Whether to show the rerun button. Default is True.\n",
"* **`show_undo`** (`bool`): Whether to show the undo button. Default is True.\n",
"* **`show_clear`** (`bool`): Whether to show the clear button. Default is True.\n",
Expand Down Expand Up @@ -458,7 +459,10 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Check out the [panel-chat-examples](https://holoviz-topics.github.io/panel-chat-examples/) docs for more examples related to [LangChain](https://python.langchain.com/docs/get_started/introduction), [OpenAI](https://openai.com/blog/chatgpt), [Mistral](https://www.google.com/url?sa=t&rct=j&q=&esrc=s&source=web&cd=&ved=2ahUKEwjZtP35yvSBAxU00wIHHerUDZAQFnoECBEQAQ&url=https%3A%2F%2Fdocs.mistral.ai%2F&usg=AOvVaw2qpx09O_zOzSksgjBKiJY_&opi=89978449), [Llama](https://ai.meta.com/llama/), etc."
"Check out the [panel-chat-examples](https://holoviz-topics.github.io/panel-chat-examples/) docs for more examples related to [LangChain](https://python.langchain.com/docs/get_started/introduction), [OpenAI](https://openai.com/blog/chatgpt), [Mistral](https://www.google.com/url?sa=t&rct=j&q=&esrc=s&source=web&cd=&ved=2ahUKEwjZtP35yvSBAxU00wIHHerUDZAQFnoECBEQAQ&url=https%3A%2F%2Fdocs.mistral.ai%2F&usg=AOvVaw2qpx09O_zOzSksgjBKiJY_&opi=89978449), [Llama](https://ai.meta.com/llama/), etc.\n",
"\n",
"\n",
"Also, since `ChatInterface` inherits from [`ChatFeed`](ChatFeed.ipynb), be sure to also read [ChatFeed.ipynb](ChatFeed.ipynb) to understand `ChatInterface`'s full potential!"
]
}
],
Expand Down
98 changes: 76 additions & 22 deletions panel/chat/feed.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import asyncio
import traceback

from inspect import (
isasyncgen, isasyncgenfunction, isawaitable, isgenerator,
)
from enum import Enum
from functools import partial
from inspect import isasyncgen, isawaitable, isgenerator
from io import BytesIO
from typing import (
TYPE_CHECKING, Any, Callable, ClassVar, Dict, List, Literal,
Expand Down Expand Up @@ -94,6 +94,18 @@
""" # noqa: E501


class CallbackState(Enum):
IDLE = "idle"
RUNNING = "running"
GENERATING = "generating"
STOPPING = "stopping"
STOPPED = "stopped"


class StopCallback(Exception):
pass


class ChatFeed(ListPanel):
"""
A widget to display a list of `ChatMessage` objects and interact with them.
Expand Down Expand Up @@ -193,6 +205,12 @@ class ChatFeed(ListPanel):
The placeholder wrapped in a ChatMessage object;
primarily to prevent recursion error in _update_placeholder.""")

_callback_future = param.ClassSelector(class_=asyncio.Future, allow_None=True, doc="""
The current, cancellable async task being executed.""")

_callback_state = param.ObjectSelector(objects=list(CallbackState), doc="""
The current state of the callback.""")

_stylesheets: ClassVar[List[str]] = [f"{CDN_DIST}css/chat_feed.css"]

def __init__(self, *objects, **params):
Expand Down Expand Up @@ -324,9 +342,13 @@ def _upsert_message(
Replace the placeholder message with the response or update
the message's value with the response.
"""
is_stopping = self._callback_state == CallbackState.STOPPING
is_stopped = self._callback_future is not None and self._callback_future.cancelled()
if value is None:
# don't add new message if the callback returns None
return
elif is_stopping or is_stopped:
raise StopCallback("Callback was stopped.")

user = self.callback_user
avatar = None
Expand Down Expand Up @@ -356,7 +378,7 @@ def _upsert_message(
self._replace_placeholder(new_message)
return new_message

def _extract_contents(self, message: ChatMessage) -> Any:
def _gather_callback_args(self, message: ChatMessage) -> Any:
"""
Extracts the contents from the message's panel object.
"""
Expand All @@ -369,7 +391,7 @@ def _extract_contents(self, message: ChatMessage) -> Any:
contents = value.value
else:
contents = value
return contents
return contents, message.user, self

async def _serialize_response(self, response: Any) -> ChatMessage | None:
"""
Expand All @@ -378,9 +400,11 @@ async def _serialize_response(self, response: Any) -> ChatMessage | None:
"""
response_message = None
if isasyncgen(response):
self._callback_state = CallbackState.GENERATING
async for token in response:
response_message = self._upsert_message(token, response_message)
elif isgenerator(response):
self._callback_state = CallbackState.GENERATING
for token in response:
response_message = self._upsert_message(token, response_message)
elif isawaitable(response):
Expand All @@ -389,12 +413,6 @@ async def _serialize_response(self, response: Any) -> ChatMessage | None:
response_message = self._upsert_message(response, response_message)
return response_message

async def _handle_callback(self, message: ChatMessage) -> ChatMessage | None:
contents = self._extract_contents(message)
response = self.callback(contents, message.user, self)
response_message = await self._serialize_response(response)
return response_message

async def _schedule_placeholder(
self,
task: asyncio.Task,
Expand All @@ -407,13 +425,10 @@ async def _schedule_placeholder(
if self.placeholder_threshold == 0:
return

callable_is_async = asyncio.iscoroutinefunction(
self.callback
) or isasyncgenfunction(self.callback)
start = asyncio.get_event_loop().time()
while not task.done() and num_entries == len(self._chat_log):
duration = asyncio.get_event_loop().time() - start
if duration > self.placeholder_threshold or not callable_is_async:
if duration > self.placeholder_threshold:
self.append(self._placeholder)
return
await asyncio.sleep(0.28)
Expand All @@ -428,16 +443,31 @@ async def _prepare_response(self, _) -> None:

disabled = self.disabled
try:
self.disabled = True
with param.parameterized.batch_call_watchers(self):
self.disabled = True
self._callback_state = CallbackState.RUNNING

message = self._chat_log[-1]
if not isinstance(message, ChatMessage):
return

num_entries = len(self._chat_log)
task = asyncio.create_task(self._handle_callback(message))
await self._schedule_placeholder(task, num_entries)
await task
task.result()
callback_args = self._gather_callback_args(message)
loop = asyncio.get_event_loop()
if asyncio.iscoroutinefunction(self.callback):
future = loop.create_task(self.callback(*callback_args))
else:
future = loop.run_in_executor(None, partial(self.callback, *callback_args))
self._callback_future = future
await self._schedule_placeholder(future, num_entries)

if not future.cancelled():
await future
response = future.result()
await self._serialize_response(response)
except StopCallback:
# callback was stopped by user
self._callback_state = CallbackState.STOPPED
except Exception as e:
send_kwargs = dict(user="Exception", respond=False)
if self.callback_exception == "summary":
Expand All @@ -449,8 +479,10 @@ async def _prepare_response(self, _) -> None:
else:
raise e
finally:
self._replace_placeholder(None)
self.disabled = disabled
with param.parameterized.batch_call_watchers(self):
self._replace_placeholder(None)
self.disabled = disabled
self._callback_state = CallbackState.IDLE

# Public API

Expand Down Expand Up @@ -528,6 +560,9 @@ def stream(
-------
The message that was updated.
"""
if self._callback_future is not None and self._callback_future.cancelled():
raise StopCallback("Callback was stopped.")

if isinstance(value, ChatMessage) and (user is not None or avatar is not None):
raise ValueError(
"Cannot set user or avatar when explicitly streaming "
Expand Down Expand Up @@ -559,6 +594,25 @@ def respond(self):
"""
self._callback_trigger.param.trigger("clicks")

def stop(self) -> bool:
"""
Cancels the current callback task if possible.
Returns
-------
Whether the task was successfully stopped or done.
"""
if self._callback_future is None:
return False
elif self._callback_state == CallbackState.GENERATING:
# cannot cancel generator directly as it's already "finished"
# by the time cancel is called; instead, set the state to STOPPING
# and let upsert_message raise StopCallback
self._callback_state = CallbackState.STOPPING
return True
else:
return self._callback_future.cancel()

def undo(self, count: int = 1) -> List[Any]:
"""
Removes the last `count` of messages from the chat log and returns them.
Expand Down
43 changes: 37 additions & 6 deletions panel/chat/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ..widgets.base import Widget
from ..widgets.button import Button
from ..widgets.input import FileInput, TextInput
from .feed import ChatFeed
from .feed import CallbackState, ChatFeed
from .message import _FileInputMessage


Expand Down Expand Up @@ -87,6 +87,10 @@ class ChatInterface(ChatFeed):
show_send = param.Boolean(default=True, doc="""
Whether to show the send button.""")

show_stop = param.Boolean(default=True, doc="""
Whether to show the stop button temporarily replacing the send button during
callback; has no effect if `callback` is not async.""")

show_rerun = param.Boolean(default=True, doc="""
Whether to show the rerun button.""")

Expand Down Expand Up @@ -135,6 +139,9 @@ class ChatInterface(ChatFeed):
_button_data = param.Dict(default={}, doc="""
Metadata and data related to the buttons.""")

_buttons = param.Dict(default={}, doc="""
The rendered buttons.""")

_stylesheets: ClassVar[List[str]] = [f"{CDN_DIST}css/chat_interface.css"]

def __init__(self, *objects, **params):
Expand Down Expand Up @@ -188,6 +195,7 @@ def _init_widgets(self):
"""
default_button_properties = {
"send": {"icon": "send", "_default_callback": self._click_send},
"stop": {"icon": "player-stop", "_default_callback": self._click_stop},
"rerun": {"icon": "repeat", "_default_callback": self._click_rerun},
"undo": {"icon": "arrow-back", "_default_callback": self._click_undo},
"clear": {"icon": "trash", "_default_callback": self._click_clear},
Expand Down Expand Up @@ -270,11 +278,11 @@ def _init_widgets(self):
css_classes=["chat-interface-input-widget"]
)

buttons = []
self._buttons = {}
for button_data in self._button_data.values():
action = button_data.name
try:
visible = self.param[f'show_{action}']
visible = self.param[f'show_{action}'] if action != "stop" else False
except KeyError:
visible = True
show_expr = self.param.show_button_name.rx()
Expand All @@ -288,15 +296,16 @@ def _init_widgets(self):
align="start",
visible=visible
)
self._link_disabled_loading(button)
if action != "stop":
self._link_disabled_loading(button)
callback = partial(button_data.callback, self)
button.on_click(callback)
buttons.append(button)
self._buttons[action] = button
button_data.buttons.append(button)

message_row = Row(
widget,
*buttons,
*list(self._buttons.values()),
sizing_mode="stretch_width",
css_classes=["chat-interface-input-row"],
stylesheets=self._stylesheets,
Expand Down Expand Up @@ -380,6 +389,16 @@ def _click_send(
self._reset_button_data()
self.send(value=value, user=self.user, avatar=self.avatar, respond=True)

def _click_stop(
self,
event: param.parameterized.Event | None = None,
instance: "ChatInterface" | None = None
) -> bool:
"""
Cancel the callback when the user presses the Stop button.
"""
return self.stop()

def _get_last_user_entry_index(self) -> int:
"""
Get the index of the last user message.
Expand Down Expand Up @@ -557,3 +576,15 @@ def _serialize_for_transformers(
"assistant": [self.callback_user],
}
return super()._serialize_for_transformers(role_names, default_role, custom_serializer)

@param.depends("_callback_state", watch=True)
async def _update_input_disabled(self):
busy_states = (CallbackState.RUNNING, CallbackState.GENERATING)
if not self.show_stop or self._callback_state not in busy_states:
with param.parameterized.batch_call_watchers(self):
self._buttons["send"].visible = True
self._buttons["stop"].visible = False
else:
with param.parameterized.batch_call_watchers(self):
self._buttons["send"].visible = False
self._buttons["stop"].visible = True
Loading

0 comments on commit 4b36d7f

Please sign in to comment.