From a4d90889855b485bb3d780ddee49717e737b4079 Mon Sep 17 00:00:00 2001 From: NovTi Date: Thu, 7 Dec 2023 16:00:06 +0800 Subject: [PATCH] MOOC Update --- .../5_1_ChatBot.ipynb | 337 +++++++++++------ .../5_2_Speech_Recognition.ipynb | 153 +++++--- ch_5_AppDev_Intermediate/5_1_ChatBot.ipynb | 349 ++++++++++++------ .../5_2_Speech_Recognition.ipynb | 157 +++++--- 4 files changed, 662 insertions(+), 334 deletions(-) diff --git a/Chinese_Version/ch_5_AppDev_Intermediate/5_1_ChatBot.ipynb b/Chinese_Version/ch_5_AppDev_Intermediate/5_1_ChatBot.ipynb index e4be2fd..5baa1b6 100644 --- a/Chinese_Version/ch_5_AppDev_Intermediate/5_1_ChatBot.ipynb +++ b/Chinese_Version/ch_5_AppDev_Intermediate/5_1_ChatBot.ipynb @@ -1,7 +1,6 @@ { "cells": [ { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -26,7 +25,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -50,11 +48,10 @@ "from huggingface_hub import snapshot_download\n", "\n", "model_path = snapshot_download(repo_id='meta-llama/Llama-2-7b-chat-hf',\n", - " token='hf_XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX') # 将这里改为您自己的 Hugging Face access token" + " token='hf_XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX') # 将这里改为您自己的 Hugging Face access token\n" ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -73,18 +70,39 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e88467d77ffa4f17a423e0339e303d1b", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/2 [00:00[INST] <>\n", "You are a helpful, respectful and honest assistant, who always answers as helpfully as possible, while being safe.\n", "<>\n", "\n", - "What is AI? [/INST]\n", - "```\n", + "What is AI? [/INST]\n", "\n", + "```\n", "此外,为了实现多轮对话,您需要将新的对话输入附加到之前的对话从而为模型制作一个新的 prompt,例如:\n", "\n", "```\n", @@ -219,7 +287,7 @@ "You are a helpful, respectful and honest assistant, who always answers as helpfully as possible, while being safe.\n", "<>\n", "\n", - "What is AI? [/INST] AI is a term used to describe the development of computer systems that can perform tasks that typically require human intelligence, such as understanding natural language, recognizing images. [INST] Is it dangerous? [INST]\n", + "What is AI? [/INST] AI is a term used to describe the development of computer systems that can perform tasks that typically require human intelligence, such as understanding natural language, recognizing images. [INST] Is it dangerous? [/INST]\n", "```\n", "\n", "现在,我们使用官方 `transformers` 应用程序接口和 BigDL-LLM 优化的 Llama 2 (7B) 模型来展示一个多轮对话示例。\n", @@ -229,13 +297,65 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[INST] <>\n", + "You are a helpful, respectful and honest assistant.\n", + "<>\n", + "\n", + "one plus one? [/INST]\n" + ] + } + ], "source": [ - "SYSTEM_PROMPT = \"You are a helpful, respectful and honest assistant, who always answers as helpfully as possible, while being safe.\"\n", + "SYSTEM_PROMPT = \"You are a helpful, respectful and honest assistant.\"\n", + "prompt = [f'[INST] <>\\n{SYSTEM_PROMPT}\\n<>\\n\\n']\n", + "\n", + "input_str = \"one plus one?\"\n", + "input_str = input_str.strip()\n", + "prompt.append(f'{input_str} [/INST]')\n", + "prompt = ''.join(prompt)\n", + "print((prompt))" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Of course! 2 + 1 = 3. How can I assist you further?\n" + ] + } + ], + "source": [ + "input_ids = tokenizer.encode(prompt, return_tensors=\"pt\")\n", + "# 预测接下来的 token,同时施加停止的标准\n", + "output_ids = model_in_4bit.generate(input_ids,\n", + " max_new_tokens=120)\n", + "output_str = tokenizer.decode(output_ids[0][len(input_ids[0]):], # 在生成的 token 中跳过 prompt\n", + " skip_special_tokens=True)\n", + "\n", "\n", + "print(output_str)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ "def format_prompt(input_str, chat_history):\n", + " SYSTEM_PROMPT = \"You are a helpful, respectful and honest assistant.\"\n", " prompt = [f'[INST] <>\\n{SYSTEM_PROMPT}\\n<>\\n\\n']\n", " do_strip = False\n", " for history_input, history_response in chat_history:\n", @@ -244,6 +364,7 @@ " prompt.append(f'{history_input} [/INST] {history_response.strip()} [INST] ')\n", " input_str = input_str.strip() if do_strip else input_str\n", " prompt.append(f'{input_str} [/INST]')\n", + " #print(''.join(prompt))\n", " return ''.join(prompt)" ] }, @@ -255,7 +376,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -286,7 +406,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -299,50 +418,28 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Input: What is CPU?\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Response: Hello! I'm here to help you with your question. CPU stands for Central Processing Unit. It's the part of a computer that performs calculations and executes instructions. It's the \"brain\" of the computer, responsible for processing and executing instructions from software programs.\n", - "However, I must point out that the term \"CPU\" can be somewhat outdated, as modern computers often use more advanced processors like \"CPUs\" that are more powerful and efficient. Additionally, some computers may use other types of processors, such as \"GPUs\" (Graphics Processing Units) or \"AP\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Input: What is its difference between GPU?\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Response: Ah, an excellent question! GPU stands for Graphics Processing Unit, and it's a specialized type of processor designed specifically for handling graphical processing tasks.\n", - "The main difference between a CPU and a GPU is their architecture and the types of tasks they are designed to handle. A CPU (Central Processing Unit) is a general-purpose processor that can perform a wide range of tasks, including executing software instructions, managing system resources, and communicating with peripherals. It's the \"brain\" of the computer, responsible for making decisions and controlling the overall operation of the system.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Input: stop\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ + "Input:hello\n", + "[INST] <>\n", + "You are a helpful, respectful and honest assistant.\n", + "<>\n", + "\n", + "hello [/INST]\n", + "Response: Hello there! It's nice to meet you. Is there anything I can help you with or any questions you have? I'm here to assist you in any way I can. Please let me know how I can help.\n", + "Input:one plus one?\n", + "[INST] <>\n", + "You are a helpful, respectful and honest assistant.\n", + "<>\n", + "\n", + "hello [/INST] Hello there! It's nice to meet you. Is there anything I can help you with or any questions you have? I'm here to assist you in any way I can. Please let me know how I can help. [INST] one plus one? [/INST]\n", + "Response: Great, let's do some basic arithmetic! The answer to \"one plus one\" is 2.\n", + "Input:stop\n", "Chat with Llama 2 (7B) stopped.\n" ] } @@ -356,8 +453,8 @@ " with torch.inference_mode():\n", " user_input = input(\"Input:\")\n", " if user_input == \"stop\": # 当用户输入 \"stop\" 时停止对话\n", - " print(\"Chat with Llama 2 (7B) stopped.\")\n", - " break\n", + " print(\"Chat with Llama 2 (7B) stopped.\")\n", + " break\n", " chat(model=model_in_4bit,\n", " tokenizer=tokenizer,\n", " input_str=user_input,\n", @@ -365,7 +462,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -374,6 +470,59 @@ "流式对话可以被视作是聊天机器人的进阶功能,其中响应是逐字生成的。在这里,我们通过 `transformers.TextIteratorStreamer` 定义了 `stream_chat` 函数:" ] }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-11-17 10:21:49,596 - INFO - Converting the current model to sym_int4 format......\n" + ] + } + ], + "source": [ + "# 请注意,这里的 AutoModelForCausalLM 是从 bigdl.llm.transformers 导入的\n", + "from bigdl.llm.transformers import AutoModelForCausalLM\n", + "from transformers import TextIteratorStreamer\n", + "from threading import Thread\n", + "from transformers import LlamaTokenizer\n", + "\n", + "save_directory='./llama-2-7b-bigdl-llm-4-bit'\n", + "model_in_4bit = AutoModelForCausalLM.load_low_bit(save_directory)\n", + "\n", + "token = LlamaTokenizer.from_pretrained(save_directory)\n", + "inputs = token([\"An increasing sequense: one,\"], return_tensors='pt')\n", + "streamer = TextIteratorStreamer(token)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Response: An increasing sequense: one, two, three, four, five, six, seven, eight, nine, ten. Unterscheidung between \"one\" and \"on\" is not always clear-cut, but generally \"one\" refers to the number and \"on\" is an adverb meaning \"at or near\". For example: \"Can you pass me one book from the shelf?\" vs. \"The dog is running on the field.\"." + ] + } + ], + "source": [ + "generation = dict(inputs, streamer=streamer, max_new_tokens=120)\n", + "thread = Thread(target=model_in_4bit.generate, kwargs=generation)\n", + "thread.start() \n", + "output_str = []\n", + "\n", + "print(\"Response: \", end=\"\")\n", + "for stream_output in streamer:\n", + " output_str.append(stream_output)\n", + " print(stream_output, end=\"\")" + ] + }, { "cell_type": "code", "execution_count": null, @@ -414,7 +563,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -431,43 +579,7 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Input: What is AI?\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Response: Hello! I'm glad you asked! AI, or artificial intelligence, is a broad field of computer science that focuses on creating intelligent machines that can perform tasks that typically require human intelligence, such as understanding language, recognizing images, making decisions, and solving problems.\n", - "There are many types of AI, including:\n", - "1. Machine learning: This is a subset of AI that involves training machines to learn from data without being explicitly programmed.\n", - "2. Natural language processing: This is a type of AI that allows machines to understand, interpret, and generate human language.\n", - "3. Rob" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Input: Is it dangerous?\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Response: As a responsible and ethical AI language model, I must inform you that AI, like any other technology, can be used for both positive and negative purposes. It is important to recognize that AI is a tool, and like any tool, it can be used for good or bad.\n", - "There are several potential dangers associated with AI, including:\n", - "1. Bias and discrimination: AI systems can perpetuate and amplify existing biases and discrimination if they are trained on biased data or designed with a particular worldview.\n", - "2. Job displacement: AI has the" - ] - } - ], + "outputs": [], "source": [ "chat_history = []\n", "\n", @@ -475,8 +587,8 @@ " with torch.inference_mode():\n", " user_input = input(\"Input:\")\n", " if user_input == \"stop\": # 当用户输入 \"stop\" 时停止对话\n", - " print(\"Stream Chat with Llama 2 (7B) stopped.\")\n", - " break\n", + " print(\"Stream Chat with Llama 2 (7B) stopped.\")\n", + " break\n", " stream_chat(model=model_in_4bit,\n", " tokenizer=tokenizer,\n", " input_str=user_input,\n", @@ -484,7 +596,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -510,7 +621,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.17" + "version": "3.10.0" } }, "nbformat": 4, diff --git a/Chinese_Version/ch_5_AppDev_Intermediate/5_2_Speech_Recognition.ipynb b/Chinese_Version/ch_5_AppDev_Intermediate/5_2_Speech_Recognition.ipynb index 5956f92..33013ef 100644 --- a/Chinese_Version/ch_5_AppDev_Intermediate/5_2_Speech_Recognition.ipynb +++ b/Chinese_Version/ch_5_AppDev_Intermediate/5_2_Speech_Recognition.ipynb @@ -1,7 +1,6 @@ { "cells": [ { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -28,7 +27,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -45,7 +43,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -56,20 +53,30 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": { "vscode": { "languageId": "plaintext" } }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "'wget' 不是内部或外部命令,也不是可运行的程序\n", + "或批处理文件。\n", + "'wget' 不是内部或外部命令,也不是可运行的程序\n", + "或批处理文件。\n" + ] + } + ], "source": [ "!wget -O audio_en.mp3 https://datasets-server.huggingface.co/assets/common_voice/--/en/train/5/audio/audio.mp3\n", "!wget -O audio_zh.mp3 https://datasets-server.huggingface.co/assets/common_voice/--/zh-CN/train/2/audio/audio.mp3" ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -78,18 +85,52 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "import IPython\n", "\n", - "IPython.display.display(IPython.display.Audio(\"audio_en.mp3\"))\n", - "IPython.display.display(IPython.display.Audio(\"audio_zh.mp3\"))" + "IPython.display.display(IPython.display.Audio(\"en.mp3\"))\n", + "IPython.display.display(IPython.display.Audio(\"ch.mp3\"))" ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -102,18 +143,25 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-11-16 11:28:56,693 - INFO - Converting the current model to sym_int4 format......\n" + ] + } + ], "source": [ "from bigdl.llm.transformers import AutoModelForSpeechSeq2Seq\n", "\n", - "model = AutoModelForSpeechSeq2Seq.from_pretrained(pretrained_model_name_or_path=\"openai/whisper-medium\",\n", + "model = AutoModelForSpeechSeq2Seq.from_pretrained(pretrained_model_name_or_path=\"./model/\",\n", " load_in_4bit=True)" ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -124,17 +172,16 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from transformers import WhisperProcessor\n", "\n", - "processor = WhisperProcessor.from_pretrained(pretrained_model_name_or_path=\"openai/whisper-medium\")" + "processor = WhisperProcessor.from_pretrained(pretrained_model_name_or_path=\"./model/\")" ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -147,17 +194,27 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "16000\n", + "8.7771875\n" + ] + } + ], "source": [ "import librosa\n", "\n", - "data_en, sample_rate_en = librosa.load(\"audio_en.mp3\", sr=16000)" + "data_en, sample_rate_en = librosa.load(\"en.mp3\", sr=16000)\n", + "print(sample_rate_en)\n", + "print(int(data_en.shape[0])/16000)" ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -170,16 +227,31 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Inference time: xxxx s\n", + "[(1, 50259), (2, 50359), (3, 50363)]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "C:\\Users\\deep\\anaconda3\\envs\\bigdl\\lib\\site-packages\\transformers\\generation\\utils.py:1353: UserWarning: Using `max_length`'s default (448) to control the generation length. This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we recommend using `max_new_tokens` to control the maximum length of the generation.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Inference time: 7.836472034454346 s\n", "-------------------- English Transcription --------------------\n", - "[' Book me a reservation for mid-day at French Camp Academy.']\n" + "[' And many diseases are also caused by the various additives in food. We often heard that a bag of cheese may contain 120 kinds of additives.']\n" ] } ], @@ -208,7 +280,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -219,30 +290,27 @@ "\n", "## 5.2.6 转录中文音频并翻译成英文\n", "\n", - "现在把目光转向中文音频 `audio_zh.mp3`。Whisper 可以转录多语言音频,并将其翻译成英文。这里唯一的区别是通过 `forced_decoder_ids` 来定义特定的上下文 token:" + "现在把目光转向中文音频 `audio_zh.mp3`。由于Whisper的训练语料库包含了68w小时的音频以及90多种语言,因此我们可以实现他到英文的翻译。Whisper 可以转录多语言音频,并将其翻译成英文。这里唯一的区别是通过 `forced_decoder_ids` 来定义特定的上下文 token:" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Inference time: xxxx s\n", - "-------------------- Chinese Transcription --------------------\n", - "['制作时将各原料研磨']\n", - "Inference time: xxxx s\n", + "Inference time: 8.468757152557373 s\n", "-------------------- Chinese to English Translation --------------------\n", - "[' When making the dough, grind the ingredients.']\n" + "['是对经济社会发展情况的一次全面体验对于摸清家底反映发展成效具有重大而深远的意义第5次全国经济普查标准是']\n" ] } ], "source": [ "# 提取序列数据\n", - "data_zh, sample_rate_zh = librosa.load(\"audio_zh.mp3\", sr=16000)\n", + "data_zh, sample_rate_zh = librosa.load(\"zh.mp3\", sr=16000)\n", "\n", "# 定义中文转录任务\n", "forced_decoder_ids = processor.get_decoder_prompt_ids(language=\"chinese\", task=\"transcribe\")\n", @@ -252,20 +320,6 @@ " st = time.time()\n", " predicted_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids)\n", " end = time.time()\n", - " transcribe_str = processor.batch_decode(predicted_ids, skip_special_tokens=True)\n", - "\n", - " print(f'Inference time: {end-st} s')\n", - " print('-'*20, 'Chinese Transcription', '-'*20)\n", - " print(transcribe_str)\n", - "\n", - "# 定义中文转录以及翻译任务\n", - "forced_decoder_ids = processor.get_decoder_prompt_ids(language=\"chinese\", task=\"translate\")\n", - "\n", - "with torch.inference_mode():\n", - " input_features = processor(data_zh, sampling_rate=sample_rate_zh, return_tensors=\"pt\").input_features\n", - " st = time.time()\n", - " predicted_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids)\n", - " end = time.time()\n", " translate_str = processor.batch_decode(predicted_ids, skip_special_tokens=True)\n", "\n", " print(f'Inference time: {end-st} s')\n", @@ -274,7 +328,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -300,7 +353,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.17" + "version": "3.10.0" } }, "nbformat": 4, diff --git a/ch_5_AppDev_Intermediate/5_1_ChatBot.ipynb b/ch_5_AppDev_Intermediate/5_1_ChatBot.ipynb index acdba9a..5707559 100644 --- a/ch_5_AppDev_Intermediate/5_1_ChatBot.ipynb +++ b/ch_5_AppDev_Intermediate/5_1_ChatBot.ipynb @@ -1,7 +1,6 @@ { "cells": [ { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -26,7 +25,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -51,11 +49,10 @@ "from huggingface_hub import snapshot_download\n", "\n", "model_path = snapshot_download(repo_id='meta-llama/Llama-2-7b-chat-hf',\n", - " token='hf_XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX') # change it to your own Hugging Face access token" + " token='hf_XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX') # change it to your own Hugging Face access token\n" ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -63,9 +60,7 @@ ">\n", "> The model will by default be downloaded to `HF_HOME='~/.cache/huggingface'`.\n", "\n", - "### 5.1.2.1 Load Model in Low Precision\n", - "\n", - "One common use case is to load a Hugging Face *transformers* model in low precision, i.e. conduct **implicit** quantization while loading.\n", + " One common use case is to load a Hugging Face *transformers* model in low precision, i.e. conduct **implicit** quantization while loading.\n", "\n", "For Llama 2 (7B), you could simply import `bigdl.llm.transformers.AutoModelForCausalLM` instead of `transformers.AutoModelForCausalLM`, and specify `load_in_4bit=True` or `load_in_low_bit` parameter accordingly in the `from_pretrained` function. Compared to the Hugging Face *transformers* API, only minor code changes are required.\n", "\n", @@ -74,18 +69,39 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e88467d77ffa4f17a423e0339e303d1b", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/2 [00:00 * `load_in_4bit=True` is equivalent to `load_in_low_bit='sym_int4'`.\n", "\n", "\n", + "\n", "### 5.1.2.2 Load Tokenizer \n", "\n", - "A tokenizer is also needed for LLM inference. It is used to encode input texts to tensors to feed to LLMs, and decode the LLM output tensors to texts. You can use [Huggingface transformers](https://huggingface.co/docs/transformers/index) API to load the tokenizer directly. It can be used seamlessly with models loaded by BigDL-LLM. For Llama 2, the corresponding tokenizer class is `LlamaTokenizer`." + "A tokenizer is also needed for LLM inference. It is used to encode input texts to tensors to feed to LLMs, and decode the LLM output tensors to texts. You can use [Huggingface transformers](https://huggingface.co/docs/transformers/index) API to load the tokenizer directly. It can be used seamlessly with models loaded by BigDL-LLM. For Llama 2, the corresponding tokenizer class is `LlamaTokenizer`.\n" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from transformers import LlamaTokenizer\n", "\n", - "tokenizer = LlamaTokenizer.from_pretrained(pretrained_model_name_or_path=\"meta-llama/Llama-2-7b-chat-hf\")" + "tokenizer = LlamaTokenizer.from_pretrained(pretrained_model_name_or_path=\"../chat-7b-hf/\")" ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### 5.1.2.3 Save & Load Low-Precision Model (Optional)\n", "\n", "`from_pretrained` includes a conversion/quantization step, which can be particularly time-consuming or memory-intensive for some large models. To expedite this process, you can use `save_low_bit` API to store the converted model, after the model is loaded first-time using `from_pretrained`. In subsequent uses, you can opt to use the `load_low_bit` instead of `from_pretrained`, which allows for a direct loading of the pre-converted model and speedup the process. The saving and loading process can be done on different machines.\n", - "\n", - "\n", + " \n", "**Save Low-Precision Model**\n", "\n", "Let's take the `model_in_4bit` in section 5.1.2.1 as an example. After we loading Llama 2 (7B) in 4 bit, we could use the `save_low_bit` function to save the optimized model:" @@ -143,7 +158,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -154,7 +169,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -171,7 +185,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -182,37 +195,92 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-11-17 09:56:00,824 - INFO - Converting the current model to sym_int4 format......\n" + ] + } + ], "source": [ "# note that the AutoModelForCausalLM here is imported from bigdl.llm.transformers\n", + "from bigdl.llm.transformers import AutoModelForCausalLM\n", + "from transformers import LlamaTokenizer\n", + "save_directory='./llama-2-7b-bigdl-llm-4-bit'\n", "model_in_4bit = AutoModelForCausalLM.load_low_bit(save_directory)\n", - "\n", "tokenizer = LlamaTokenizer.from_pretrained(save_directory)" ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## 5.1.3 Run Model\n", "\n", "BigDL-LLM optimized *transformers* model runs much faster than original model. [Chapter 3 Basic Application Develop](../ch_3_AppDev_Basic/) introduces some basics of using optimized model for direct text completion. In this section we will introduce some advanced usages.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " Einzeln 2018-12-06 at 15:06\n", + "\n", + "Artificial intelligence (AI) is a branch of computer science that focuses on creating intelligent machines that can perform tasks that typically require human intelligence, such as understanding natural language, recognizing images, making decisions, and solving problems. AI research involves developing algorithms and statistical models that enable computers to perform these tasks, as well as creating systems that can learn from experience and improve their performance over time.\n", + "There are several subfields of AI, including:\n", + "1.\n" + ] + } + ], + "source": [ + "prompt = \"what is AI?\"\n", + "input_ids = tokenizer.encode(prompt, return_tensors=\"pt\")\n", + "# predict next tokens with stopping_criteria\n", + "output_ids = model_in_4bit.generate(input_ids,\n", + " max_new_tokens=120)\n", + "output_str = tokenizer.decode(output_ids[0][len(input_ids[0]):], # skip prompt in generated tokens\n", + " skip_special_tokens=True)\n", "\n", + "print(output_str)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "output_ids" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ "### 5.1.3.1 Chat\n", "\n", - "One common application of large language models is Chatbot, where LLMs can engage in interactive conversations. Chatbot interaction is no magic - it still relies on the prediction and generation of next tokens by LLMs. To make LLMs chat, we need to properly format the prompts into a converation format, for example:\n", - "\n", + "One common application of large language models is Chatbot, where LLMs can engage in interactive conversations.\n", + "Chatbot interaction is no magic - it still relies on the prediction and generation of next tokens by LLMs. \n", + "To make LLMs chat, we need to properly format the prompts into a converation format, for example:\n", "```\n", "[INST] <>\n", "You are a helpful, respectful and honest assistant, who always answers as helpfully as possible, while being safe.\n", "<>\n", "\n", - "What is AI? [/INST]\n", - "```\n", + "What is AI? [/INST]\n", "\n", + "```\n", "Further, to enable a multi-turn chat experience, you need to append the new dialog input to the previous conversation to make a new prompt for the model, for example: \n", "\n", "```\n", @@ -220,7 +288,7 @@ "You are a helpful, respectful and honest assistant, who always answers as helpfully as possible, while being safe.\n", "<>\n", "\n", - "What is AI? [/INST] AI is a term used to describe the development of computer systems that can perform tasks that typically require human intelligence, such as understanding natural language, recognizing images. [INST] Is it dangerous? [INST]\n", + "What is AI? [/INST] AI is a term used to describe the development of computer systems that can perform tasks that typically require human intelligence, such as understanding natural language, recognizing images. [INST] Is it dangerous? [/INST]\n", "```\n", "\n", "Now we show a multi-turn chat example using official `transformers` API with BigDL-LLM optimized Llama 2 (7B) model. \n", @@ -230,13 +298,65 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[INST] <>\n", + "You are a helpful, respectful and honest assistant.\n", + "<>\n", + "\n", + "one plus one? [/INST]\n" + ] + } + ], "source": [ - "SYSTEM_PROMPT = \"You are a helpful, respectful and honest assistant, who always answers as helpfully as possible, while being safe.\"\n", + "SYSTEM_PROMPT = \"You are a helpful, respectful and honest assistant.\"\n", + "prompt = [f'[INST] <>\\n{SYSTEM_PROMPT}\\n<>\\n\\n']\n", + "\n", + "input_str = \"one plus one?\"\n", + "input_str = input_str.strip()\n", + "prompt.append(f'{input_str} [/INST]')\n", + "prompt = ''.join(prompt)\n", + "print((prompt))" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Of course! 2 + 1 = 3. How can I assist you further?\n" + ] + } + ], + "source": [ + "input_ids = tokenizer.encode(prompt, return_tensors=\"pt\")\n", + "# predict next tokens with stopping_criteria\n", + "output_ids = model_in_4bit.generate(input_ids,\n", + " max_new_tokens=120)\n", + "output_str = tokenizer.decode(output_ids[0][len(input_ids[0]):], # skip prompt in generated tokens\n", + " skip_special_tokens=True)\n", "\n", + "\n", + "print(output_str)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ "def format_prompt(input_str, chat_history):\n", + " SYSTEM_PROMPT = \"You are a helpful, respectful and honest assistant.\"\n", " prompt = [f'[INST] <>\\n{SYSTEM_PROMPT}\\n<>\\n\\n']\n", " do_strip = False\n", " for history_input, history_response in chat_history:\n", @@ -245,6 +365,7 @@ " prompt.append(f'{history_input} [/INST] {history_response.strip()} [INST] ')\n", " input_str = input_str.strip() if do_strip else input_str\n", " prompt.append(f'{input_str} [/INST]')\n", + " #print(''.join(prompt))\n", " return ''.join(prompt)" ] }, @@ -256,7 +377,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -265,7 +385,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -287,7 +407,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -300,50 +419,28 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Input: What is CPU?\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Response: Hello! I'm here to help you with your question. CPU stands for Central Processing Unit. It's the part of a computer that performs calculations and executes instructions. It's the \"brain\" of the computer, responsible for processing and executing instructions from software programs.\n", - "However, I must point out that the term \"CPU\" can be somewhat outdated, as modern computers often use more advanced processors like \"CPUs\" that are more powerful and efficient. Additionally, some computers may use other types of processors, such as \"GPUs\" (Graphics Processing Units) or \"AP\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Input: What is its difference between GPU?\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Response: Ah, an excellent question! GPU stands for Graphics Processing Unit, and it's a specialized type of processor designed specifically for handling graphical processing tasks.\n", - "The main difference between a CPU and a GPU is their architecture and the types of tasks they are designed to handle. A CPU (Central Processing Unit) is a general-purpose processor that can perform a wide range of tasks, including executing software instructions, managing system resources, and communicating with peripherals. It's the \"brain\" of the computer, responsible for making decisions and controlling the overall operation of the system.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Input: stop\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ + "Input:hello\n", + "[INST] <>\n", + "You are a helpful, respectful and honest assistant.\n", + "<>\n", + "\n", + "hello [/INST]\n", + "Response: Hello there! It's nice to meet you. Is there anything I can help you with or any questions you have? I'm here to assist you in any way I can. Please let me know how I can help.\n", + "Input:one plus one?\n", + "[INST] <>\n", + "You are a helpful, respectful and honest assistant.\n", + "<>\n", + "\n", + "hello [/INST] Hello there! It's nice to meet you. Is there anything I can help you with or any questions you have? I'm here to assist you in any way I can. Please let me know how I can help. [INST] one plus one? [/INST]\n", + "Response: Great, let's do some basic arithmetic! The answer to \"one plus one\" is 2.\n", + "Input:stop\n", "Chat with Llama 2 (7B) stopped.\n" ] } @@ -357,8 +454,8 @@ " with torch.inference_mode():\n", " user_input = input(\"Input:\")\n", " if user_input == \"stop\": # let's stop the conversation when user input \"stop\"\n", - " print(\"Chat with Llama 2 (7B) stopped.\")\n", - " break\n", + " print(\"Chat with Llama 2 (7B) stopped.\")\n", + " break\n", " chat(model=model_in_4bit,\n", " tokenizer=tokenizer,\n", " input_str=user_input,\n", @@ -366,13 +463,65 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### 5.1.3.2 Stream Chat\n", "\n", - "Stream chat can be considered as an advanced function for a chatbot, where the response is generated word by word. Here, we define the `stream_chat` function with the help of `transformers.TextIteratorStreamer`:" + "Stream chat can be considered as an advanced function for a chatbot, where the response is generated word by word. Here, we define the `stream_chat` function with the help of `transformers.TextIteratorStreamer`:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-11-17 10:21:49,596 - INFO - Converting the current model to sym_int4 format......\n" + ] + } + ], + "source": [ + "# note that the AutoModelForCausalLM here is imported from bigdl.llm.transformers\n", + "from bigdl.llm.transformers import AutoModelForCausalLM\n", + "from transformers import TextIteratorStreamer\n", + "from threading import Thread\n", + "from transformers import LlamaTokenizer\n", + "\n", + "save_directory='./llama-2-7b-bigdl-llm-4-bit'\n", + "model_in_4bit = AutoModelForCausalLM.load_low_bit(save_directory)\n", + "\n", + "token = LlamaTokenizer.from_pretrained(save_directory)\n", + "inputs = token([\"An increasing sequense: one,\"], return_tensors='pt')\n", + "streamer = TextIteratorStreamer(token)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Response: An increasing sequense: one, two, three, four, five, six, seven, eight, nine, ten. Unterscheidung between \"one\" and \"on\" is not always clear-cut, but generally \"one\" refers to the number and \"on\" is an adverb meaning \"at or near\". For example: \"Can you pass me one book from the shelf?\" vs. \"The dog is running on the field.\"." + ] + } + ], + "source": [ + "generation = dict(inputs, streamer=streamer, max_new_tokens=120)\n", + "thread = Thread(target=model_in_4bit.generate, kwargs=generation)\n", + "thread.start() \n", + "output_str = []\n", + "\n", + "print(\"Response: \", end=\"\")\n", + "for stream_output in streamer:\n", + " output_str.append(stream_output)\n", + " print(stream_output, end=\"\")" ] }, { @@ -415,7 +564,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -432,43 +580,7 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Input: What is AI?\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Response: Hello! I'm glad you asked! AI, or artificial intelligence, is a broad field of computer science that focuses on creating intelligent machines that can perform tasks that typically require human intelligence, such as understanding language, recognizing images, making decisions, and solving problems.\n", - "There are many types of AI, including:\n", - "1. Machine learning: This is a subset of AI that involves training machines to learn from data without being explicitly programmed.\n", - "2. Natural language processing: This is a type of AI that allows machines to understand, interpret, and generate human language.\n", - "3. Rob" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Input: Is it dangerous?\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Response: As a responsible and ethical AI language model, I must inform you that AI, like any other technology, can be used for both positive and negative purposes. It is important to recognize that AI is a tool, and like any tool, it can be used for good or bad.\n", - "There are several potential dangers associated with AI, including:\n", - "1. Bias and discrimination: AI systems can perpetuate and amplify existing biases and discrimination if they are trained on biased data or designed with a particular worldview.\n", - "2. Job displacement: AI has the" - ] - } - ], + "outputs": [], "source": [ "chat_history = []\n", "\n", @@ -476,8 +588,8 @@ " with torch.inference_mode():\n", " user_input = input(\"Input:\")\n", " if user_input == \"stop\": # let's stop the conversation when user input \"stop\"\n", - " print(\"Stream Chat with Llama 2 (7B) stopped.\")\n", - " break\n", + " print(\"Stream Chat with Llama 2 (7B) stopped.\")\n", + " break\n", " stream_chat(model=model_in_4bit,\n", " tokenizer=tokenizer,\n", " input_str=user_input,\n", @@ -485,7 +597,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -511,7 +622,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.17" + "version": "3.10.0" } }, "nbformat": 4, diff --git a/ch_5_AppDev_Intermediate/5_2_Speech_Recognition.ipynb b/ch_5_AppDev_Intermediate/5_2_Speech_Recognition.ipynb index 612462b..ecbbc07 100644 --- a/ch_5_AppDev_Intermediate/5_2_Speech_Recognition.ipynb +++ b/ch_5_AppDev_Intermediate/5_2_Speech_Recognition.ipynb @@ -1,7 +1,6 @@ { "cells": [ { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -28,7 +27,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -45,7 +43,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -56,20 +53,30 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": { "vscode": { "languageId": "plaintext" } }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "'wget' 不是内部或外部命令,也不是可运行的程序\n", + "或批处理文件。\n", + "'wget' 不是内部或外部命令,也不是可运行的程序\n", + "或批处理文件。\n" + ] + } + ], "source": [ "!wget -O audio_en.mp3 https://datasets-server.huggingface.co/assets/common_voice/--/en/train/5/audio/audio.mp3\n", "!wget -O audio_zh.mp3 https://datasets-server.huggingface.co/assets/common_voice/--/zh-CN/train/2/audio/audio.mp3" ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -78,18 +85,52 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "import IPython\n", "\n", - "IPython.display.display(IPython.display.Audio(\"audio_en.mp3\"))\n", - "IPython.display.display(IPython.display.Audio(\"audio_zh.mp3\"))" + "IPython.display.display(IPython.display.Audio(\"en.mp3\"))\n", + "IPython.display.display(IPython.display.Audio(\"ch.mp3\"))" ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -102,18 +143,25 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-11-16 11:28:56,693 - INFO - Converting the current model to sym_int4 format......\n" + ] + } + ], "source": [ "from bigdl.llm.transformers import AutoModelForSpeechSeq2Seq\n", "\n", - "model = AutoModelForSpeechSeq2Seq.from_pretrained(pretrained_model_name_or_path=\"openai/whisper-medium\",\n", + "model = AutoModelForSpeechSeq2Seq.from_pretrained(pretrained_model_name_or_path=\"./model/\",\n", " load_in_4bit=True)" ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -124,17 +172,16 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from transformers import WhisperProcessor\n", "\n", - "processor = WhisperProcessor.from_pretrained(pretrained_model_name_or_path=\"openai/whisper-medium\")" + "processor = WhisperProcessor.from_pretrained(pretrained_model_name_or_path=\"./model/\")" ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -147,17 +194,27 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "16000\n", + "8.7771875\n" + ] + } + ], "source": [ "import librosa\n", "\n", - "data_en, sample_rate_en = librosa.load(\"audio_en.mp3\", sr=16000)" + "data_en, sample_rate_en = librosa.load(\"en.mp3\", sr=16000)\n", + "print(sample_rate_en)\n", + "print(int(data_en.shape[0])/16000)" ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -170,16 +227,31 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Inference time: xxxx s\n", + "[(1, 50259), (2, 50359), (3, 50363)]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "C:\\Users\\deep\\anaconda3\\envs\\bigdl\\lib\\site-packages\\transformers\\generation\\utils.py:1353: UserWarning: Using `max_length`'s default (448) to control the generation length. This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we recommend using `max_new_tokens` to control the maximum length of the generation.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Inference time: 7.836472034454346 s\n", "-------------------- English Transcription --------------------\n", - "[' Book me a reservation for mid-day at French Camp Academy.']\n" + "[' And many diseases are also caused by the various additives in food. We often heard that a bag of cheese may contain 120 kinds of additives.']\n" ] } ], @@ -191,7 +263,7 @@ "forced_decoder_ids = processor.get_decoder_prompt_ids(language=\"english\", task=\"transcribe\")\n", "\n", "with torch.inference_mode():\n", - " # extract input features for the Whisper model\n", + " # extract input features for the Whisper model mel-filter bank features\n", " input_features = processor(data_en, sampling_rate=sample_rate_en, return_tensors=\"pt\").input_features\n", "\n", " # predict token ids for transcription\n", @@ -208,7 +280,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -219,47 +290,30 @@ "\n", "## 5.2.6 Transcribe Chinese Audio and Translate to English\n", "\n", - "Then let's move to the Chinese audio `audio_zh.mp3`. Whisper can transcribe multilingual audio, and translate them into English. The only difference here is to define specific context token through `forced_decoder_ids`:" + "Then let's move to the Chinese audio `audio_zh.mp3`. The training corpus of Whisper includes 680,000 hours of audio and covers over 90 languages. This allows us to achieve translation from it to English. Whisper can transcribe multilingual audio, and translate them into English. The only difference here is to define specific context token through `forced_decoder_ids`:" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Inference time: xxxx s\n", - "-------------------- Chinese Transcription --------------------\n", - "['制作时将各原料研磨']\n", - "Inference time: xxxx s\n", + "Inference time: 8.468757152557373 s\n", "-------------------- Chinese to English Translation --------------------\n", - "[' When making the dough, grind the ingredients.']\n" + "['是对经济社会发展情况的一次全面体验对于摸清家底反映发展成效具有重大而深远的意义第5次全国经济普查标准是']\n" ] } ], "source": [ "# extract sequence data\n", - "data_zh, sample_rate_zh = librosa.load(\"audio_zh.mp3\", sr=16000)\n", - "\n", - "# define Chinese transcribe task\n", - "forced_decoder_ids = processor.get_decoder_prompt_ids(language=\"chinese\", task=\"transcribe\")\n", - "\n", - "with torch.inference_mode():\n", - " input_features = processor(data_zh, sampling_rate=sample_rate_zh, return_tensors=\"pt\").input_features\n", - " st = time.time()\n", - " predicted_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids)\n", - " end = time.time()\n", - " transcribe_str = processor.batch_decode(predicted_ids, skip_special_tokens=True)\n", - "\n", - " print(f'Inference time: {end-st} s')\n", - " print('-'*20, 'Chinese Transcription', '-'*20)\n", - " print(transcribe_str)\n", + "data_zh, sample_rate_zh = librosa.load(\"zh.mp3\", sr=16000)\n", "\n", "# define Chinese transcribe and translation task\n", - "forced_decoder_ids = processor.get_decoder_prompt_ids(language=\"chinese\", task=\"translate\")\n", + "forced_decoder_ids = processor.get_decoder_prompt_ids(language=\"chinese\", task=\"transcribe\")\n", "\n", "with torch.inference_mode():\n", " input_features = processor(data_zh, sampling_rate=sample_rate_zh, return_tensors=\"pt\").input_features\n", @@ -274,7 +328,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -300,7 +353,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.17" + "version": "3.10.0" } }, "nbformat": 4,