-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
59 additions
and
45 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,74 +1,88 @@ | ||
import subprocess | ||
import unittest | ||
import logging | ||
|
||
# Set up logging configuration | ||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | ||
|
||
class TestTasks(unittest.TestCase): | ||
def __init__(self, *args, **kwargs): | ||
super(TestTasks, self).__init__(*args, **kwargs) | ||
self.model_name = "Qwen/Qwen2-0.5B-Instruct" | ||
self.ptemplate = "chatglm" | ||
self.wrapper_type = "vllm" | ||
self.lang = "vi" # Set the lang argument to "vi" | ||
self.seed = 42 # Set the seed to 42 | ||
self.smoke_test = True # Set the smoke_test argument to True | ||
"""Test various tasks using the melt command.""" | ||
|
||
def setUp(self): | ||
"""Set up the test environment.""" | ||
self.config = { | ||
"model_name": "Qwen/Qwen2-0.5B-Instruct", | ||
"ptemplate": "chatglm", | ||
"wrapper_type": "vllm", | ||
"lang": "vi", | ||
"seed": 42, | ||
"smoke_test": True | ||
} | ||
|
||
def run_melt_command(self, dataset_name): | ||
result = subprocess.run(["melt", "--wtype", self.wrapper_type, "--model_name", self.model_name, "--dataset_name", dataset_name, "--ptemplate", self.ptemplate, "--lang", self.lang, "--seed", str(self.seed), "--smoke_test", str(self.smoke_test)], capture_output=True, text=True) | ||
self.assertEqual(result.returncode, 0) | ||
"""Run the melt command and assert success. | ||
Args: | ||
dataset_name (str): The name of the dataset to test. | ||
""" | ||
cmd_args = [ | ||
"melt", "--wtype", self.config["wrapper_type"], | ||
"--model_name", self.config["model_name"], | ||
"--dataset_name", dataset_name, | ||
"--ptemplate", self.config["ptemplate"], | ||
"--lang", self.config["lang"], | ||
"--seed", str(self.config["seed"]), | ||
"--smoke_test", str(self.config["smoke_test"]) | ||
] | ||
|
||
result = subprocess.run(cmd_args, capture_output=True, text=True) | ||
self.assertEqual(result.returncode, 0, | ||
"Command failed for dataset %s with output: %s\n%s" % | ||
(dataset_name, result.stdout, result.stderr)) | ||
|
||
def test_sentiment_analysis(self): | ||
# Test sentiment analysis task | ||
dataset_name = "UIT-VSFC" | ||
self.run_melt_command(dataset_name) | ||
"""Test sentiment analysis task.""" | ||
self.run_melt_command("UIT-VSFC") | ||
|
||
def test_text_classification(self): | ||
# Test text classification task | ||
dataset_name = "UIT-VSMEC" | ||
self.run_melt_command(dataset_name) | ||
"""Test text classification task.""" | ||
self.run_melt_command("UIT-VSMEC") | ||
|
||
def test_toxic_detection(self): | ||
# Test toxic detection task | ||
dataset_name = "ViHSD" | ||
self.run_melt_command(dataset_name) | ||
|
||
"""Test toxic detection task.""" | ||
self.run_melt_command("ViHSD") | ||
|
||
def test_reasoning(self): | ||
# Test reasoning task | ||
dataset_name = "synthetic_natural_azr" | ||
self.run_melt_command(dataset_name) | ||
"""Test reasoning task.""" | ||
self.run_melt_command("synthetic_natural_azr") | ||
|
||
def test_open_ended_knowledge(self): | ||
# Test open-ended knowledge task | ||
dataset_name = "zalo_e2eqa" | ||
self.run_melt_command(dataset_name) | ||
"""Test open-ended knowledge task.""" | ||
self.run_melt_command("zalo_e2eqa") | ||
|
||
def test_multiple_choice_knowledge(self): | ||
# Test multiple choice knowledge task | ||
dataset_name = "ViMMRC" | ||
self.run_melt_command(dataset_name) | ||
"""Test multiple choice knowledge task.""" | ||
self.run_melt_command("ViMMRC") | ||
|
||
def test_math(self): | ||
# Test math task | ||
dataset_name = "math_level1_azr" | ||
self.run_melt_command(dataset_name) | ||
"""Test math task.""" | ||
self.run_melt_command("math_level1_azr") | ||
|
||
def test_translation(self): | ||
# Test translation task | ||
dataset_name = "opus100_envi" | ||
self.run_melt_command(dataset_name) | ||
"""Test translation task.""" | ||
self.run_melt_command("opus100_envi") | ||
|
||
def test_summarization(self): | ||
# Test summarization task | ||
dataset_name = "wiki_lingua" | ||
self.run_melt_command(dataset_name) | ||
"""Test summarization task.""" | ||
self.run_melt_command("wiki_lingua") | ||
|
||
def test_question_answering(self): | ||
# Test question answering task | ||
dataset_name = "xquad_xtreme" | ||
self.run_melt_command(dataset_name) | ||
"""Test question answering task.""" | ||
self.run_melt_command("xquad_xtreme") | ||
|
||
def test_information_retrieval(self): | ||
# Test information retrieval task | ||
dataset_name = "mmarco" | ||
self.run_melt_command(dataset_name) | ||
"""Test information retrieval task.""" | ||
self.run_melt_command("mmarco") | ||
|
||
if __name__ == '__main__': | ||
unittest.main() | ||
unittest.main() |