diff --git a/src/pytorch_ie/pipeline.py b/src/pytorch_ie/pipeline.py index 26985077..4617de00 100644 --- a/src/pytorch_ie/pipeline.py +++ b/src/pytorch_ie/pipeline.py @@ -25,6 +25,16 @@ logger = logging.getLogger(__name__) +# TODO: use torch.get_autocast_dtype when available +def get_autocast_dtype(device_type: str): + if device_type == "cuda": + return torch.float16 + elif device_type == "cpu": + return torch.bfloat16 + else: + raise ValueError(f"Unsupported device type for half precision autocast: {device_type}") + + class Pipeline: """ The Pipeline class is the class from which all pipelines inherit. Refer to this class for methods shared across @@ -46,6 +56,10 @@ class Pipeline: The device to run the pipeline on. This can be a CPU device (:obj:`"cpu"`), a GPU device (:obj:`"cuda"`) or a specific GPU device (:obj:`"cuda:X"`, where :obj:`X` is the index of the GPU). + half_precision_model (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to use half precision model. This can be set to :obj:`True` to reduce + the memory usage of the model. If set to :obj:`True`, the model will be cast to + :obj:`torch.float16` on supported devices. """ default_input_names = None @@ -55,6 +69,7 @@ def __init__( model: PyTorchIEModel, taskmodule: TaskModule, device: Union[int, str] = "cpu", + half_precision_model: bool = False, **kwargs, ): self.taskmodule = taskmodule @@ -66,6 +81,8 @@ def __init__( # Module.to() returns just self, but moved to the device. This is not correctly # reflected in typing of PyTorch. self.model: PyTorchIEModel = model.to(self.device) # type: ignore + if half_precision_model: + self.model = self.model.to(dtype=get_autocast_dtype(self.device.type)) self.call_count = 0 ( @@ -190,7 +207,7 @@ def _sanitize_parameters( preprocess_parameters[p_name] = pipeline_parameters[p_name] # set forward parameters - for p_name in ["show_progress_bar", "fast_dev_run"]: + for p_name in ["show_progress_bar", "fast_dev_run", "half_precision_ops"]: if p_name in pipeline_parameters: forward_parameters[p_name] = pipeline_parameters[p_name] @@ -213,7 +230,7 @@ def preprocess( **preprocess_parameters: Dict, ) -> Sequence[TaskEncoding]: """ - Preprocess will take the `input_` of a specific pipeline and return a dictionnary of everything necessary for + Preprocess will take the `input_` of a specific pipeline and return a dictionary of everything necessary for `_forward` to run properly. It should contain at least one tensor, but might have arbitrary other items. """ @@ -231,7 +248,7 @@ def _forward( self, input_tensors: Tuple[Dict[str, Tensor], Any, Any, Any], **forward_parameters: Dict ) -> Dict: """ - _forward will receive the prepared dictionnary from `preprocess` and run it on the model. This method might + _forward will receive the prepared dictionary from `preprocess` and run it on the model. This method might involve the GPU or the CPU and should be agnostic to it. Isolating this function is the reason for `preprocess` and `postprocess` to exist, so that the hot path, this method generally can run as fast as possible. @@ -308,7 +325,7 @@ def __call__( 1. Encode the documents 2. Run the model forward pass(es) on the encodings - 3. Combine the model outputs with the inputs encodings and integrate them back into the documents + 3. Combine the model outputs with the input encodings and integrate them back into the documents Args: documents (:obj:`Union[Document, Sequence[Document]]`): The documents to process. If a single document is @@ -320,6 +337,9 @@ def __call__( during inference. fast_dev_run (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to run a fast development run. If set to :obj:`True`, only the first two model inputs will be processed. + half_precision_ops (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to use half + precision operations. If set to :obj:`True`, the model will be run with half precision operations + via :obj:`torch.autocast`. batch_size (:obj:`int`, `optional`, defaults to :obj:`1`): The batch size to use for the dataloader. If not provided, a batch size of 1 will be used. num_workers (:obj:`int`, `optional`, defaults to :obj:`8`): The number of workers to use for the dataloader. @@ -378,12 +398,16 @@ def __call__( dataloader = self.get_dataloader(model_inputs=model_inputs, **dataloader_params) show_progress_bar = forward_params.pop("show_progress_bar", False) + half_precision_ops = forward_params.pop("half_precision_ops", False) model_outputs: List = [] with torch.no_grad(): - for batch in tqdm.tqdm(dataloader, desc="inference", disable=not show_progress_bar): - output = self.forward(batch, **forward_params) - processed_output = self.taskmodule.unbatch_output(output) - model_outputs.extend(processed_output) + with torch.autocast(device_type=self.device.type, enabled=half_precision_ops): + for batch in tqdm.tqdm( + dataloader, desc="inference", disable=not show_progress_bar + ): + output = self.forward(batch, **forward_params) + processed_output = self.taskmodule.unbatch_output(output) + model_outputs.extend(processed_output) assert len(model_inputs) == len( model_outputs diff --git a/src/pytorch_ie/taskmodules/simple_transformer_text_classification.py b/src/pytorch_ie/taskmodules/simple_transformer_text_classification.py index 171d6532..734e0f85 100644 --- a/src/pytorch_ie/taskmodules/simple_transformer_text_classification.py +++ b/src/pytorch_ie/taskmodules/simple_transformer_text_classification.py @@ -182,7 +182,7 @@ def unbatch_output(self, model_output: ModelOutputType) -> Sequence[TaskOutputTy logits = model_output["logits"] # convert the logits to "probabilities" - probabilities = logits.softmax(dim=-1).detach().cpu().numpy() + probabilities = logits.softmax(dim=-1).detach().cpu().float().numpy() # get the max class index per example max_label_ids = np.argmax(probabilities, axis=-1) diff --git a/src/pytorch_ie/taskmodules/transformer_re_text_classification.py b/src/pytorch_ie/taskmodules/transformer_re_text_classification.py index 99cdacda..36123aa8 100644 --- a/src/pytorch_ie/taskmodules/transformer_re_text_classification.py +++ b/src/pytorch_ie/taskmodules/transformer_re_text_classification.py @@ -556,7 +556,7 @@ def unbatch_output(self, model_output: ModelOutputType) -> Sequence[TaskOutputTy logits = model_output["logits"] output_label_probs = logits.sigmoid() if self.multi_label else logits.softmax(dim=-1) - output_label_probs = output_label_probs.detach().cpu().numpy() + output_label_probs = output_label_probs.detach().cpu().float().numpy() unbatched_output = [] if self.multi_label: diff --git a/src/pytorch_ie/taskmodules/transformer_span_classification.py b/src/pytorch_ie/taskmodules/transformer_span_classification.py index bbeff94e..74671b93 100644 --- a/src/pytorch_ie/taskmodules/transformer_span_classification.py +++ b/src/pytorch_ie/taskmodules/transformer_span_classification.py @@ -227,7 +227,7 @@ def encode_target( def unbatch_output(self, model_output: ModelOutputType) -> Sequence[TaskOutputType]: logits = model_output["logits"] - probs = F.softmax(logits, dim=-1).detach().cpu().numpy() + probs = F.softmax(logits, dim=-1).detach().cpu().float().numpy() label_ids = torch.argmax(logits, dim=-1).detach().cpu().numpy() start_indices = model_output["start_indices"].detach().cpu().numpy() diff --git a/src/pytorch_ie/taskmodules/transformer_text_classification.py b/src/pytorch_ie/taskmodules/transformer_text_classification.py index 8b624aaf..ead014d7 100644 --- a/src/pytorch_ie/taskmodules/transformer_text_classification.py +++ b/src/pytorch_ie/taskmodules/transformer_text_classification.py @@ -196,7 +196,7 @@ def unbatch_output(self, model_output: ModelOutputType) -> Sequence[TaskOutputTy logits = model_output["logits"] output_label_probs = logits.sigmoid() if self.multi_label else logits.softmax(dim=-1) - output_label_probs = output_label_probs.detach().cpu().numpy() + output_label_probs = output_label_probs.detach().cpu().float().numpy() if self.multi_label: raise NotImplementedError() diff --git a/src/pytorch_ie/taskmodules/transformer_token_classification.py b/src/pytorch_ie/taskmodules/transformer_token_classification.py index 46709394..55a5649f 100644 --- a/src/pytorch_ie/taskmodules/transformer_token_classification.py +++ b/src/pytorch_ie/taskmodules/transformer_token_classification.py @@ -290,7 +290,7 @@ def encode_target( def unbatch_output(self, model_output: ModelOutputType) -> Sequence[TaskOutputType]: logits = model_output["logits"] - probabilities = F.softmax(logits, dim=-1).detach().cpu().numpy() + probabilities = F.softmax(logits, dim=-1).detach().cpu().float().numpy() indices = torch.argmax(logits, dim=-1).detach().cpu().numpy() tags = [[self.id_to_label[e] for e in b] for b in indices] return [{"tags": t, "probabilities": p} for t, p in zip(tags, probabilities)] diff --git a/tests/pipeline/test_re_text_classification.py b/tests/pipeline/test_re_text_classification.py index 35dfb82e..8eed463a 100644 --- a/tests/pipeline/test_re_text_classification.py +++ b/tests/pipeline/test_re_text_classification.py @@ -20,11 +20,15 @@ class ExampleDocument(TextDocument): @pytest.mark.slow @pytest.mark.parametrize("use_auto", [False, True]) -def test_re_text_classification(use_auto): +@pytest.mark.parametrize("half_precision_model", [False, True]) +@pytest.mark.parametrize("half_precision_ops", [False, True]) +def test_re_text_classification(use_auto, half_precision_model, half_precision_ops): model_name_or_path = "pie/example-re-textclf-tacred" if use_auto: pipeline = AutoPipeline.from_pretrained( - model_name_or_path, taskmodule_kwargs={"create_relation_candidates": True} + model_name_or_path, + taskmodule_kwargs={"create_relation_candidates": True}, + half_precision_model=half_precision_model, ) else: re_taskmodule = TransformerRETextClassificationTaskModule.from_pretrained( @@ -32,7 +36,12 @@ def test_re_text_classification(use_auto): create_relation_candidates=True, ) re_model = TransformerTextClassificationModel.from_pretrained(model_name_or_path) - pipeline = Pipeline(model=re_model, taskmodule=re_taskmodule, device=-1) + pipeline = Pipeline( + model=re_model, + taskmodule=re_taskmodule, + device=-1, + half_precision_model=half_precision_model, + ) assert pipeline.taskmodule.is_from_pretrained assert pipeline.model.is_from_pretrained @@ -44,7 +53,7 @@ def test_re_text_classification(use_auto): for start, end, label in [(65, 75, "PER"), (96, 100, "ORG"), (126, 134, "ORG")]: document.entities.append(LabeledSpan(start=start, end=end, label=label)) - pipeline(document, batch_size=2) + pipeline(document, batch_size=2, half_precision_ops=half_precision_ops) relations: Sequence[BinaryRelation] = document["relations"].predictions assert len(relations) == 3