From 941ddaf2a74c775011013d9277f22ec7758f1d51 Mon Sep 17 00:00:00 2001 From: jp Date: Fri, 1 Nov 2024 16:58:39 +0800 Subject: [PATCH 1/9] Upgrade to latest vllm --- .github/workflows/main.yml | 9 +++++---- requirements.txt | 3 ++- run_model_cot.sh | 1 - utils/api_server.py | 21 ++++++++++++++++++--- 4 files changed, 25 insertions(+), 9 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 1d2ceca..a301086 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -6,20 +6,21 @@ jobs: lint: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: psf/black@stable test: runs-on: ubuntu-latest needs: lint steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: - python-version: '3.11' + python-version: '3.10' cache: 'pip' - name: Install pip dependencies run: | + pip install --upgrade pip setuptools pip install -r requirements.txt pip install pytest - name: Download spaCy model diff --git a/requirements.txt b/requirements.txt index 982f6b0..f80fb4e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ argparse func_timeout mistralai mysql-connector-python +numpy openai>=1.1.0 pandas pandas-gbq @@ -15,7 +16,7 @@ sentence-transformers snowflake-connector-python spacy sqlalchemy -tiktoken==0.7.0 +tiktoken together torch tqdm diff --git a/run_model_cot.sh b/run_model_cot.sh index 35d62b5..3c09d67 100755 --- a/run_model_cot.sh +++ b/run_model_cot.sh @@ -49,7 +49,6 @@ for model_name in "${model_names[@]}"; do --api_url "http://localhost:${PORT}/generate" \ --api_type "vllm" \ -p 10 \ - --cot_table_alias "prealias" \ --logprobs # finally, kill the api server pkill -9 -f "python3 utils/api_server.py.*--port ${PORT}" diff --git a/utils/api_server.py b/utils/api_server.py index ea1d009..a2d21a5 100644 --- a/utils/api_server.py +++ b/utils/api_server.py @@ -55,17 +55,32 @@ async def generate(request: Request) -> Response: sql_lora_path = request_dict.pop("sql_lora_path", None) request_dict.pop("sql_lora_name", None) lora_request = ( - LoRARequest("sql_adapter", 1, sql_lora_path) if sql_lora_path else None + LoRARequest(lora_name="sql_adapter", lora_int_id=1, lora_path=sql_lora_path) + if sql_lora_path + else None ) + if vllm_version >= "0.6.2": + # remove use_beam_search if present as it's no longer supported + # see https://github.com/vllm-project/vllm/releases/tag/v0.6.2 + if "use_beam_search" in request_dict: + request_dict.pop("use_beam_search") sampling_params = SamplingParams(**request_dict) request_id = random_uuid() tokenizer = await engine.get_tokenizer() prompt_token_ids = tokenizer.encode(text=prompt, add_special_tokens=False) - # print(f"prompt_token_ids: {prompt_token_ids}") if prompt_token_ids[0] != tokenizer.bos_token_id: prompt_token_ids = [tokenizer.bos_token_id] + prompt_token_ids - if vllm_version >= "0.4.2": + if vllm_version >= "0.6.3": + from vllm import TokensPrompt + + results_generator = engine.generate( + prompt=TokensPrompt(prompt_token_ids=prompt_token_ids), + sampling_params=sampling_params, + request_id=request_id, + lora_request=lora_request, + ) + elif vllm_version >= "0.4.2": results_generator = engine.generate( inputs={"prompt_token_ids": prompt_token_ids}, sampling_params=sampling_params, From 272452b43261a488666a967dff3ce3263361047f Mon Sep 17 00:00:00 2001 From: jp Date: Fri, 1 Nov 2024 17:07:01 +0800 Subject: [PATCH 2/9] pin numpy/pandas --- requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index f80fb4e..5bcbb7a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,9 +3,9 @@ argparse func_timeout mistralai mysql-connector-python -numpy +numpy==2.1.2 openai>=1.1.0 -pandas +pandas==2.2.3 pandas-gbq peft psycopg2-binary From 30215a2a791a6f1463ae2406fecadd1d67ad605a Mon Sep 17 00:00:00 2001 From: jp Date: Fri, 1 Nov 2024 17:11:34 +0800 Subject: [PATCH 3/9] pin vllm/torch --- requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 5bcbb7a..31360ca 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,9 +18,9 @@ spacy sqlalchemy tiktoken together -torch +torch==2.4.0 tqdm transformers sqlparse sqlglot -vllm; sys_platform != 'darwin' +vllm==0.6.3.post1; sys_platform != 'darwin' From 748ef1ebdea478c643448b79aa1f81c95abcdee5 Mon Sep 17 00:00:00 2001 From: jp Date: Fri, 1 Nov 2024 17:17:02 +0800 Subject: [PATCH 4/9] separate requirements_test.txt --- .github/workflows/main.yml | 2 +- requirements_test.txt | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) create mode 100644 requirements_test.txt diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index a301086..e1c5488 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -21,7 +21,7 @@ jobs: - name: Install pip dependencies run: | pip install --upgrade pip setuptools - pip install -r requirements.txt + pip install -r requirements_test.txt pip install pytest - name: Download spaCy model run: python -m spacy download en_core_web_sm diff --git a/requirements_test.txt b/requirements_test.txt new file mode 100644 index 0000000..7181382 --- /dev/null +++ b/requirements_test.txt @@ -0,0 +1,11 @@ +func_timeout +pandas +numpy +openai +sentence_transformers +spacy +sqlalchemy +sqlglot +sqlite3 +torch +tqdm \ No newline at end of file From 12208afb3c838d3da27ec69705ec53ba2d487358 Mon Sep 17 00:00:00 2001 From: jp Date: Fri, 1 Nov 2024 17:19:15 +0800 Subject: [PATCH 5/9] fix sqlite3 --- requirements_test.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements_test.txt b/requirements_test.txt index 7181382..c426803 100644 --- a/requirements_test.txt +++ b/requirements_test.txt @@ -1,11 +1,11 @@ func_timeout -pandas numpy openai +pandas +pysqlite3 sentence_transformers spacy sqlalchemy sqlglot -sqlite3 torch tqdm \ No newline at end of file From a34f12ba332519c5809c510bff4ef3d1f20780e8 Mon Sep 17 00:00:00 2001 From: jp Date: Fri, 1 Nov 2024 17:28:17 +0800 Subject: [PATCH 6/9] fix tests add more deps --- requirements_test.txt | 2 ++ tests/test_utils_pruning.py | 6 +++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/requirements_test.txt b/requirements_test.txt index c426803..7723466 100644 --- a/requirements_test.txt +++ b/requirements_test.txt @@ -2,8 +2,10 @@ func_timeout numpy openai pandas +psycopg2-binary pysqlite3 sentence_transformers +snowflake-connector-python spacy sqlalchemy sqlglot diff --git a/tests/test_utils_pruning.py b/tests/test_utils_pruning.py index 63cb564..6f8c0c9 100644 --- a/tests/test_utils_pruning.py +++ b/tests/test_utils_pruning.py @@ -80,7 +80,7 @@ def test_metadata_diff_coldesc(): def test_get_md_emb_no_shuffle(test_metadata): column_emb, column_csv, column_ner, column_join = test_metadata question = "How many flights start from Los Angeles Airport (LAX)?" - assert get_entity_types(question) == {"GPE", "ORG"} + assert get_entity_types(question) == {"FAC"} k = 3 threshold = 0.0 @@ -124,7 +124,7 @@ def test_get_md_emb_no_shuffle(test_metadata): def test_get_md_emb_shuffle(test_metadata): column_emb, column_csv, column_ner, column_join = test_metadata question = "How many flights start from Los Angeles Airport (LAX)?" - assert get_entity_types(question) == {"GPE", "ORG"} + assert get_entity_types(question) == {"FAC"} k = 3 threshold = 0.0 @@ -190,7 +190,7 @@ def test_get_md_emb_sql_emb_empty(test_metadata): def test_get_md_emb_coldesc(test_metadata_diff_coldesc): column_emb, column_csv, column_ner, column_join = test_metadata_diff_coldesc question = "How many flights start from Los Angeles Airport (LAX)?" - assert get_entity_types(question) == {"GPE", "ORG"} + assert get_entity_types(question) == {"FAC"} k = 3 threshold = 0.0 From ff37ab1ce3a0a4494181133047ae586ca0b148f1 Mon Sep 17 00:00:00 2001 From: jp Date: Fri, 1 Nov 2024 17:40:52 +0800 Subject: [PATCH 7/9] add FAC to column_ner --- tests/test_utils_pruning.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/test_utils_pruning.py b/tests/test_utils_pruning.py index 6f8c0c9..d8ab29e 100644 --- a/tests/test_utils_pruning.py +++ b/tests/test_utils_pruning.py @@ -33,6 +33,11 @@ def test_metadata(): "airport.airport_name,text,name of airport", "flight.airport_name,text,name of the airport", ], + "FAC": [ + "country.name,text,country name", + "airport.airport_name,text,name of airport", + "flight.airport_name,text,name of the airport", + ], "PERSON": ["flight.pilot_name,text,name of the pilot"], } column_join = {("airport", "country"): [("airport.country_id", "country.id")]} From a986ba46c2af2a54225e058154e077487f0f66cb Mon Sep 17 00:00:00 2001 From: JP Date: Fri, 1 Nov 2024 10:01:04 +0000 Subject: [PATCH 8/9] pin spacy --- requirements_test.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements_test.txt b/requirements_test.txt index 7723466..8aeada3 100644 --- a/requirements_test.txt +++ b/requirements_test.txt @@ -6,7 +6,7 @@ psycopg2-binary pysqlite3 sentence_transformers snowflake-connector-python -spacy +spacy==3.7.2 sqlalchemy sqlglot torch From 1c2e240118f4de45adb5b4faa6df1cd689c1ea85 Mon Sep 17 00:00:00 2001 From: JP Date: Fri, 1 Nov 2024 10:02:39 +0000 Subject: [PATCH 9/9] fix entities --- tests/test_utils_pruning.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_utils_pruning.py b/tests/test_utils_pruning.py index d8ab29e..468548d 100644 --- a/tests/test_utils_pruning.py +++ b/tests/test_utils_pruning.py @@ -85,7 +85,7 @@ def test_metadata_diff_coldesc(): def test_get_md_emb_no_shuffle(test_metadata): column_emb, column_csv, column_ner, column_join = test_metadata question = "How many flights start from Los Angeles Airport (LAX)?" - assert get_entity_types(question) == {"FAC"} + assert get_entity_types(question) == {"GPE", "ORG"} k = 3 threshold = 0.0 @@ -129,7 +129,7 @@ def test_get_md_emb_no_shuffle(test_metadata): def test_get_md_emb_shuffle(test_metadata): column_emb, column_csv, column_ner, column_join = test_metadata question = "How many flights start from Los Angeles Airport (LAX)?" - assert get_entity_types(question) == {"FAC"} + assert get_entity_types(question) == {"GPE", "ORG"} k = 3 threshold = 0.0 @@ -195,7 +195,7 @@ def test_get_md_emb_sql_emb_empty(test_metadata): def test_get_md_emb_coldesc(test_metadata_diff_coldesc): column_emb, column_csv, column_ner, column_join = test_metadata_diff_coldesc question = "How many flights start from Los Angeles Airport (LAX)?" - assert get_entity_types(question) == {"FAC"} + assert get_entity_types(question) == {"GPE", "ORG"} k = 3 threshold = 0.0