Skip to content

Commit

Permalink
Add device parameter to convert functions and update usage model de… (
Browse files Browse the repository at this point in the history
#351)

* Add `device` parameter to convert functions and update usage model device in clip_similarity_scores.

* Add device parameter to clip_similarity_scores and handle device assignment.

* add `_parameters` to conftest Model
  • Loading branch information
ayasyrev authored Aug 27, 2024
1 parent bb2b4b6 commit f99a63d
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 6 deletions.
17 changes: 14 additions & 3 deletions src/datachain/lib/clip.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import inspect
from typing import TYPE_CHECKING, Any, Callable, Literal, Union
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union

import torch
from transformers.modeling_utils import PreTrainedModel
Expand Down Expand Up @@ -39,6 +39,7 @@ def clip_similarity_scores(
tokenizer: Callable,
prob: bool = False,
image_to_text: bool = True,
device: Optional[Union[str, torch.device]] = None,
) -> list[list[float]]:
"""
Calculate CLIP similarity scores between one or more images and/or text.
Expand All @@ -52,6 +53,7 @@ def clip_similarity_scores(
prob : Compute softmax probabilities.
image_to_text : Whether to compute for image-to-text or text-to-image. Ignored
if only one of images or text provided.
device : Device to use. Defaults is None - use model's device.
Example:
Expand Down Expand Up @@ -130,17 +132,26 @@ def clip_similarity_scores(
```
"""

if device is None:
if hasattr(model, "device"):
device = model.device
else:
device = next(model.parameters()).device
else:
model = model.to(device)
with torch.no_grad():
if images is not None:
encoder = _get_encoder(model, "image")
image_features = convert_images(
images, transform=preprocess, encoder=encoder
images, transform=preprocess, encoder=encoder, device=device
)
image_features /= image_features.norm(dim=-1, keepdim=True) # type: ignore[union-attr]

if text is not None:
encoder = _get_encoder(model, "text")
text_features = convert_text(text, tokenizer, encoder=encoder)
text_features = convert_text(
text, tokenizer, encoder=encoder, device=device
)
text_features /= text_features.norm(dim=-1, keepdim=True) # type: ignore[union-attr]

if images is not None and text is not None:
Expand Down
10 changes: 9 additions & 1 deletion src/datachain/lib/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def convert_image(
size: Optional[tuple[int, int]] = None,
transform: Optional[Callable] = None,
encoder: Optional[Callable] = None,
device: Optional[Union[str, torch.device]] = None,
) -> Union[Image.Image, torch.Tensor]:
"""
Resize, transform, and otherwise convert an image.
Expand All @@ -20,6 +21,7 @@ def convert_image(
size (tuple[int, int]): Size in (width, height) pixels for resizing.
transform (Callable): Torchvision transform or huggingface processor to apply.
encoder (Callable): Encode image using model.
device (str or torch.device): Device to use.
"""
if mode:
img = img.convert(mode)
Expand All @@ -35,6 +37,8 @@ def convert_image(
img = torch.tensor(img.pixel_values[0]) # type: ignore[assignment,attr-defined]
except ImportError:
pass
if device:
img = img.to(device) # type: ignore[attr-defined]
if encoder:
img = img.unsqueeze(0) # type: ignore[attr-defined]
if encoder:
Expand All @@ -48,6 +52,7 @@ def convert_images(
size: Optional[tuple[int, int]] = None,
transform: Optional[Callable] = None,
encoder: Optional[Callable] = None,
device: Optional[Union[str, torch.device]] = None,
) -> Union[list[Image.Image], torch.Tensor]:
"""
Resize, transform, and otherwise convert one or more images.
Expand All @@ -58,11 +63,14 @@ def convert_images(
size (tuple[int, int]): Size in (width, height) pixels for resizing.
transform (Callable): Torchvision transform or huggingface processor to apply.
encoder (Callable): Encode image using model.
device (str or torch.device): Device to use.
"""
if isinstance(images, Image.Image):
images = [images]

converted = [convert_image(img, mode, size, transform) for img in images]
converted = [
convert_image(img, mode, size, transform, device=device) for img in images
]

if isinstance(converted[0], torch.Tensor):
converted = torch.stack(converted) # type: ignore[assignment,arg-type]
Expand Down
4 changes: 4 additions & 0 deletions src/datachain/lib/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ def convert_text(
tokenizer: Optional[Callable] = None,
tokenizer_kwargs: Optional[dict[str, Any]] = None,
encoder: Optional[Callable] = None,
device: Optional[Union[str, torch.device]] = None,
) -> Union[str, list[str], torch.Tensor]:
"""
Tokenize and otherwise transform text.
Expand All @@ -18,6 +19,7 @@ def convert_text(
tokenizer (Callable): Tokenizer to use to tokenize objects.
tokenizer_kwargs (dict): Additional kwargs to pass when calling tokenizer.
encoder (Callable): Encode text using model.
device (str or torch.device): Device to use.
"""
if not tokenizer:
return text
Expand All @@ -32,6 +34,8 @@ def convert_text(

tokens = res.input_ids if isinstance(tokenizer, PreTrainedTokenizerBase) else res
tokens = torch.tensor(tokens)
if device:
tokens = tokens.to(device)

if not encoder:
return tokens
Expand Down
7 changes: 5 additions & 2 deletions tests/unit/lib/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@

@pytest.fixture()
def fake_clip_model():
class Model:
class Model(torch.nn.Module):
def __init__(self, *args, **kwargs):
self._parameters = {"p_1": torch.nn.Parameter(torch.tensor(1.0))}

def encode_image(self, tensor):
return torch.randn(len(tensor), 512)

Expand All @@ -34,7 +37,7 @@ def tokenizer(text, context_length=77):
def fake_hf_model():
class Model(PreTrainedModel):
def __init__(self, *args, **kwargs):
pass
self._parameters = {"p_1": torch.nn.Parameter(torch.tensor(1.0))}

def get_text_features(self, tensor):
return torch.randn(len(tensor), 512)
Expand Down

0 comments on commit f99a63d

Please sign in to comment.