Skip to content

Commit

Permalink
multimodality bug fix -- add method for pd concat (#49)
Browse files Browse the repository at this point in the history
Add support for pd concatenation for Image Array datatypes
  • Loading branch information
liana313 authored Dec 13, 2024
1 parent 09d7af6 commit 745fa8e
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
17 changes: 17 additions & 0 deletions .github/tests/multimodality_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,23 @@ def test_topk_operation(setup_models, model):

top_2_actual = set(sorted_df["image"].values)
assert top_2_expected == top_2_actual

@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini"))
def test_topk_with_groupby_operation(setup_models, model):
image_url = [
"https://img.etsystatic.com/il/4bee20/1469037676/il_340x270.1469037676_iiti.jpg?version=0",
"https://i1.wp.com/www.alloverthemap.net/wp-content/uploads/2014/02/2012-09-25-12.46.15.jpg?resize=400%2C284&ssl=1",
"https://i.pinimg.com/236x/a4/3a/65/a43a65683a0314f29b66402cebdcf46d.jpg",
"https://pravme.ru/wp-content/uploads/2018/01/sobor-Bogord-1.jpg",
]
elements = ["doll", "bird"]
image_df = pd.DataFrame({"image": ImageArray(image_url)})
element_df = pd.DataFrame({"element": elements})

df = image_df.join(element_df, how="cross")
df.sem_topk("the {image} is most likely an {element}", K=1, group_by=["element"])
assert(len(set(df["element"])) == 2)



@pytest.mark.parametrize("model", get_enabled("clip-ViT-B-32"))
Expand Down
17 changes: 17 additions & 0 deletions lotus/dtype_extensions/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,23 @@ def copy(self) -> "ImageArray":
new_array._cached_images = self._cached_images.copy()
return new_array

def _concat_same_type(cls, to_concat: Sequence["ImageArray"]) -> "ImageArray":
"""
Concatenate multiple ImageArray instances into a single one.
Args:
to_concat (Sequence[ImageArray]): A sequence of ImageArray instances to concatenate.
Returns:
ImageArray: A new ImageArray containing all elements from the input arrays.
"""
# create list of all data
combined_data = np.concatenate([arr._data for arr in to_concat])
return cls._from_sequence(combined_data)




@classmethod
def _from_sequence(cls, scalars, dtype=None, copy=False):
if copy:
Expand Down

0 comments on commit 745fa8e

Please sign in to comment.