Skip to content

Commit

Permalink
Improve WebLLM example
Browse files Browse the repository at this point in the history
  • Loading branch information
philippjfr committed Sep 12, 2024
1 parent dcf8604 commit 9807172
Showing 1 changed file with 51 additions and 18 deletions.
69 changes: 51 additions & 18 deletions examples/gallery/WebLLM.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"\n",
"from panel.custom import JSComponent, ESMEvent\n",
"\n",
"pn.extension(template='material')"
"pn.extension('mathjax', template='material')"
]
},
{
Expand All @@ -31,36 +31,48 @@
"metadata": {},
"outputs": [],
"source": [
"\n",
"MODELS = {\n",
" 'Mistral-7b-Instruct': 'Mistral-7B-Instruct-v0.3-q4f16_1-MLC',\n",
" 'SmolLM': 'SmolLM-360M-Instruct-q4f16_1-MLC',\n",
" 'Gemma-2b': 'gemma-2-2b-it-q4f16_1-MLC',\n",
" 'Llama-3.1-8b-Instruct': 'Llama-3.1-8B-Instruct-q4f32_1-MLC-1k'\n",
" 'SmolLM (130MB)': 'SmolLM-135M-Instruct-q4f16_1-MLC',\n",
" 'TinyLlama-1.1B-Chat (675 MB)': 'TinyLlama-1.1B-Chat-v1.0-q4f16_1-MLC-1k',\n",
" 'Gemma-2b (1895 MB)': 'gemma-2-2b-it-q4f16_1-MLC',\n",
" 'Mistral-7b-Instruct (4570 MB)': 'Mistral-7B-Instruct-v0.3-q4f16_1-MLC',\n",
" 'Llama-3.1-8b-Instruct (4598 MB)': 'Llama-3-8B-Instruct-q4f16_1-MLC-1k',\n",
"}\n",
"\n",
"class WebLLM(JSComponent):\n",
"\n",
" loaded = param.Boolean(default=False, doc=\"\"\"\n",
" Whether the model is loaded.\"\"\")\n",
"\n",
" model = param.Selector(default='SmolLM-360M-Instruct-q4f16_1-MLC', objects=MODELS)\n",
"\n",
" temperature = param.Number(default=1, bounds=(0, 2))\n",
" status = param.Dict(default={'text': '', 'progress': 0})\n",
"\n",
" load_model = param.Event()\n",
"\n",
" model = param.Selector(default='SmolLM-135M-Instruct-q4f16_1-MLC', objects=MODELS)\n",
"\n",
" running = param.Boolean(default=False, doc=\"\"\"\n",
" Whether the LLM is currently running.\"\"\")\n",
" \n",
" temperature = param.Number(default=1, bounds=(0, 2), doc=\"\"\"\n",
" Temperature of the model completions.\"\"\")\n",
"\n",
" _esm = \"\"\"\n",
" import * as webllm from \"https://esm.run/@mlc-ai/web-llm\";\n",
"\n",
" const engines = new Map()\n",
"\n",
" export async function render({ model }) {\n",
" model.on(\"msg:custom\", async (event) => {\n",
" console.log(event)\n",
" if (event.type === 'load') {\n",
" if (!engines.has(model.model)) {\n",
" engines.set(model.model, await webllm.CreateMLCEngine(model.model))\n",
" const initProgressCallback = (status) => {\n",
" model.status = status\n",
" }\n",
" const mlc = await webllm.CreateMLCEngine(\n",
" model.model,\n",
" {initProgressCallback}\n",
" )\n",
" engines.set(model.model, mlc)\n",
" }\n",
" model.loaded = true\n",
" } else if (event.type === 'completion') {\n",
Expand All @@ -73,7 +85,11 @@
" temperature: model.temperature ,\n",
" stream: true,\n",
" })\n",
" model.running = true\n",
" for await (const chunk of chunks) {\n",
" if (!model.running) {\n",
" break\n",
" }\n",
" model.send_msg(chunk.choices[0])\n",
" }\n",
" }\n",
Expand All @@ -83,6 +99,8 @@
"\n",
" def __init__(self, **params):\n",
" super().__init__(**params)\n",
" if pn.state.location:\n",
" pn.state.location.sync(self, {'model': 'model'})\n",
" self._buffer = []\n",
"\n",
" @param.depends('load_model', watch=True)\n",
Expand All @@ -93,14 +111,14 @@
" @param.depends('loaded', watch=True)\n",
" def _loaded(self):\n",
" self.loading = False\n",
" self.param.load_model.constant = True\n",
"\n",
" @param.depends('model', watch=True)\n",
" def _update_load_model(self):\n",
" self.param.load_model.constant = False\n",
" self.loaded = False\n",
"\n",
" def _handle_msg(self, msg):\n",
" self._buffer.insert(0, msg)\n",
" if self.running:\n",
" self._buffer.insert(0, msg)\n",
"\n",
" async def create_completion(self, msgs):\n",
" self._send_msg({'type': 'completion', 'messages': msgs})\n",
Expand All @@ -119,21 +137,34 @@
"\n",
" async def callback(self, contents: str, user: str):\n",
" if not self.loaded:\n",
" yield f'Model `{self.model}` is loading.' if self.param.load_model.constant else 'Load the model'\n",
" if self.loading:\n",
" yield pn.pane.Markdown(\n",
" f'## `{self.model}`\\n\\n' + self.param.status.rx()['text']\n",
" )\n",
" else:\n",
" yield 'Load the model'\n",
" return\n",
" self.running = False\n",
" self._buffer.clear()\n",
" message = \"\"\n",
" async for chunk in llm.create_completion([{'role': 'user', 'content': contents}]):\n",
" message += chunk['delta'].get('content', '')\n",
" yield message\n",
"\n",
" def menu(self):\n",
" status = self.param.status.rx()\n",
" return pn.Column(\n",
" pn.widgets.Select.from_param(self.param.model, sizing_mode='stretch_width'),\n",
" pn.widgets.FloatSlider.from_param(self.param.temperature, sizing_mode='stretch_width'),\n",
" pn.widgets.Button.from_param(\n",
" self.param.load_model, sizing_mode='stretch_width',\n",
" loading=self.param.loading\n",
" )\n",
" disabled=self.param.loaded.rx().rx.or_(self.param.loading)\n",
" ),\n",
" pn.indicators.Progress(\n",
" value=(status['progress']*100).rx.pipe(int), visible=self.param.loading,\n",
" sizing_mode='stretch_width'\n",
" ),\n",
" pn.pane.Markdown(status['text'], visible=self.param.loading)\n",
" )"
]
},
Expand Down Expand Up @@ -179,7 +210,9 @@
" respond=False,\n",
")\n",
"\n",
"chat_interface.servable(title='WebLLM')"
"llm.param.watch(lambda e: chat_interface.send(f'Loaded `{e.obj.model}`, start chatting!', user='System', respond=False), 'loaded')\n",
"\n",
"pn.Row(chat_interface).servable(title='WebLLM')"
]
}
],
Expand Down

0 comments on commit 9807172

Please sign in to comment.