Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(model): Add support for FalAI and Stability API #22

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Optional

import fire
import wandb
import weave

from hemm.eval_pipelines import StabilityAPIModel, EvaluationPipeline
from hemm.metrics.vqa import MultiModalLLMEvaluationMetric
from hemm.metrics.vqa.judges.mmllm_judges import OpenAIJudge, PromptCategory


def main(
project="mllm-eval",
entity="hemm-eval",
dataset_ref: Optional[str] = "Dataset:v0",
dataset_limit: Optional[int] = None,
model_name: str = "sd3-large",
):
wandb.init(project=project, entity=entity, job_type="evaluation")
weave.init(project_name=f"{entity}/{project}")

dataset = weave.ref(dataset_ref).get()
dataset = dataset.rows[:dataset_limit] if dataset_limit else dataset

stability_model = StabilityAPIModel(model_name=model_name)
evaluation_pipeline = EvaluationPipeline(model=stability_model)

judge = OpenAIJudge(prompt_property=PromptCategory.action)
metric = MultiModalLLMEvaluationMetric(judge=judge)
evaluation_pipeline.add_metric(metric)

evaluation_pipeline(dataset=dataset)


if __name__ == "__main__":
fire.Fire(main)
9 changes: 7 additions & 2 deletions hemm/eval_pipelines/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from .eval_pipeline import EvaluationPipeline
from .model import BaseDiffusionModel
from .model import BaseDiffusionModel, FalDiffusionModel, StabilityAPIModel

__all__ = ["BaseDiffusionModel", "EvaluationPipeline"]
__all__ = [
"BaseDiffusionModel",
"EvaluationPipeline",
"FalDiffusionModel",
"StabilityAPIModel",
]
67 changes: 45 additions & 22 deletions hemm/eval_pipelines/eval_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,29 @@
from abc import ABC
from typing import Dict, List, Union

import wandb
import weave

import wandb

from ..metrics.base import BaseMetric
from .model import BaseDiffusionModel
from .model import BaseDiffusionModel, FalDiffusionModel, StabilityAPIModel


class EvaluationPipeline(ABC):
"""Evaluation pipeline to evaluate the a multi-modal generative model.

Args:
model (BaseDiffusionModel): The model to evaluate.
model (Union[BaseDiffusionModel, FalDiffusionModel, StabilityAPIModel]): The model to evaluate.
seed (int): Seed value for the random number generator.
"""

def __init__(self, model: BaseDiffusionModel, seed: int = 42) -> None:
def __init__(
self,
model: Union[BaseDiffusionModel, FalDiffusionModel, StabilityAPIModel],
seed: int = 42,
) -> None:
super().__init__()
self.model = model

self.image_size = (self.model.image_height, self.model.image_width)
self.seed = seed

self.inference_counter = 1
Expand All @@ -30,17 +33,24 @@ def __init__(self, model: BaseDiffusionModel, seed: int = 42) -> None:
self.evaluation_table: wandb.Table = None
self.metric_functions: List[BaseMetric] = []

self.evaluation_configs = {
"pretrained_model_name_or_path": self.model.diffusion_model_name_or_path,
"torch_dtype": str(self.model._torch_dtype),
"enable_cpu_offfload": self.model.enable_cpu_offfload,
"image_size": {
"height": self.image_size[0],
"width": self.image_size[1],
},
"seed": seed,
"diffusion_pipeline": dict(self.model._pipeline.config),
}
if isinstance(self.model, BaseDiffusionModel):
self.image_size = (self.model.image_height, self.model.image_width)
self.evaluation_configs = {
"pretrained_model_name_or_path": self.model.diffusion_model_name_or_path,
"torch_dtype": str(self.model._torch_dtype),
"enable_cpu_offfload": self.model.enable_cpu_offfload,
"image_size": {
"height": self.image_size[0],
"width": self.image_size[1],
},
"seed": seed,
"diffusion_pipeline": dict(self.model._pipeline.config),
}
elif isinstance(self.model, StabilityAPIModel):
self.evaluation_configs = {
"model_name": self.model.model_name,
"aspect_ratio": self.model.aspect_ratio,
}

def add_metric(self, metric_fn: BaseMetric):
"""Add a metric function to the evaluation pipeline.
Expand All @@ -67,9 +77,16 @@ def infer(self, prompt: str) -> Dict[str, str]:
self.evaluation_table = wandb.Table(columns=self.table_columns)
self.inference_counter += 1
output = self.model.predict(prompt, seed=self.seed)
self.table_rows.append(
[self.model.diffusion_model_name_or_path, prompt, output["image"]]
)
inference_row = []
if isinstance(self.model, BaseDiffusionModel):
inference_row = [
self.model.diffusion_model_name_or_path,
prompt,
output["image"],
]
elif isinstance(self.model, StabilityAPIModel):
inference_row = [self.model.model_name, prompt, output["image"]]
self.table_rows.append(inference_row)
return output

@weave.op()
Expand Down Expand Up @@ -104,17 +121,23 @@ def log_summary(self, summary: Dict[str, float]) -> None:
}
)

def __call__(self, dataset: Union[List[Dict], str]) -> Dict[str, float]:
def __call__(
self, dataset: Union[List[Dict], str], evaluation_in_async: bool = True
) -> Dict[str, float]:
"""Evaluate the Stable Diffusion model on the given dataset.

Args:
dataset (Union[List[Dict], str]): Dataset to evaluate the model on. If a string is
passed, it is assumed to be a Weave dataset reference.
evaluation_in_async (bool): Whether to evaluate the metrics in async mode.
"""
dataset = weave.ref(dataset).get() if isinstance(dataset, str) else dataset
evaluation = weave.Evaluation(
dataset=dataset,
scorers=[metric_fn.evaluate_async for metric_fn in self.metric_functions],
scorers=[
metric_fn.evaluate_async if evaluation_in_async else metric_fn.evaluate
for metric_fn in self.metric_functions
],
)
self.model.configs.update(self.evaluation_configs)
summary = asyncio.run(evaluation.evaluate(self.infer_async))
Expand Down
111 changes: 110 additions & 1 deletion hemm/eval_pipelines/model.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,37 @@
import io
import os
from typing import Any, Dict

import fal_client
import requests
import torch
import weave
from diffusers import DiffusionPipeline
from diffusers.utils.loading_utils import load_image
from PIL import Image

from ..utils import custom_weave_wrapper


STABILITY_MODEL_HOST = {
"sd3-large": "https://api.stability.ai/v2beta/stable-image/generate/sd3",
"sd3-large-turbo": "https://api.stability.ai/v2beta/stable-image/generate/sd3",
}


class BaseDiffusionModel(weave.Model):
"""Base `weave.Model` wrapping `diffusers.DiffusionPipeline`.
"""`weave.Model` wrapping `diffusers.DiffusionPipeline`.

Args:
diffusion_model_name_or_path (str): The name or path of the diffusion model.
enable_cpu_offfload (bool): Enable CPU offload for the diffusion model.
image_height (int): The height of the generated image.
image_width (int): The width of the generated image.
num_inference_steps (int): The number of inference steps.
disable_safety_checker (bool): Disable safety checker for the diffusion model.
configs (Dict[str, Any]): Additional configs.
pipeline_configs (Dict[str, Any]): Diffusion pipeline configs.
inference_kwargs (Dict[str, Any]): Inference kwargs.
"""

diffusion_model_name_or_path: str
Expand Down Expand Up @@ -78,3 +96,94 @@ def predict(self, prompt: str, seed: int) -> Dict[str, Any]:
**self.inference_kwargs,
)
return {"image": pipeline_output.images[0]}


class FalDiffusionModel(weave.Model):
"""`weave.Model` wrapping [FalAI](https://fal.ai/) calls.

Args:
model_name (str): FalAI model name.
inference_kwargs (Dict[str, Any]): Inference kwargs.
"""

model_name: str
inference_kwargs: Dict[str, Any] = {}

@weave.op()
def generate_image(self, prompt: str, seed: int) -> Image.Image:
result = custom_weave_wrapper(name="fal_client.submit.get")(
fal_client.submit(
self.model_name,
arguments={"prompt": prompt, "seed": seed, **self.inference_kwargs},
).get
)()
return load_image(result["images"][0]["url"])

@weave.op()
def predict(self, prompt: str, seed: int) -> Image.Image:
return {"image": self.generate_image(prompt=prompt, seed=seed)}


class StabilityAPIModel(weave.Model):
"""`weave.Model` wrapping Stability API calls.

Args:
model_name (str): Stability model name.
aspect_ratio (str): Aspect ratio of the generated image.
creativity (float): Creativity of the generated image.
"""

model_name: str
aspect_ratio: str = "1:1"
creativity: float = 0.35
configs: Dict[str, Any] = {}

def __init__(
self,
model_name: str,
aspect_ratio: str = "1:1",
creativity: float = 0.35,
) -> None:
assert aspect_ratio in [
"1:1",
"16:9",
"21:9",
"2:3",
"3:2",
"4:5",
"5:4",
"9:16",
"9:21",
], "Invalid aspect ratio"
super().__init__(
model_name=model_name, aspect_ratio=aspect_ratio, creativity=creativity
)

@weave.op()
def send_generation_request(self, prompt: str, seed: int):
api_key = os.environ["STABILITY_KEY"]
headers = {"Accept": "image/*", "Authorization": f"Bearer {api_key}"}
response = requests.post(
STABILITY_MODEL_HOST[self.model_name],
headers=headers,
files={"none": ""},
data={
"prompt": prompt,
"negative_prompt": "",
"aspect_ratio": self.aspect_ratio,
"seed": seed,
"output_format": "png",
"model": self.model_name,
"mode": "text-to-image",
"creativity": self.creativity,
},
)
if not response.ok:
raise Exception(f"HTTP {response.status_code}: {response.text}")
return response

@weave.op()
def predict(self, prompt: str, seed: int) -> Image.Image:
response = self.send_generation_request(prompt=prompt, seed=seed)
image = Image.open(io.BytesIO(response.content))
return {"image": image}
9 changes: 9 additions & 0 deletions hemm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,12 @@ def autogenerate_seed(set_to_max: bool = False) -> int:
seed = -seed if seed < 0 else seed
seed = seed % max_seed
return seed


def custom_weave_wrapper(name: str) -> Callable[[Callable], Callable]:
def wrapper(fn: Callable) -> Callable:
op = weave.op()(fn)
op.name = name # type: ignore
return op

return wrapper
48 changes: 46 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading