Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

half and mixed precision inference #442

Merged
merged 9 commits into from
Jan 16, 2025
Prev Previous commit
move get_autocast_dtype out of Pipeline class
ArneBinder committed Jan 14, 2025
commit ca3112da6c9ba0bd0a41a4e2e7e745b44500505c
23 changes: 11 additions & 12 deletions src/pytorch_ie/pipeline.py
Original file line number Diff line number Diff line change
@@ -25,6 +25,16 @@
logger = logging.getLogger(__name__)


# TODO: use torch.get_autocast_dtype when available
def get_autocast_dtype(device_type: str):
if device_type == "cuda":
return torch.float16
elif device_type == "cpu":
return torch.bfloat16
else:
raise ValueError(f"Unsupported device type for half precision autocast: {device_type}")


class Pipeline:
"""
The Pipeline class is the class from which all pipelines inherit. Refer to this class for methods shared across
@@ -72,8 +82,7 @@ def __init__(
# reflected in typing of PyTorch.
self.model: PyTorchIEModel = model.to(self.device) # type: ignore
if half_precision_model:
# TODO: use torch.get_autocast_dtype(self.device.type) when available
self.model = self.model.to(dtype=self.get_autocast_dtype())
self.model = self.model.to(dtype=get_autocast_dtype(self.device.type))

self.call_count = 0
(
@@ -83,16 +92,6 @@ def __init__(
self._postprocess_params,
) = self._sanitize_parameters(**kwargs)

def get_autocast_dtype(self):
if self.device.type == "cuda":
return torch.float16
elif self.device.type == "cpu":
return torch.bfloat16
else:
raise ValueError(
f"Unsupported device type for half precision autocast: {self.device.type}"
)

def save_pretrained(self, save_directory: str):
"""
Save the pipeline's model and taskmodule.