Skip to content

Commit

Permalink
half and mixed precision inference (#442)
Browse files Browse the repository at this point in the history
* implement half and mixed precision inference in Pipeline

* cast probabiliteis to float before converting to numpy

* add test cases for parameters half_precision_model and half_precision_ops to test_re_text_classification

* add docs for new params

* fix spellings in docs

* fix documentation

* fix documentation

* add todo

* move get_autocast_dtype out of Pipeline class
  • Loading branch information
ArneBinder authored Jan 16, 2025
1 parent 600d4b7 commit 9ab36d8
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 17 deletions.
40 changes: 32 additions & 8 deletions src/pytorch_ie/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@
logger = logging.getLogger(__name__)


# TODO: use torch.get_autocast_dtype when available
def get_autocast_dtype(device_type: str):
if device_type == "cuda":
return torch.float16
elif device_type == "cpu":
return torch.bfloat16
else:
raise ValueError(f"Unsupported device type for half precision autocast: {device_type}")


class Pipeline:
"""
The Pipeline class is the class from which all pipelines inherit. Refer to this class for methods shared across
Expand All @@ -46,6 +56,10 @@ class Pipeline:
The device to run the pipeline on. This can be a CPU device (:obj:`"cpu"`), a GPU
device (:obj:`"cuda"`) or a specific GPU device (:obj:`"cuda:X"`, where :obj:`X`
is the index of the GPU).
half_precision_model (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to use half precision model. This can be set to :obj:`True` to reduce
the memory usage of the model. If set to :obj:`True`, the model will be cast to
:obj:`torch.float16` on supported devices.
"""

default_input_names = None
Expand All @@ -55,6 +69,7 @@ def __init__(
model: PyTorchIEModel,
taskmodule: TaskModule,
device: Union[int, str] = "cpu",
half_precision_model: bool = False,
**kwargs,
):
self.taskmodule = taskmodule
Expand All @@ -66,6 +81,8 @@ def __init__(
# Module.to() returns just self, but moved to the device. This is not correctly
# reflected in typing of PyTorch.
self.model: PyTorchIEModel = model.to(self.device) # type: ignore
if half_precision_model:
self.model = self.model.to(dtype=get_autocast_dtype(self.device.type))

self.call_count = 0
(
Expand Down Expand Up @@ -190,7 +207,7 @@ def _sanitize_parameters(
preprocess_parameters[p_name] = pipeline_parameters[p_name]

# set forward parameters
for p_name in ["show_progress_bar", "fast_dev_run"]:
for p_name in ["show_progress_bar", "fast_dev_run", "half_precision_ops"]:
if p_name in pipeline_parameters:
forward_parameters[p_name] = pipeline_parameters[p_name]

Expand All @@ -213,7 +230,7 @@ def preprocess(
**preprocess_parameters: Dict,
) -> Sequence[TaskEncoding]:
"""
Preprocess will take the `input_` of a specific pipeline and return a dictionnary of everything necessary for
Preprocess will take the `input_` of a specific pipeline and return a dictionary of everything necessary for
`_forward` to run properly. It should contain at least one tensor, but might have arbitrary other items.
"""

Expand All @@ -231,7 +248,7 @@ def _forward(
self, input_tensors: Tuple[Dict[str, Tensor], Any, Any, Any], **forward_parameters: Dict
) -> Dict:
"""
_forward will receive the prepared dictionnary from `preprocess` and run it on the model. This method might
_forward will receive the prepared dictionary from `preprocess` and run it on the model. This method might
involve the GPU or the CPU and should be agnostic to it. Isolating this function is the reason for `preprocess`
and `postprocess` to exist, so that the hot path, this method generally can run as fast as possible.
Expand Down Expand Up @@ -308,7 +325,7 @@ def __call__(
1. Encode the documents
2. Run the model forward pass(es) on the encodings
3. Combine the model outputs with the inputs encodings and integrate them back into the documents
3. Combine the model outputs with the input encodings and integrate them back into the documents
Args:
documents (:obj:`Union[Document, Sequence[Document]]`): The documents to process. If a single document is
Expand All @@ -320,6 +337,9 @@ def __call__(
during inference.
fast_dev_run (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to run a fast development
run. If set to :obj:`True`, only the first two model inputs will be processed.
half_precision_ops (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to use half
precision operations. If set to :obj:`True`, the model will be run with half precision operations
via :obj:`torch.autocast`.
batch_size (:obj:`int`, `optional`, defaults to :obj:`1`): The batch size to use for the dataloader. If not
provided, a batch size of 1 will be used.
num_workers (:obj:`int`, `optional`, defaults to :obj:`8`): The number of workers to use for the dataloader.
Expand Down Expand Up @@ -378,12 +398,16 @@ def __call__(
dataloader = self.get_dataloader(model_inputs=model_inputs, **dataloader_params)

show_progress_bar = forward_params.pop("show_progress_bar", False)
half_precision_ops = forward_params.pop("half_precision_ops", False)
model_outputs: List = []
with torch.no_grad():
for batch in tqdm.tqdm(dataloader, desc="inference", disable=not show_progress_bar):
output = self.forward(batch, **forward_params)
processed_output = self.taskmodule.unbatch_output(output)
model_outputs.extend(processed_output)
with torch.autocast(device_type=self.device.type, enabled=half_precision_ops):
for batch in tqdm.tqdm(
dataloader, desc="inference", disable=not show_progress_bar
):
output = self.forward(batch, **forward_params)
processed_output = self.taskmodule.unbatch_output(output)
model_outputs.extend(processed_output)

assert len(model_inputs) == len(
model_outputs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def unbatch_output(self, model_output: ModelOutputType) -> Sequence[TaskOutputTy
logits = model_output["logits"]

# convert the logits to "probabilities"
probabilities = logits.softmax(dim=-1).detach().cpu().numpy()
probabilities = logits.softmax(dim=-1).detach().cpu().float().numpy()

# get the max class index per example
max_label_ids = np.argmax(probabilities, axis=-1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ def unbatch_output(self, model_output: ModelOutputType) -> Sequence[TaskOutputTy
logits = model_output["logits"]

output_label_probs = logits.sigmoid() if self.multi_label else logits.softmax(dim=-1)
output_label_probs = output_label_probs.detach().cpu().numpy()
output_label_probs = output_label_probs.detach().cpu().float().numpy()

unbatched_output = []
if self.multi_label:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def encode_target(

def unbatch_output(self, model_output: ModelOutputType) -> Sequence[TaskOutputType]:
logits = model_output["logits"]
probs = F.softmax(logits, dim=-1).detach().cpu().numpy()
probs = F.softmax(logits, dim=-1).detach().cpu().float().numpy()
label_ids = torch.argmax(logits, dim=-1).detach().cpu().numpy()

start_indices = model_output["start_indices"].detach().cpu().numpy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def unbatch_output(self, model_output: ModelOutputType) -> Sequence[TaskOutputTy
logits = model_output["logits"]

output_label_probs = logits.sigmoid() if self.multi_label else logits.softmax(dim=-1)
output_label_probs = output_label_probs.detach().cpu().numpy()
output_label_probs = output_label_probs.detach().cpu().float().numpy()

if self.multi_label:
raise NotImplementedError()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def encode_target(

def unbatch_output(self, model_output: ModelOutputType) -> Sequence[TaskOutputType]:
logits = model_output["logits"]
probabilities = F.softmax(logits, dim=-1).detach().cpu().numpy()
probabilities = F.softmax(logits, dim=-1).detach().cpu().float().numpy()
indices = torch.argmax(logits, dim=-1).detach().cpu().numpy()
tags = [[self.id_to_label[e] for e in b] for b in indices]
return [{"tags": t, "probabilities": p} for t, p in zip(tags, probabilities)]
Expand Down
17 changes: 13 additions & 4 deletions tests/pipeline/test_re_text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,28 @@ class ExampleDocument(TextDocument):

@pytest.mark.slow
@pytest.mark.parametrize("use_auto", [False, True])
def test_re_text_classification(use_auto):
@pytest.mark.parametrize("half_precision_model", [False, True])
@pytest.mark.parametrize("half_precision_ops", [False, True])
def test_re_text_classification(use_auto, half_precision_model, half_precision_ops):
model_name_or_path = "pie/example-re-textclf-tacred"
if use_auto:
pipeline = AutoPipeline.from_pretrained(
model_name_or_path, taskmodule_kwargs={"create_relation_candidates": True}
model_name_or_path,
taskmodule_kwargs={"create_relation_candidates": True},
half_precision_model=half_precision_model,
)
else:
re_taskmodule = TransformerRETextClassificationTaskModule.from_pretrained(
model_name_or_path,
create_relation_candidates=True,
)
re_model = TransformerTextClassificationModel.from_pretrained(model_name_or_path)
pipeline = Pipeline(model=re_model, taskmodule=re_taskmodule, device=-1)
pipeline = Pipeline(
model=re_model,
taskmodule=re_taskmodule,
device=-1,
half_precision_model=half_precision_model,
)
assert pipeline.taskmodule.is_from_pretrained
assert pipeline.model.is_from_pretrained

Expand All @@ -44,7 +53,7 @@ def test_re_text_classification(use_auto):
for start, end, label in [(65, 75, "PER"), (96, 100, "ORG"), (126, 134, "ORG")]:
document.entities.append(LabeledSpan(start=start, end=end, label=label))

pipeline(document, batch_size=2)
pipeline(document, batch_size=2, half_precision_ops=half_precision_ops)
relations: Sequence[BinaryRelation] = document["relations"].predictions
assert len(relations) == 3

Expand Down

0 comments on commit 9ab36d8

Please sign in to comment.