diff --git a/tests/pipeline/test_re_text_classification.py b/tests/pipeline/test_re_text_classification.py index 35dfb82e..8eed463a 100644 --- a/tests/pipeline/test_re_text_classification.py +++ b/tests/pipeline/test_re_text_classification.py @@ -20,11 +20,15 @@ class ExampleDocument(TextDocument): @pytest.mark.slow @pytest.mark.parametrize("use_auto", [False, True]) -def test_re_text_classification(use_auto): +@pytest.mark.parametrize("half_precision_model", [False, True]) +@pytest.mark.parametrize("half_precision_ops", [False, True]) +def test_re_text_classification(use_auto, half_precision_model, half_precision_ops): model_name_or_path = "pie/example-re-textclf-tacred" if use_auto: pipeline = AutoPipeline.from_pretrained( - model_name_or_path, taskmodule_kwargs={"create_relation_candidates": True} + model_name_or_path, + taskmodule_kwargs={"create_relation_candidates": True}, + half_precision_model=half_precision_model, ) else: re_taskmodule = TransformerRETextClassificationTaskModule.from_pretrained( @@ -32,7 +36,12 @@ def test_re_text_classification(use_auto): create_relation_candidates=True, ) re_model = TransformerTextClassificationModel.from_pretrained(model_name_or_path) - pipeline = Pipeline(model=re_model, taskmodule=re_taskmodule, device=-1) + pipeline = Pipeline( + model=re_model, + taskmodule=re_taskmodule, + device=-1, + half_precision_model=half_precision_model, + ) assert pipeline.taskmodule.is_from_pretrained assert pipeline.model.is_from_pretrained @@ -44,7 +53,7 @@ def test_re_text_classification(use_auto): for start, end, label in [(65, 75, "PER"), (96, 100, "ORG"), (126, 134, "ORG")]: document.entities.append(LabeledSpan(start=start, end=end, label=label)) - pipeline(document, batch_size=2) + pipeline(document, batch_size=2, half_precision_ops=half_precision_ops) relations: Sequence[BinaryRelation] = document["relations"].predictions assert len(relations) == 3