Skip to content

Commit

Permalink
Update test_execution.py
Browse files Browse the repository at this point in the history
  • Loading branch information
hieuphmm authored Sep 1, 2024
1 parent 9f1ba54 commit 8e274ec
Showing 1 changed file with 59 additions and 45 deletions.
104 changes: 59 additions & 45 deletions tests/test_execution.py
Original file line number Diff line number Diff line change
@@ -1,74 +1,88 @@
import subprocess
import unittest
import logging

# Set up logging configuration
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

class TestTasks(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(TestTasks, self).__init__(*args, **kwargs)
self.model_name = "Qwen/Qwen2-0.5B-Instruct"
self.ptemplate = "chatglm"
self.wrapper_type = "vllm"
self.lang = "vi" # Set the lang argument to "vi"
self.seed = 42 # Set the seed to 42
self.smoke_test = True # Set the smoke_test argument to True
"""Test various tasks using the melt command."""

def setUp(self):
"""Set up the test environment."""
self.config = {
"model_name": "Qwen/Qwen2-0.5B-Instruct",
"ptemplate": "chatglm",
"wrapper_type": "vllm",
"lang": "vi",
"seed": 42,
"smoke_test": True
}

def run_melt_command(self, dataset_name):
result = subprocess.run(["melt", "--wtype", self.wrapper_type, "--model_name", self.model_name, "--dataset_name", dataset_name, "--ptemplate", self.ptemplate, "--lang", self.lang, "--seed", str(self.seed), "--smoke_test", str(self.smoke_test)], capture_output=True, text=True)
self.assertEqual(result.returncode, 0)
"""Run the melt command and assert success.
Args:
dataset_name (str): The name of the dataset to test.
"""
cmd_args = [
"melt", "--wtype", self.config["wrapper_type"],
"--model_name", self.config["model_name"],
"--dataset_name", dataset_name,
"--ptemplate", self.config["ptemplate"],
"--lang", self.config["lang"],
"--seed", str(self.config["seed"]),
"--smoke_test", str(self.config["smoke_test"])
]

result = subprocess.run(cmd_args, capture_output=True, text=True)
self.assertEqual(result.returncode, 0,
"Command failed for dataset %s with output: %s\n%s" %
(dataset_name, result.stdout, result.stderr))

def test_sentiment_analysis(self):
# Test sentiment analysis task
dataset_name = "UIT-VSFC"
self.run_melt_command(dataset_name)
"""Test sentiment analysis task."""
self.run_melt_command("UIT-VSFC")

def test_text_classification(self):
# Test text classification task
dataset_name = "UIT-VSMEC"
self.run_melt_command(dataset_name)
"""Test text classification task."""
self.run_melt_command("UIT-VSMEC")

def test_toxic_detection(self):
# Test toxic detection task
dataset_name = "ViHSD"
self.run_melt_command(dataset_name)

"""Test toxic detection task."""
self.run_melt_command("ViHSD")

def test_reasoning(self):
# Test reasoning task
dataset_name = "synthetic_natural_azr"
self.run_melt_command(dataset_name)
"""Test reasoning task."""
self.run_melt_command("synthetic_natural_azr")

def test_open_ended_knowledge(self):
# Test open-ended knowledge task
dataset_name = "zalo_e2eqa"
self.run_melt_command(dataset_name)
"""Test open-ended knowledge task."""
self.run_melt_command("zalo_e2eqa")

def test_multiple_choice_knowledge(self):
# Test multiple choice knowledge task
dataset_name = "ViMMRC"
self.run_melt_command(dataset_name)
"""Test multiple choice knowledge task."""
self.run_melt_command("ViMMRC")

def test_math(self):
# Test math task
dataset_name = "math_level1_azr"
self.run_melt_command(dataset_name)
"""Test math task."""
self.run_melt_command("math_level1_azr")

def test_translation(self):
# Test translation task
dataset_name = "opus100_envi"
self.run_melt_command(dataset_name)
"""Test translation task."""
self.run_melt_command("opus100_envi")

def test_summarization(self):
# Test summarization task
dataset_name = "wiki_lingua"
self.run_melt_command(dataset_name)
"""Test summarization task."""
self.run_melt_command("wiki_lingua")

def test_question_answering(self):
# Test question answering task
dataset_name = "xquad_xtreme"
self.run_melt_command(dataset_name)
"""Test question answering task."""
self.run_melt_command("xquad_xtreme")

def test_information_retrieval(self):
# Test information retrieval task
dataset_name = "mmarco"
self.run_melt_command(dataset_name)
"""Test information retrieval task."""
self.run_melt_command("mmarco")

if __name__ == '__main__':
unittest.main()
unittest.main()

0 comments on commit 8e274ec

Please sign in to comment.