Skip to content

Commit

Permalink
remove: to(cpu) code
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelfeil committed Nov 4, 2023
1 parent a635bdf commit 49b6108
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 57 deletions.
54 changes: 2 additions & 52 deletions libs/infinity_emb/infinity_emb/transformer/sentence_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
try:
import torch
from sentence_transformers import SentenceTransformer, util # type: ignore
from torch import Tensor, device, dtype
from torch import Tensor
from torch.nn import Module

TORCH_AVAILABLE = True
except ImportError:
torch, Tensor, device, dtype = None, None, None, None
torch, Tensor = None, None

class SentenceTransformer:
pass
Expand Down Expand Up @@ -233,56 +233,6 @@ def children(self):
# child module so that it will stay on the CPU.
return []

def half(self):
self.to(dtype="float16")
return self

def to(
self,
device: int | device | None = None,
dtype: dtype | str | None = None,
non_blocking: bool = False,
) -> "CT2Transformer":
if not isinstance(device, int):
raise ValueError("param `dtype` needs to be of type int")
if not isinstance(dtype, str) or dtype is not None:
raise ValueError("param `dtype` needs to be of type str")

if dtype and not ("float" in dtype or "int" in dtype):
raise ValueError(
"dtype should be one of `int8`, `float16`, `int8_float16`, `float32`"
)
elif dtype:
new_dtype = True
self.compute_type = new_dtype
else:
new_dtype = False

if device and (device.startswith("cuda") or device.startswith("cpu")):
raise ValueError(
"for param `device`, f'cuda:{index}' or f'cpu:{index}' are supported"
)
elif device:
if ":" in device:
new_device = device.split(":")[0]
new_index = device.split(":")[1]
else:
new_device = device
new_index = "0"
else:
new_device = ""
new_index = ""

if new_device or new_dtype or new_index:
self.encoder = self._ctranslate2_encoder_cls(
self.ct2_model_dir,
device=new_device,
device_index=new_index,
intra_threads=torch.get_num_threads(),
compute_type=self.compute_type,
)
return self

def forward(self, features):
"""overwrites torch forward method with CTranslate model"""
device = features["input_ids"].device
Expand Down
6 changes: 1 addition & 5 deletions libs/infinity_emb/tests/unit_test/inference/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,12 @@ def _pretrained_model_score(
model_name,
expected_score,
ct2_compute_type: str = "",
device: str = "cuda",
):
test_samples = dataset[::3]

if ct2_compute_type:
model = CT2SentenceTransformer(model_name, compute_type=ct2_compute_type)
if not torch.cuda.is_available() or device == "cpu":
model.to("cpu")
else:
model.to("cuda")


else:
model = SentenceTransformerPatched(model_name)
Expand Down

0 comments on commit 49b6108

Please sign in to comment.