Skip to content

Commit

Permalink
add test cases for parameters half_precision_model and half_precision…
Browse files Browse the repository at this point in the history
…_ops to test_re_text_classification
  • Loading branch information
ArneBinder committed Jan 13, 2025
1 parent 4b78a86 commit 7e11234
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions tests/pipeline/test_re_text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,28 @@ class ExampleDocument(TextDocument):

@pytest.mark.slow
@pytest.mark.parametrize("use_auto", [False, True])
def test_re_text_classification(use_auto):
@pytest.mark.parametrize("half_precision_model", [False, True])
@pytest.mark.parametrize("half_precision_ops", [False, True])
def test_re_text_classification(use_auto, half_precision_model, half_precision_ops):
model_name_or_path = "pie/example-re-textclf-tacred"
if use_auto:
pipeline = AutoPipeline.from_pretrained(
model_name_or_path, taskmodule_kwargs={"create_relation_candidates": True}
model_name_or_path,
taskmodule_kwargs={"create_relation_candidates": True},
half_precision_model=half_precision_model,
)
else:
re_taskmodule = TransformerRETextClassificationTaskModule.from_pretrained(
model_name_or_path,
create_relation_candidates=True,
)
re_model = TransformerTextClassificationModel.from_pretrained(model_name_or_path)
pipeline = Pipeline(model=re_model, taskmodule=re_taskmodule, device=-1)
pipeline = Pipeline(
model=re_model,
taskmodule=re_taskmodule,
device=-1,
half_precision_model=half_precision_model,
)
assert pipeline.taskmodule.is_from_pretrained
assert pipeline.model.is_from_pretrained

Expand All @@ -44,7 +53,7 @@ def test_re_text_classification(use_auto):
for start, end, label in [(65, 75, "PER"), (96, 100, "ORG"), (126, 134, "ORG")]:
document.entities.append(LabeledSpan(start=start, end=end, label=label))

pipeline(document, batch_size=2)
pipeline(document, batch_size=2, half_precision_ops=half_precision_ops)
relations: Sequence[BinaryRelation] = document["relations"].predictions
assert len(relations) == 3

Expand Down

0 comments on commit 7e11234

Please sign in to comment.