-
Notifications
You must be signed in to change notification settings - Fork 109
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #19 from Filimoa/torch-device
Global PyTorch Config
- Loading branch information
Showing
9 changed files
with
114 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
If you're using the ml dependencies, the config class manages the computational device settings for your project, | ||
|
||
### Config | ||
|
||
This is a simple wrapper around pytorch. Setting the device will fail if pytorch is not installed. | ||
|
||
```python | ||
import openparse | ||
|
||
openparse.config.set_device("cpu") | ||
``` | ||
|
||
Note if you're on apple silicon, setting this to `mps` runs significantly slower than on `cpu`. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from typing import Literal | ||
|
||
|
||
TorchDevice = Literal["cuda", "cpu", "mps"] | ||
|
||
|
||
class Config: | ||
def __init__(self): | ||
self._device = "cpu" # Default to CPU | ||
self._torch_available = False | ||
self._cuda_available = False | ||
try: | ||
import torch | ||
|
||
self._torch_available = True | ||
if torch.cuda.is_available(): | ||
self._device = "cuda" | ||
self._cuda_available = True | ||
except ImportError: | ||
pass | ||
|
||
def set_device(self, device: TorchDevice) -> None: | ||
if not self._torch_available and device == "cuda": | ||
raise RuntimeError( | ||
"CUDA device requested but torch is not available. Have you installed ml dependencies?" | ||
) | ||
if not self._cuda_available and device == "cuda": | ||
raise RuntimeError("CUDA device requested but CUDA is not available") | ||
if device not in ["cuda", "cpu", "mps"]: | ||
raise ValueError("Device must be 'cuda', 'cpu' or 'mps'") | ||
self._device = device | ||
|
||
def get_device(self): | ||
if self._torch_available: | ||
import torch | ||
|
||
return torch.device(self._device) | ||
else: | ||
return self._device | ||
|
||
|
||
config = Config() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
OPEN_PARSE_VERSION = "0.5.1" | ||
|
||
|
||
def version_info() -> str: | ||
"""Return complete version information for OpenParse and its dependencies.""" | ||
import importlib.metadata as importlib_metadata | ||
import platform | ||
import sys | ||
from pathlib import Path | ||
|
||
python_version = sys.version.split()[0] | ||
operating_system = platform.system() | ||
os_version = platform.release() | ||
|
||
package_names = { | ||
"email-validator", | ||
"torch", | ||
"torchvision", | ||
"transformers", | ||
"tokenizers", | ||
"PyMuPDF", | ||
"pydantic", | ||
} | ||
related_packages = [] | ||
|
||
for dist in importlib_metadata.distributions(): | ||
name = dist.metadata["Name"] | ||
if name in package_names: | ||
related_packages.append(f"{name}-{dist.version}") | ||
|
||
info = { | ||
"python_version": python_version, | ||
"operating_system": operating_system, | ||
"os_version": os_version, | ||
"open-parse version": OPEN_PARSE_VERSION, | ||
"install path": Path(__file__).resolve().parent, | ||
"python version": sys.version, | ||
"platform": platform.platform(), | ||
"related packages": " ".join(related_packages), | ||
} | ||
return "\n".join( | ||
"{:>30} {}".format(k + ":", str(v).replace("\n", " ")) for k, v in info.items() | ||
) |