Skip to content

Commit

Permalink
sem map and sem join done
Browse files Browse the repository at this point in the history
  • Loading branch information
harshitgupta412 committed Nov 17, 2024
1 parent 466c3f9 commit ac8fd82
Show file tree
Hide file tree
Showing 9 changed files with 233 additions and 127 deletions.
10 changes: 4 additions & 6 deletions examples/multimodal_data/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,12 @@
lm = LM(model="gpt-4o-mini")
lotus.settings.configure(lm=lm)

mnist_data = datasets.MNIST(root='mnist_data', train=True, download=True, transform=None)
mnist_data = datasets.MNIST(root="mnist_data", train=True, download=True, transform=None)

images = [image for image, _ in mnist_data]
labels = [label for _, label in mnist_data]

df = pd.DataFrame({
"image": ImageArray(images),
"label": labels
})
df = pd.DataFrame({"image": ImageArray(images), "label": labels})

df.sem_filter("{image} represents number 1")
df = df.sem_filter("{image} represents number 1")
print(df)
22 changes: 22 additions & 0 deletions examples/multimodal_data/join.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import pandas as pd
from torchvision import datasets

import lotus
from lotus.dtype_extensions import ImageArray
from lotus.models import LM

lm = LM(model="gpt-4o-mini")
lotus.settings.configure(lm=lm)

mnist_data = datasets.MNIST(root="mnist_data", train=True, download=True, transform=None)

images = [image for image, _ in mnist_data]
labels = [label for _, label in mnist_data]

df = pd.DataFrame({"image": ImageArray(images[:5]), "label": labels[:5]})

df2 = pd.DataFrame({"image": ImageArray(images[5:10]), "label": labels[5:10]})

df = df.sem_join(df2, "{image:left} represents the same number as {image:right}", strategy="zs-cot")

print(df)
19 changes: 19 additions & 0 deletions examples/multimodal_data/map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import pandas as pd
from torchvision import datasets

import lotus
from lotus.dtype_extensions import ImageArray
from lotus.models import LM

lm = LM(model="gpt-4o-mini")
lotus.settings.configure(lm=lm)

mnist_data = datasets.MNIST(root="mnist_data", train=True, download=True, transform=None)

images = [image for image, _ in mnist_data]
labels = [label for _, label in mnist_data]

df = pd.DataFrame({"image": ImageArray(images[:5]), "label": labels[:5]})

df = df.sem_map("convert {image} to the number it represents")
print(df)
35 changes: 27 additions & 8 deletions lotus/dtype_extensions/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ def construct_array_type(cls):


class ImageArray(ExtensionArray):
def __init__(self, values: np.ndarray):
def __init__(self, values):
self._data = np.asarray(values, dtype=object)
self._dtype = ImageDtype()
self._cached_images: dict[int, str | Image.Image | None] = {} # Cache for loaded images
self.allowed_image_types = ["Image", "base64"]
self._cached_images: dict[tuple[int, str], str | Image.Image | None] = {} # Cache for loaded images

def __getitem__(self, item: int | slice | Sequence[int]) -> np.ndarray:
result = self._data[item]
Expand Down Expand Up @@ -55,22 +56,23 @@ def __setitem__(self, key, value) -> None:

def _invalidate_cache(self, idx: int) -> None:
"""Remove an item from the cache."""
if idx in self._cached_images:
del self._cached_images[idx]
for image_type in self.allowed_image_types:
if (idx, image_type) in self._cached_images:
del self._cached_images[(idx, image_type)]

def get_image(self, idx: int, image_type: str = "Image") -> Union[Image.Image, str, None]:
"""Explicit method to fetch and return the actual image"""
if idx not in self._cached_images:
if (idx, image_type) not in self._cached_images:
image_result = fetch_image(self._data[idx], image_type)
assert image_result is None or isinstance(image_result, (Image.Image, str))
self._cached_images[idx] = image_result
return self._cached_images[idx]
self._cached_images[(idx, image_type)] = image_result
return self._cached_images[(idx, image_type)]

def isna(self) -> np.ndarray:
return pd.isna(self._data)

def take(self, indices: Sequence[int], allow_fill: bool = False, fill_value=None) -> "ImageArray":
result = self._data.take(indices)
result = self._data.take(indices, axis=0)
if allow_fill and fill_value is not None:
result[indices == -1] = fill_value
return ImageArray(result)
Expand Down Expand Up @@ -113,6 +115,23 @@ def __repr__(self) -> str:
def _formatter(self, boxed: bool = False):
return lambda x: f"<Image: {type(x)}>" if x is not None else "None"

def to_numpy(self, dtype=None, copy=False, na_value=None) -> np.ndarray:
"""Convert the ImageArray to a numpy array."""
pil_images = []
for i, img_data in enumerate(self._data):
if isinstance(img_data, np.ndarray):
image = self.get_image(i)
pil_images.append(image)
else:
pil_images.append(img_data)
result = np.empty(len(self), dtype=object)
result[:] = pil_images
return result

def __array__(self, dtype=None) -> np.ndarray:
"""Numpy array interface."""
return self.to_numpy(dtype=dtype)


def _compare_images(img1, img2) -> bool:
if img1 is None or img2 is None:
Expand Down
1 change: 0 additions & 1 deletion lotus/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,4 @@
"LiteLLMRM",
"SentenceTransformersRM",
"ColBERTv2RM",
"Qwen2Model",
]
12 changes: 6 additions & 6 deletions lotus/sem_ops/sem_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@


def sem_filter(
docs: list[dict[str, Any]] | list[str],
docs: list[dict[str, Any]],
model: lotus.models.LM,
user_instruction: str,
default: bool = True,
examples_multimodal_data: list[dict[str, Any]] | list[str] | None = None,
examples_multimodal_data: list[dict[str, Any]] | None = None,
examples_answers: list[bool] | None = None,
cot_reasoning: list[str] | None = None,
strategy: str | None = None,
Expand All @@ -27,11 +27,11 @@ def sem_filter(
Filters a list of documents based on a given user instruction using a language model.
Args:
docs (list[dict[str, Any]] | list[str]): The list of documents to filter. Each document is a tuple of text and images.
docs (list[dict[str, Any]]): The list of documents to filter. Each document is a tuple of text and images.
model (lotus.models.LM): The language model used for filtering.
user_instruction (str): The user instruction for filtering.
default (bool): The default value for filtering in case of parsing errors. Defaults to True.
examples_multimodal_data (list[dict[str, Any]] | list[str] | None): The text for examples. Defaults to None.
examples_multimodal_data (list[dict[str, Any]] | None): The text for examples. Defaults to None.
examples_answers (list[bool] | None): The answers for examples. Defaults to None.
cot_reasoning (list[str] | None): The reasoning for CoT. Defaults to None.
logprobs (bool): Whether to return log probabilities. Defaults to False.
Expand Down Expand Up @@ -60,7 +60,7 @@ def sem_filter(


def learn_filter_cascade_thresholds(
sample_multimodal_data: list[dict[str, Any]] | list[str],
sample_multimodal_data: list[dict[str, Any]],
lm: lotus.models.LM,
formatted_usr_instr: str,
default: bool,
Expand All @@ -69,7 +69,7 @@ def learn_filter_cascade_thresholds(
delta: float,
helper_true_probs: list[float],
sample_correction_factors: NDArray[np.float64],
examples_multimodal_data: list[dict[str, Any]] | list[str] | None = None,
examples_multimodal_data: list[dict[str, Any]] | None = None,
examples_answers: list[bool] | None = None,
cot_reasoning: list[str] | None = None,
strategy: str | None = None,
Expand Down
Loading

0 comments on commit ac8fd82

Please sign in to comment.