diff --git a/docs_website/docs/changelog/breaking_change.mdx b/docs_website/docs/changelog/breaking_change.mdx index 76324b21e..4ce0f97ae 100644 --- a/docs_website/docs/changelog/breaking_change.mdx +++ b/docs_website/docs/changelog/breaking_change.mdx @@ -7,6 +7,13 @@ slug: /changelog Here are the list of breaking changes that you should be aware of when updating Querybook: +## v3.31.0 + +Upgraded langchain to [0.1.6](https://blog.langchain.dev/langchain-v0-1-0/). + +- Some langchain packages are imported from different paths, e.g. `PromptTemplate` is now from `langchain.prompts` +- Removed `StreamingWebsocketCallbackHandler` to adopt the new streaming approach. + ## v3.29.0 Made below changes for `S3BaseExporter` (csv table uploader feature): diff --git a/package.json b/package.json index 43a574e77..8e2ba8749 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "querybook", - "version": "3.30.0", + "version": "3.31.0", "description": "A Big Data Webapp", "private": true, "scripts": { diff --git a/querybook/config/querybook_default_config.yaml b/querybook/config/querybook_default_config.yaml index 052bdea82..265e388e2 100644 --- a/querybook/config/querybook_default_config.yaml +++ b/querybook/config/querybook_default_config.yaml @@ -92,6 +92,7 @@ AI_ASSISTANT_CONFIG: model_args: model_name: ~ temperature: ~ + streaming: ~ reserved_tokens: ~ EMBEDDINGS_PROVIDER: ~ diff --git a/querybook/server/lib/ai_assistant/ai_socket.py b/querybook/server/lib/ai_assistant/ai_socket.py index 4fb916e96..b7d81565d 100644 --- a/querybook/server/lib/ai_assistant/ai_socket.py +++ b/querybook/server/lib/ai_assistant/ai_socket.py @@ -26,12 +26,6 @@ def _send(self, event_type, payload: dict = None): def send_data(self, data: dict): self._send("data", data) - def send_delta_data(self, data: str): - self._send("delta_data", data) - - def send_delta_end(self): - self._send("delta_end") - def send_tables_for_sql_gen(self, data: list[str]): self._send("tables", data) diff --git a/querybook/server/lib/ai_assistant/assistants/openai_assistant.py b/querybook/server/lib/ai_assistant/assistants/openai_assistant.py index c629d1f23..2ae7ff7be 100644 --- a/querybook/server/lib/ai_assistant/assistants/openai_assistant.py +++ b/querybook/server/lib/ai_assistant/assistants/openai_assistant.py @@ -1,7 +1,6 @@ import openai import tiktoken -from langchain.callbacks.manager import CallbackManager -from langchain.chat_models import ChatOpenAI +from langchain_openai import ChatOpenAI from lib.ai_assistant.base_ai_assistant import BaseAIAssistant from lib.logger import get_logger @@ -46,19 +45,12 @@ def _get_token_count(self, ai_command: str, prompt: str) -> int: return len(encoding.encode(prompt)) def _get_error_msg(self, error) -> str: - if isinstance(error, openai.error.AuthenticationError): + if isinstance(error, openai.AuthenticationError): return "Invalid OpenAI API key" return super()._get_error_msg(error) - def _get_llm(self, ai_command: str, prompt_length: int, callback_handler=None): + def _get_llm(self, ai_command: str, prompt_length: int): config = self._get_llm_config(ai_command) - if not callback_handler: - # non-streaming - return ChatOpenAI(**config) - return ChatOpenAI( - **config, - streaming=True, - callback_manager=CallbackManager([callback_handler]) - ) + return ChatOpenAI(**config) diff --git a/querybook/server/lib/ai_assistant/base_ai_assistant.py b/querybook/server/lib/ai_assistant/base_ai_assistant.py index fae8af071..0096c107a 100644 --- a/querybook/server/lib/ai_assistant/base_ai_assistant.py +++ b/querybook/server/lib/ai_assistant/base_ai_assistant.py @@ -1,7 +1,10 @@ import functools -import json from abc import ABC, abstractmethod +from langchain_core.language_models.base import BaseLanguageModel +from langchain_core.output_parsers import JsonOutputParser, StrOutputParser +from pydantic import ValidationError + from app.db import with_session from const.ai_assistant import ( DEFAUTL_TABLE_SELECT_LIMIT, @@ -17,16 +20,15 @@ from logic.metastore import get_table_by_name from models.metastore import DataTableColumn from models.query_execution import QueryExecution -from pydantic.error_wrappers import ValidationError -from .ai_socket import with_ai_socket +from .ai_socket import AIWebSocket, with_ai_socket +from .prompts.sql_edit_prompt import SQL_EDIT_PROMPT from .prompts.sql_fix_prompt import SQL_FIX_PROMPT from .prompts.sql_summary_prompt import SQL_SUMMARY_PROMPT from .prompts.sql_title_prompt import SQL_TITLE_PROMPT from .prompts.table_select_prompt import TABLE_SELECT_PROMPT from .prompts.table_summary_prompt import TABLE_SUMMARY_PROMPT from .prompts.text_to_sql_prompt import TEXT_TO_SQL_PROMPT -from .streaming_web_socket_callback_handler import StreamingWebsocketCallbackHandler from .tools.table_schema import ( get_slimmed_table_schemas, get_table_schema_by_name, @@ -57,11 +59,7 @@ def wrapper(self, *args, **kwargs): except Exception as e: LOG.error(e, exc_info=True) err_msg = self._get_error_msg(e) - callback_handler = kwargs.get("callback_handler") - if callback_handler: - callback_handler.stream.send_error(err_msg) - else: - raise Exception(err_msg) from e + raise Exception(err_msg) from e return wrapper @@ -96,14 +94,12 @@ def _get_llm( self, ai_command: str, prompt_length: int, - callback_handler: StreamingWebsocketCallbackHandler = None, - ): + ) -> BaseLanguageModel: """return the large language model to use. Args: ai_command (str): AI command type prompt_length (str): The number of tokens in the prompt. Can be used to decide which model to use. - callback_handler (StreamingWebsocketCallbackHandler, optional): Callback handler to handle the straming result. """ raise NotImplementedError() @@ -112,7 +108,8 @@ def _get_sql_title_prompt(self, query): def _get_text_to_sql_prompt(self, dialect, question, table_schemas, original_query): context_limit = self._get_usable_token_count(AICommandType.TEXT_TO_SQL.value) - prompt = TEXT_TO_SQL_PROMPT.format( + prompt_template = SQL_EDIT_PROMPT if original_query else TEXT_TO_SQL_PROMPT + prompt = prompt_template.format( dialect=dialect, question=question, table_schemas=table_schemas, @@ -122,7 +119,7 @@ def _get_text_to_sql_prompt(self, dialect, question, table_schemas, original_que if token_count > context_limit: # if the prompt is too long, use slimmed table schemas - prompt = TEXT_TO_SQL_PROMPT.format( + prompt = prompt_template.format( dialect=dialect, question=question, table_schemas=get_slimmed_table_schemas(table_schemas), @@ -184,6 +181,26 @@ def _get_query_execution_error(self, query_execution: QueryExecution) -> str: return error[:1000] + def _run_prompt_and_send( + self, + socket: AIWebSocket, + command: AICommandType, + llm: BaseLanguageModel, + prompt_text: str, + ): + """Run the prompt and send the response to the websocket. If the command is streaming, send the response in streaming mode.""" + + chain = llm | JsonOutputParser() + + if self._get_llm_config(command.value).get("streaming", False): + for s in chain.stream(prompt_text): + socket.send_data(s) + socket.close() + else: + response = chain.invoke(prompt_text) + socket.send_data(response) + socket.close() + @catch_error @with_session @with_ai_socket(command_type=AICommandType.TEXT_TO_SQL) @@ -213,7 +230,9 @@ def generate_sql_query( # not finding any relevant tables # ask user to provide table names socket.send_data( - "Sorry, I can't find any relevant tables by the given context. Please provide table names above." + { + "explanation": "Sorry, I can't find any relevant tables by the given context. Please provide table names above." + } ) socket.close() @@ -237,9 +256,14 @@ def generate_sql_query( prompt_length=self._get_token_count( AICommandType.TEXT_TO_SQL.value, prompt ), - callback_handler=StreamingWebsocketCallbackHandler(socket), ) - return llm.predict(text=prompt) + + self._run_prompt_and_send( + socket=socket, + command=AICommandType.TEXT_TO_SQL, + llm=llm, + prompt_text=prompt, + ) @catch_error @with_ai_socket(command_type=AICommandType.SQL_TITLE) @@ -248,16 +272,18 @@ def generate_title_from_query(self, query, socket=None): Args: query (str): SQL query - stream (bool, optional): Whether to stream the result. Defaults to True. - callback_handler (CallbackHandler, optional): Callback handler to handle the straming result. Required if stream is True. """ prompt = self._get_sql_title_prompt(query=query) llm = self._get_llm( ai_command=AICommandType.SQL_TITLE.value, prompt_length=self._get_token_count(AICommandType.SQL_TITLE.value, prompt), - callback_handler=StreamingWebsocketCallbackHandler(socket), ) - return llm.predict(text=prompt) + self._run_prompt_and_send( + socket=socket, + command=AICommandType.SQL_TITLE, + llm=llm, + prompt_text=prompt, + ) @catch_error @with_session @@ -268,7 +294,7 @@ def query_auto_fix( socket=None, session=None, ): - """Generate title from SQL query. + """Fix a SQL query from the error message of a failed query execution. Args: query_execution_id (int): The failed query execution id @@ -301,9 +327,13 @@ def query_auto_fix( llm = self._get_llm( ai_command=AICommandType.SQL_FIX.value, prompt_length=self._get_token_count(AICommandType.SQL_FIX.value, prompt), - callback_handler=StreamingWebsocketCallbackHandler(socket), ) - return llm.predict(text=prompt) + self._run_prompt_and_send( + socket=socket, + command=AICommandType.SQL_FIX, + llm=llm, + prompt_text=prompt, + ) @catch_error @with_session @@ -337,9 +367,9 @@ def summarize_table( prompt_length=self._get_token_count( AICommandType.TABLE_SUMMARY.value, prompt ), - callback_handler=None, ) - return llm.predict(text=prompt) + chain = llm | StrOutputParser() + return chain.invoke(prompt) @catch_error @with_session @@ -365,9 +395,9 @@ def summarize_query( prompt_length=self._get_token_count( AICommandType.SQL_SUMMARY.value, prompt ), - callback_handler=None, ) - return llm.predict(text=prompt) + chain = llm | StrOutputParser() + return chain.invoke(prompt) @with_session def find_tables(self, metastore_id, question, session=None): @@ -422,9 +452,9 @@ def find_tables(self, metastore_id, question, session=None): prompt_length=self._get_token_count( AICommandType.TABLE_SELECT.value, prompt ), - callback_handler=None, ) - return json.loads(llm.predict(text=prompt)) + chain = llm | JsonOutputParser() + return chain.invoke(prompt) except Exception as e: LOG.error(e, exc_info=True) return [] diff --git a/querybook/server/lib/ai_assistant/prompts/sql_edit_prompt.py b/querybook/server/lib/ai_assistant/prompts/sql_edit_prompt.py new file mode 100644 index 000000000..2f198d0cc --- /dev/null +++ b/querybook/server/lib/ai_assistant/prompts/sql_edit_prompt.py @@ -0,0 +1,32 @@ +from langchain.prompts import PromptTemplate + +prompt_template = """ +You are a {dialect} expert. + +Please help to modify the original {dialect} query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. + +===Tables +{table_schemas} + +===Original Query +{original_query} + +===Response Guidelines +1. If the provided context is sufficient, please modify and generate a valid query without any explanations for the question. The query should start with a comment containing the question being asked. +2. If the provided context is insufficient, please explain why it can't be generated. +3. The original query may start with a comment containing a previously asked question. If you find such a comment, please use both the original question and the new question to modify the query, and update the comment accordingly. +4. Please use the most relevant table(s). +5. Please format the query before responding. +6. Please always respond with a valid well-formed JSON object with the following format + +===Response Format +{{ + "query": "A generated SQL query when context is sufficient.", + "explanation": "An explanation of failing to generate the query." +}} + +===Question +{question} +""" + +SQL_EDIT_PROMPT = PromptTemplate.from_template(prompt_template) diff --git a/querybook/server/lib/ai_assistant/prompts/sql_fix_prompt.py b/querybook/server/lib/ai_assistant/prompts/sql_fix_prompt.py index 535782c9c..720a734a4 100644 --- a/querybook/server/lib/ai_assistant/prompts/sql_fix_prompt.py +++ b/querybook/server/lib/ai_assistant/prompts/sql_fix_prompt.py @@ -1,33 +1,28 @@ -from langchain import PromptTemplate - - -prompt_template = ( - "You are a SQL expert that can help fix SQL query errors.\n\n" - "Please help fix the query below based on the given error message and table schemas. \n\n" - "===SQL dialect\n" - "{dialect}\n\n" - "===Query\n" - "{query}\n\n" - "===Error\n" - "{error}\n\n" - "===Table Schemas\n" - "{table_schemas}\n\n" - "===Response Format\n" - "<@key-1@>\n" - "value-1\n\n" - "<@key-2@>\n" - "value-2\n\n" - "===Example response:\n" - "<@explanation@>\n" - "This is an explanation about the error\n\n" - "<@fix_suggestion@>\n" - "This is a recommended fix for the error\n\n" - "<@fixed_query@>\n" - "The fixed SQL query\n\n" - "===Response Guidelines\n" - "1. For the <@fixed_query@> section, it can only be a valid SQL query without any explanation.\n" - "2. If there is insufficient context to address the query error, you may leave the fixed_query section blank and provide a general suggestion instead.\n" - "3. Maintain the original query format and case in the fixed_query section, including comments, except when correcting the erroneous part.\n" -) +from langchain.prompts import PromptTemplate + + +prompt_template = """You are a {dialect} expert that can help fix SQL query errors. + +Please help fix below {dialect} query based on the given error message and table schemas. + +===Query +{query} + +===Error +{error} + +===Table Schemas +{table_schemas} + +===Response Guidelines +1. If there is insufficient context to address the query error, please leave fixed_query blank and provide a general suggestion instead. +2. Maintain the original query format and case for the fixed_query, including comments, except when correcting the erroneous part. +===Response Format +{{ + "explanation": "An explanation about the error", + "fix_suggestion": "A recommended fix for the error"", + "fixed_query": "A valid and well formatted fixed query" +}} +""" SQL_FIX_PROMPT = PromptTemplate.from_template(prompt_template) diff --git a/querybook/server/lib/ai_assistant/prompts/sql_summary_prompt.py b/querybook/server/lib/ai_assistant/prompts/sql_summary_prompt.py index 6fc19bae4..531f03439 100644 --- a/querybook/server/lib/ai_assistant/prompts/sql_summary_prompt.py +++ b/querybook/server/lib/ai_assistant/prompts/sql_summary_prompt.py @@ -1,4 +1,4 @@ -from langchain import PromptTemplate +from langchain.prompts import PromptTemplate prompt_template = """ diff --git a/querybook/server/lib/ai_assistant/prompts/sql_title_prompt.py b/querybook/server/lib/ai_assistant/prompts/sql_title_prompt.py index ad22d04da..c93a07bcf 100644 --- a/querybook/server/lib/ai_assistant/prompts/sql_title_prompt.py +++ b/querybook/server/lib/ai_assistant/prompts/sql_title_prompt.py @@ -1,17 +1,19 @@ -from langchain import PromptTemplate +from langchain.prompts import PromptTemplate -prompt_template = ( - "You are a helpful data scientist that can summerize SQL queries.\n\n" - "Generate a brief 10-word-maximum title for the SQL query below. " - "===Query\n" - "{query}\n\n" - "===Response Guidelines\n" - "1. Only respond with the title without any explanation\n" - "2. Dont use double quotes to enclose the title\n" - "3. Dont add a final period to the title\n\n" - "===Example response\n" - "This is a title\n" -) +prompt_template = """ +You are a helpful data scientist that can summarize SQL queries. + +Generate a brief 10-word-maximum title for the SQL query below. + +===Query +{query} + +===Response Format +Please respond in below JSON format: +{{ + "title": "This is a title" +}} +""" SQL_TITLE_PROMPT = PromptTemplate.from_template(prompt_template) diff --git a/querybook/server/lib/ai_assistant/prompts/table_select_prompt.py b/querybook/server/lib/ai_assistant/prompts/table_select_prompt.py index e1a8febc0..1e71a1d5c 100644 --- a/querybook/server/lib/ai_assistant/prompts/table_select_prompt.py +++ b/querybook/server/lib/ai_assistant/prompts/table_select_prompt.py @@ -1,4 +1,4 @@ -from langchain import PromptTemplate +from langchain.prompts import PromptTemplate prompt_template = """ diff --git a/querybook/server/lib/ai_assistant/prompts/table_summary_prompt.py b/querybook/server/lib/ai_assistant/prompts/table_summary_prompt.py index dbb822fa0..e03111761 100644 --- a/querybook/server/lib/ai_assistant/prompts/table_summary_prompt.py +++ b/querybook/server/lib/ai_assistant/prompts/table_summary_prompt.py @@ -1,4 +1,4 @@ -from langchain import PromptTemplate +from langchain.prompts import PromptTemplate prompt_template = """ diff --git a/querybook/server/lib/ai_assistant/prompts/text_to_sql_prompt.py b/querybook/server/lib/ai_assistant/prompts/text_to_sql_prompt.py index 8029c2a0c..f8f7cd893 100644 --- a/querybook/server/lib/ai_assistant/prompts/text_to_sql_prompt.py +++ b/querybook/server/lib/ai_assistant/prompts/text_to_sql_prompt.py @@ -1,37 +1,32 @@ -from langchain import PromptTemplate - - -prompt_template = ( - "You are a SQL expert that can help generating SQL query.\n\n" - "Please help to generate a new SQL query or modify the original query to answer the following question. Your response should ONLY be based on the given context.\n\n" - "Please always follow the key/value pair format below for your response:\n" - "===Response Format\n" - "<@query@>\n" - "query\n\n" - "or\n\n" - "<@explanation@>\n" - "explanation\n\n" - "===Example Response:\n" - "Example 1: Sufficient Context\n" - "<@query@>\n" - "A generated SQL query based on the provided context with the asked question at the beginning is provided here.\n\n" - "Example 2: Insufficient Context\n" - "<@explanation@>\n" - "An explanation of the missing context is provided here.\n\n" - "===Response Guidelines\n" - "1. If the provided context is sufficient, please respond only with a valid SQL query without any explanations in the <@query@> section. The query should start with a comment containing the question being asked.\n" - "2. If the provided context is insufficient, please explain what information is missing.\n" - "3. If the original query is provided, please modify the original query to answer the question. The original query may start with a comment containing a previously asked question. If you find such a comment, please use both the original question and the new question to generate the new query.\n" - "4. Please use the most relevant table(s) for the query generation\n" - "5. The response should always start with <@query@> or <@explanation@>\n\n" - "===SQL Dialect\n" - "{dialect}\n\n" - "===Tables\n" - "{table_schemas}\n\n" - "===Original Query\n" - "{original_query}\n\n" - "===Question\n" - "{question}\n\n" -) +from langchain.prompts import PromptTemplate + + +prompt_template = """ +You are a {dialect} expert. + +Please help to generate a {dialect} query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. + +===Tables +{table_schemas} + +===Original Query +{original_query} + +===Response Guidelines +1. If the provided context is sufficient, please generate a valid query without any explanations for the question. The query should start with a comment containing the question being asked. +2. If the provided context is insufficient, please explain why it can't be generated. +3. Please use the most relevant table(s). +5. Please format the query before responding. +6. Please always respond with a valid well-formed JSON object with the following format + +===Response Format +{{ + "query": "A generated SQL query when context is sufficient.", + "explanation": "An explanation of failing to generate the query." +}} + +===Question +{question} +""" TEXT_TO_SQL_PROMPT = PromptTemplate.from_template(prompt_template) diff --git a/querybook/server/lib/ai_assistant/streaming_web_socket_callback_handler.py b/querybook/server/lib/ai_assistant/streaming_web_socket_callback_handler.py deleted file mode 100644 index 92206b4d1..000000000 --- a/querybook/server/lib/ai_assistant/streaming_web_socket_callback_handler.py +++ /dev/null @@ -1,18 +0,0 @@ -from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler - -from .ai_socket import AIWebSocket - - -class StreamingWebsocketCallbackHandler(StreamingStdOutCallbackHandler): - """Callback handlder to stream the result through web socket.""" - - def __init__(self, socket: AIWebSocket): - super().__init__() - self.socket = socket - - def on_llm_new_token(self, token: str, **kwargs): - self.socket.send_delta_data(token) - - def on_llm_end(self, response, **kwargs): - self.socket.send_delta_end() - self.socket.close() diff --git a/querybook/server/lib/vector_store/stores/opensearch.py b/querybook/server/lib/vector_store/stores/opensearch.py index 46b3b5245..673c5ce0f 100644 --- a/querybook/server/lib/vector_store/stores/opensearch.py +++ b/querybook/server/lib/vector_store/stores/opensearch.py @@ -1,5 +1,5 @@ from langchain.docstore.document import Document -from langchain.vectorstores import OpenSearchVectorSearch +from langchain_community.vectorstores import OpenSearchVectorSearch from lib.logger import get_logger from lib.vector_store.base_vector_store import VectorStoreBase diff --git a/querybook/webapp/__tests__/lib/stream.test.ts b/querybook/webapp/__tests__/lib/stream.test.ts deleted file mode 100644 index 3446804c4..000000000 --- a/querybook/webapp/__tests__/lib/stream.test.ts +++ /dev/null @@ -1,147 +0,0 @@ -import { DeltaStreamParser, trimQueryTitle, trimSQLQuery } from 'lib/stream'; - -describe('DeltaStreamParser', () => { - it('Works for stream without key/value pairs', () => { - const parser = new DeltaStreamParser(); - parser.parse('some data'); - expect(parser.result).toEqual({ - data: 'some data', - }); - parser.parse('\nsome more data'); - expect(parser.result).toEqual({ - data: 'some data\nsome more data', - }); - }); - - it('Works for stream ending with non empty buffer', () => { - const parser = new DeltaStreamParser(); - parser.parse('201'); - parser.parse('9'); - expect(parser.result).toEqual({ - data: '201', - }); - parser.close(); - expect(parser.result).toEqual({ - data: '2019', - }); - }); - - it('Works for stream with both data and key/value pairs', () => { - const parser = new DeltaStreamParser(); - parser.parse('some data'); - parser.parse('\n<@some_key@>\nsome value'); - expect(parser.result).toEqual({ - data: 'some data\n', - some_key: 'some value', - }); - }); - - it('Works for stream with only key/value pairs', () => { - const parser = new DeltaStreamParser(); - - parser.parse('<@some_key@>'); - expect(parser.result).toEqual({ - data: '', - some_key: '', - }); - - parser.parse('\nsome value\n'); - expect(parser.result).toEqual({ - data: '', - some_key: 'some value\n', - }); - - parser.parse('<@another_key@>\nanother value'); - expect(parser.result).toEqual({ - data: '', - some_key: 'some value\n', - another_key: 'another value', - }); - }); - - it('Works for partial stream', () => { - const parser = new DeltaStreamParser(); - parser.parse('some da'); - expect(parser.result).toEqual({ - data: 'some da', - }); - // wait for <@ to be complete before parsing - parser.parse('ta<'); - expect(parser.result).toEqual({ - data: 'some da', - }); - // the next char is not @, so it will be treated as data - parser.parse('ta'); - expect(parser.result).toEqual({ - data: 'some data to be complete before parsing - parser.parse('e_key@'); - expect(parser.result).toEqual({ - data: 'some data\n'); - expect(parser.result).toEqual({ - data: 'some data { - it('Works for query title with quotes', () => { - expect(trimQueryTitle('"some title')).toEqual('some title'); - expect(trimQueryTitle('"some title"')).toEqual('some title'); - expect(trimQueryTitle("'some title")).toEqual('some title'); - expect(trimQueryTitle("'some title'")).toEqual('some title'); - }); - - it('Works for query title with trailing period', () => { - expect(trimQueryTitle('some title.')).toEqual('some title'); - }); - - it('Works for query title with both', () => { - expect(trimQueryTitle('"some title.')).toEqual('some title'); - expect(trimQueryTitle('"some title."')).toEqual('some title'); - expect(trimQueryTitle("'some title.'")).toEqual('some title'); - }); -}); - -describe('trimSQLQuery', () => { - it('Works for query with ``` ', () => { - expect(trimSQLQuery('```\nsome query')).toEqual('some query'); - expect(trimSQLQuery('```\nsome query```')).toEqual('some query'); - }); - - it('Works for query with ```sql ', () => { - expect(trimSQLQuery('```sql\nsome query')).toEqual('some query'); - expect(trimSQLQuery('```sql\nsome query```')).toEqual('some query'); - }); -}); diff --git a/querybook/webapp/components/AIAssistant/AutoFixButton.tsx b/querybook/webapp/components/AIAssistant/AutoFixButton.tsx index aeb4f7726..32ecc19eb 100644 --- a/querybook/webapp/components/AIAssistant/AutoFixButton.tsx +++ b/querybook/webapp/components/AIAssistant/AutoFixButton.tsx @@ -1,11 +1,11 @@ -import React, { useState } from 'react'; +import React, { useEffect, useState } from 'react'; import { QueryComparison } from 'components/TranspileQueryModal/QueryComparison'; import { AICommandType } from 'const/aiAssistant'; import { ComponentType, ElementType } from 'const/analytics'; import { useAISocket } from 'hooks/useAISocket'; +import useNonEmptyState from 'hooks/useNonEmptyState'; import { trackClick } from 'lib/analytics'; -import { trimSQLQuery } from 'lib/stream'; import { Button } from 'ui/Button/Button'; import { Message } from 'ui/Message/Message'; import { Modal } from 'ui/Modal/Modal'; @@ -21,19 +21,22 @@ interface IProps { const useSQLFix = () => { const [data, setData] = useState<{ [key: string]: string }>({}); + const [fixedQuery, setFixedQuery] = useNonEmptyState(''); const socket = useAISocket(AICommandType.SQL_FIX, ({ data }) => { - setData(data); + setData(data as { [key: string]: string }); }); const { data: unformattedData, explanation, fix_suggestion: suggestion, - fixed_query: rawFixedQuery, + fixed_query: newFixedQuery, } = data; - const fixedQuery = trimSQLQuery(rawFixedQuery); + useEffect(() => { + setFixedQuery(newFixedQuery); + }, [newFixedQuery]); return { socket, diff --git a/querybook/webapp/components/AIAssistant/QueryGenerationModal.tsx b/querybook/webapp/components/AIAssistant/QueryGenerationModal.tsx index 052d35b27..ddbb62faf 100644 --- a/querybook/webapp/components/AIAssistant/QueryGenerationModal.tsx +++ b/querybook/webapp/components/AIAssistant/QueryGenerationModal.tsx @@ -11,7 +11,6 @@ import { useSurveyTrigger } from 'hooks/ui/useSurveyTrigger'; import { useAISocket } from 'hooks/useAISocket'; import { trackClick } from 'lib/analytics'; import { TableToken } from 'lib/sql-helper/sql-lexer'; -import { trimSQLQuery } from 'lib/stream'; import { matchKeyPress } from 'lib/utils/keyboard'; import { analyzeCode } from 'lib/web-worker'; import { Button } from 'ui/Button/Button'; @@ -124,7 +123,9 @@ export const QueryGenerationModal = ({ const { explanation, query: rawNewQuery, data } = streamData; useEffect(() => { - setNewQuery(trimSQLQuery(rawNewQuery)); + if (rawNewQuery) { + setNewQuery(rawNewQuery); + } }, [rawNewQuery]); const triggerSurvey = useSurveyTrigger(); @@ -146,7 +147,7 @@ export const QueryGenerationModal = ({ query_engine_id: engineId, tables, question, - original_query: query, + original_query: textToSQLMode === TextToSQLMode.EDIT ? query : null, }); trackClick({ component: ComponentType.AI_ASSISTANT, @@ -172,6 +173,23 @@ export const QueryGenerationModal = ({ [onGenerate] ); + const handleKeepQuery = useCallback(() => { + onUpdateQuery(newQuery, false); + setTextToSQLMode(TextToSQLMode.EDIT); + setQuestion(''); + setNewQuery(''); + trackClick({ + component: ComponentType.AI_ASSISTANT, + element: ElementType.QUERY_GENERATION_KEEP_BUTTON, + aux: { + mode: textToSQLMode, + question, + tables, + query: newQuery, + }, + }); + }, [newQuery, onUpdateQuery, textToSQLMode, question, tables]); + const questionBarDOM = (
@@ -380,26 +398,7 @@ export const QueryGenerationModal = ({ {New Query}
diff --git a/querybook/webapp/components/QueryCellTitle/QueryCellTitle.tsx b/querybook/webapp/components/QueryCellTitle/QueryCellTitle.tsx index 91f15ec96..d21398508 100644 --- a/querybook/webapp/components/QueryCellTitle/QueryCellTitle.tsx +++ b/querybook/webapp/components/QueryCellTitle/QueryCellTitle.tsx @@ -5,7 +5,6 @@ import { AICommandType } from 'const/aiAssistant'; import { ComponentType, ElementType } from 'const/analytics'; import { useAISocket } from 'hooks/useAISocket'; import { trackClick } from 'lib/analytics'; -import { trimQueryTitle } from 'lib/stream'; import { IconButton } from 'ui/Button/IconButton'; import { ResizableTextArea } from 'ui/ResizableTextArea/ResizableTextArea'; @@ -37,12 +36,12 @@ export const QueryCellTitle: React.FC = ({ const [title, setTitle] = useState(''); const socket = useAISocket(AICommandType.SQL_TITLE, ({ data }) => { - setTitle(data.data); + setTitle(data.title); }); useEffect(() => { if (title) { - onChange(trimQueryTitle(title)); + onChange(title); } }, [title]); diff --git a/querybook/webapp/hooks/useAISocket.ts b/querybook/webapp/hooks/useAISocket.ts index 410ed3caf..5d58558e4 100644 --- a/querybook/webapp/hooks/useAISocket.ts +++ b/querybook/webapp/hooks/useAISocket.ts @@ -3,7 +3,6 @@ import toast from 'react-hot-toast'; import { AICommandType, AISocketEvent } from 'const/aiAssistant'; import aiAssistantSocket from 'lib/ai-assistant/ai-assistant-socketio'; -import { DeltaStreamParser } from 'lib/stream'; export interface AISocket { loading: boolean; @@ -17,26 +16,11 @@ export function useAISocket( ): AISocket { const [loading, setLoading] = useState(false); - const deltaStreamParserRef = useRef( - new DeltaStreamParser() - ); - const eventHandler = useCallback( (event, payload) => { - const parser = deltaStreamParserRef.current; switch (event) { case AISocketEvent.DATA: - onData({ data: { data: payload } }); - break; - - case AISocketEvent.DELTA_DATA: - parser.parse(payload); - onData({ data: parser.result }); - break; - - case AISocketEvent.DELTA_END: - parser.close(); - onData({ data: parser.result }); + onData({ type: 'data', data: payload }); break; case AISocketEvent.TABLES: @@ -62,7 +46,6 @@ export function useAISocket( aiAssistantSocket.removeListener(commandType, eventHandler); aiAssistantSocket.removeListener('error', onError); setLoading(false); - deltaStreamParserRef.current.reset(); }, [aiAssistantSocket, commandType, eventHandler]); const onError = useCallback( diff --git a/querybook/webapp/hooks/useNonEmptyState.ts b/querybook/webapp/hooks/useNonEmptyState.ts new file mode 100644 index 000000000..0ad8a15dc --- /dev/null +++ b/querybook/webapp/hooks/useNonEmptyState.ts @@ -0,0 +1,30 @@ +import { useCallback, useState } from 'react'; + +/** + * Use this when you need a state from the Gen AI streaming that fits the following desc: + * + * - The state may have a break between two consecutive streaming updates. + * + * e.g. + * {"query": "SELECT\n Country"} + * {"query": ""} // Here somehow the state is empty + * {"query": "SELECT\n Country,\n"} + * {"query": "SELECT\n Country,\n Rank"} + * + * We'd like to return the old value if the new value is empty. + * + * @param initVal + */ +export default function useNonEmptyState(initValue: T | (() => T)) { + const [state, _setState] = useState(initValue); + const setState = useCallback((newValOrFunc: T | ((old: T) => T)) => { + _setState((oldVal) => { + const newVal = + typeof newValOrFunc === 'function' + ? (newValOrFunc as (old: T) => T)(oldVal) + : newValOrFunc; + return newVal || oldVal; + }); + }, []); + return [state, setState] as const; +} diff --git a/querybook/webapp/lib/stream.ts b/querybook/webapp/lib/stream.ts deleted file mode 100644 index 1435b0cf5..000000000 --- a/querybook/webapp/lib/stream.ts +++ /dev/null @@ -1,139 +0,0 @@ -/** - * This is a parser for the delta stream from AI assistant streaming. The stream format is as follows: - * - * ``` - * some data - * <@key1@> - * value1 - * <@key2@> - * value2 - * ``` - * - * Key names are wrapped in <@ and @>. The parser will parse the stream into a JSON object: - * ``` - * { - * key1: 'value1', - * key2: 'value2', - * data: 'some data' - * } - * ``` - * - * "some data" and key/value pairs are all optional. E.g. - * - Without any key/value pairs, the stream will be parsed into: { data: 'some data' } - * - Without any data before the first key/value pair, the stream will be parsed into: { key1: 'value1', key2: 'value2' } - * - * As it is a streaming parser, it will parse the stream incrementally. E.g. if a partial stream is: - * some data - * <@key - * The parser will parse the stream into: { data: 'some data' }. A partial key will not be put into the result. - */ -export class DeltaStreamParser { - private _buffer: string; - private _result: { [key: string]: string }; - private _currentKey: string; - private _currentValue: string; - private _isPartialKey: boolean; - - public constructor() { - this.reset(); - } - - public get result() { - // make a copy of the result to avoid modifying the original result by the caller - return { ...this._result }; - } - - public reset() { - this._buffer = ''; - this._result = {}; - this._currentKey = 'data'; - this._currentValue = ''; - this._isPartialKey = false; - } - - public parse(delta: string) { - this._buffer += delta; - // This is to make sure we always have complete <@ and @> in the buffer - if ( - this._buffer.length < 2 || - this._buffer.endsWith('<') || - this._buffer.endsWith('@') - ) { - return; - } - - let i = 0; - while (i < this._buffer.length - 1) { - const nextTwoChars = this._buffer.slice(i, i + 2); - if (nextTwoChars === '<@') { - this._result[this._currentKey] = this._currentValue.trimStart(); - this._currentKey = ''; - this._currentValue = ''; - this._isPartialKey = true; - i += 1; // skip the next two chars - } else if (this._isPartialKey && nextTwoChars === '@>') { - this._isPartialKey = false; - i += 1; // skip the next two chars - } else { - if (this._isPartialKey) { - this._currentKey += this._buffer[i]; - } else { - this._currentValue += this._buffer[i]; - } - } - i += 1; - } - - // handle the last char - if (i < this._buffer.length) { - if (this._isPartialKey) { - this._currentKey += this._buffer[i]; - } else { - this._currentValue += this._buffer[i]; - } - } - - if (!this._isPartialKey) { - this._result[this._currentKey] = this._currentValue.trimStart(); - } - - this._buffer = ''; - } - - public close() { - // flush the buffer if the stream has ended - if (this._buffer.length) { - if (!this._isPartialKey) { - this._currentValue += this._buffer; - this._result[this._currentKey] = this._currentValue.trimStart(); - this._buffer = ''; - } - } - } -} - -/** - * Trim the title of a query to remove the quotes and trailing period - * - * e.g. - * "some title" => some title - * "some title." => some title - * some title. => some title - */ -export function trimQueryTitle(title: string | null | undefined) { - return title - ?.replace(/^["']|["']$/g, '') - .replace(/\.$/, '') - .trim(); -} - -/** - * Trim the SQL query to remove the wraping ``` - * - * e.g. - * ```\nsome query``` => some query - * ```sql\nsome query``` => some query - */ -export function trimSQLQuery(query: string | null | undefined) { - return query?.replace(/^```(sql)?|```$/g, '').trim(); -} diff --git a/requirements/ai/langchain.txt b/requirements/ai/langchain.txt index 735e5e1ef..79045f2a6 100644 --- a/requirements/ai/langchain.txt +++ b/requirements/ai/langchain.txt @@ -1,2 +1,3 @@ -langchain[openai]==0.0.266 +langchain==0.1.6 +langchain-openai==0.0.5 opensearch-py==2.3.0 diff --git a/requirements/base.txt b/requirements/base.txt index b5975d040..3c39b64bc 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -39,7 +39,7 @@ markdown2 # Utils pandas==1.3.5 -typing-extensions==3.10.0.0 +typing-extensions==4.9.0 setuptools>=65.5.1 # not directly required, pinned by Snyk to avoid a vulnerability numpy>=1.22.2 # not directly required, pinned by Snyk to avoid a vulnerability