diff --git a/examples/multimodal_llm_eval/evaluate_mllm_metric_complex_stability.py b/examples/multimodal_llm_eval/evaluate_mllm_metric_complex_stability.py new file mode 100644 index 0000000..3d597be --- /dev/null +++ b/examples/multimodal_llm_eval/evaluate_mllm_metric_complex_stability.py @@ -0,0 +1,36 @@ +from typing import Optional + +import fire +import wandb +import weave + +from hemm.eval_pipelines import StabilityAPIModel, EvaluationPipeline +from hemm.metrics.vqa import MultiModalLLMEvaluationMetric +from hemm.metrics.vqa.judges.mmllm_judges import OpenAIJudge, PromptCategory + + +def main( + project="mllm-eval", + entity="hemm-eval", + dataset_ref: Optional[str] = "Dataset:v0", + dataset_limit: Optional[int] = None, + model_name: str = "sd3-large", +): + wandb.init(project=project, entity=entity, job_type="evaluation") + weave.init(project_name=f"{entity}/{project}") + + dataset = weave.ref(dataset_ref).get() + dataset = dataset.rows[:dataset_limit] if dataset_limit else dataset + + stability_model = StabilityAPIModel(model_name=model_name) + evaluation_pipeline = EvaluationPipeline(model=stability_model) + + judge = OpenAIJudge(prompt_property=PromptCategory.action) + metric = MultiModalLLMEvaluationMetric(judge=judge) + evaluation_pipeline.add_metric(metric) + + evaluation_pipeline(dataset=dataset) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/hemm/eval_pipelines/__init__.py b/hemm/eval_pipelines/__init__.py index e95cb96..f33f6ea 100644 --- a/hemm/eval_pipelines/__init__.py +++ b/hemm/eval_pipelines/__init__.py @@ -1,4 +1,9 @@ from .eval_pipeline import EvaluationPipeline -from .model import BaseDiffusionModel +from .model import BaseDiffusionModel, FalDiffusionModel, StabilityAPIModel -__all__ = ["BaseDiffusionModel", "EvaluationPipeline"] +__all__ = [ + "BaseDiffusionModel", + "EvaluationPipeline", + "FalDiffusionModel", + "StabilityAPIModel", +] diff --git a/hemm/eval_pipelines/eval_pipeline.py b/hemm/eval_pipelines/eval_pipeline.py index 15d46c8..fdd7536 100644 --- a/hemm/eval_pipelines/eval_pipeline.py +++ b/hemm/eval_pipelines/eval_pipeline.py @@ -2,26 +2,29 @@ from abc import ABC from typing import Dict, List, Union -import wandb import weave +import wandb + from ..metrics.base import BaseMetric -from .model import BaseDiffusionModel +from .model import BaseDiffusionModel, FalDiffusionModel, StabilityAPIModel class EvaluationPipeline(ABC): """Evaluation pipeline to evaluate the a multi-modal generative model. Args: - model (BaseDiffusionModel): The model to evaluate. + model (Union[BaseDiffusionModel, FalDiffusionModel, StabilityAPIModel]): The model to evaluate. seed (int): Seed value for the random number generator. """ - def __init__(self, model: BaseDiffusionModel, seed: int = 42) -> None: + def __init__( + self, + model: Union[BaseDiffusionModel, FalDiffusionModel, StabilityAPIModel], + seed: int = 42, + ) -> None: super().__init__() self.model = model - - self.image_size = (self.model.image_height, self.model.image_width) self.seed = seed self.inference_counter = 1 @@ -30,17 +33,24 @@ def __init__(self, model: BaseDiffusionModel, seed: int = 42) -> None: self.evaluation_table: wandb.Table = None self.metric_functions: List[BaseMetric] = [] - self.evaluation_configs = { - "pretrained_model_name_or_path": self.model.diffusion_model_name_or_path, - "torch_dtype": str(self.model._torch_dtype), - "enable_cpu_offfload": self.model.enable_cpu_offfload, - "image_size": { - "height": self.image_size[0], - "width": self.image_size[1], - }, - "seed": seed, - "diffusion_pipeline": dict(self.model._pipeline.config), - } + if isinstance(self.model, BaseDiffusionModel): + self.image_size = (self.model.image_height, self.model.image_width) + self.evaluation_configs = { + "pretrained_model_name_or_path": self.model.diffusion_model_name_or_path, + "torch_dtype": str(self.model._torch_dtype), + "enable_cpu_offfload": self.model.enable_cpu_offfload, + "image_size": { + "height": self.image_size[0], + "width": self.image_size[1], + }, + "seed": seed, + "diffusion_pipeline": dict(self.model._pipeline.config), + } + elif isinstance(self.model, StabilityAPIModel): + self.evaluation_configs = { + "model_name": self.model.model_name, + "aspect_ratio": self.model.aspect_ratio, + } def add_metric(self, metric_fn: BaseMetric): """Add a metric function to the evaluation pipeline. @@ -67,9 +77,16 @@ def infer(self, prompt: str) -> Dict[str, str]: self.evaluation_table = wandb.Table(columns=self.table_columns) self.inference_counter += 1 output = self.model.predict(prompt, seed=self.seed) - self.table_rows.append( - [self.model.diffusion_model_name_or_path, prompt, output["image"]] - ) + inference_row = [] + if isinstance(self.model, BaseDiffusionModel): + inference_row = [ + self.model.diffusion_model_name_or_path, + prompt, + output["image"], + ] + elif isinstance(self.model, StabilityAPIModel): + inference_row = [self.model.model_name, prompt, output["image"]] + self.table_rows.append(inference_row) return output @weave.op() @@ -104,17 +121,23 @@ def log_summary(self, summary: Dict[str, float]) -> None: } ) - def __call__(self, dataset: Union[List[Dict], str]) -> Dict[str, float]: + def __call__( + self, dataset: Union[List[Dict], str], evaluation_in_async: bool = True + ) -> Dict[str, float]: """Evaluate the Stable Diffusion model on the given dataset. Args: dataset (Union[List[Dict], str]): Dataset to evaluate the model on. If a string is passed, it is assumed to be a Weave dataset reference. + evaluation_in_async (bool): Whether to evaluate the metrics in async mode. """ dataset = weave.ref(dataset).get() if isinstance(dataset, str) else dataset evaluation = weave.Evaluation( dataset=dataset, - scorers=[metric_fn.evaluate_async for metric_fn in self.metric_functions], + scorers=[ + metric_fn.evaluate_async if evaluation_in_async else metric_fn.evaluate + for metric_fn in self.metric_functions + ], ) self.model.configs.update(self.evaluation_configs) summary = asyncio.run(evaluation.evaluate(self.infer_async)) diff --git a/hemm/eval_pipelines/model.py b/hemm/eval_pipelines/model.py index c0d3a3a..e0d77cf 100644 --- a/hemm/eval_pipelines/model.py +++ b/hemm/eval_pipelines/model.py @@ -1,19 +1,37 @@ +import io +import os from typing import Any, Dict +import fal_client +import requests import torch import weave from diffusers import DiffusionPipeline +from diffusers.utils.loading_utils import load_image +from PIL import Image + +from ..utils import custom_weave_wrapper + + +STABILITY_MODEL_HOST = { + "sd3-large": "https://api.stability.ai/v2beta/stable-image/generate/sd3", + "sd3-large-turbo": "https://api.stability.ai/v2beta/stable-image/generate/sd3", +} class BaseDiffusionModel(weave.Model): - """Base `weave.Model` wrapping `diffusers.DiffusionPipeline`. + """`weave.Model` wrapping `diffusers.DiffusionPipeline`. Args: diffusion_model_name_or_path (str): The name or path of the diffusion model. enable_cpu_offfload (bool): Enable CPU offload for the diffusion model. image_height (int): The height of the generated image. image_width (int): The width of the generated image. + num_inference_steps (int): The number of inference steps. disable_safety_checker (bool): Disable safety checker for the diffusion model. + configs (Dict[str, Any]): Additional configs. + pipeline_configs (Dict[str, Any]): Diffusion pipeline configs. + inference_kwargs (Dict[str, Any]): Inference kwargs. """ diffusion_model_name_or_path: str @@ -78,3 +96,94 @@ def predict(self, prompt: str, seed: int) -> Dict[str, Any]: **self.inference_kwargs, ) return {"image": pipeline_output.images[0]} + + +class FalDiffusionModel(weave.Model): + """`weave.Model` wrapping [FalAI](https://fal.ai/) calls. + + Args: + model_name (str): FalAI model name. + inference_kwargs (Dict[str, Any]): Inference kwargs. + """ + + model_name: str + inference_kwargs: Dict[str, Any] = {} + + @weave.op() + def generate_image(self, prompt: str, seed: int) -> Image.Image: + result = custom_weave_wrapper(name="fal_client.submit.get")( + fal_client.submit( + self.model_name, + arguments={"prompt": prompt, "seed": seed, **self.inference_kwargs}, + ).get + )() + return load_image(result["images"][0]["url"]) + + @weave.op() + def predict(self, prompt: str, seed: int) -> Image.Image: + return {"image": self.generate_image(prompt=prompt, seed=seed)} + + +class StabilityAPIModel(weave.Model): + """`weave.Model` wrapping Stability API calls. + + Args: + model_name (str): Stability model name. + aspect_ratio (str): Aspect ratio of the generated image. + creativity (float): Creativity of the generated image. + """ + + model_name: str + aspect_ratio: str = "1:1" + creativity: float = 0.35 + configs: Dict[str, Any] = {} + + def __init__( + self, + model_name: str, + aspect_ratio: str = "1:1", + creativity: float = 0.35, + ) -> None: + assert aspect_ratio in [ + "1:1", + "16:9", + "21:9", + "2:3", + "3:2", + "4:5", + "5:4", + "9:16", + "9:21", + ], "Invalid aspect ratio" + super().__init__( + model_name=model_name, aspect_ratio=aspect_ratio, creativity=creativity + ) + + @weave.op() + def send_generation_request(self, prompt: str, seed: int): + api_key = os.environ["STABILITY_KEY"] + headers = {"Accept": "image/*", "Authorization": f"Bearer {api_key}"} + response = requests.post( + STABILITY_MODEL_HOST[self.model_name], + headers=headers, + files={"none": ""}, + data={ + "prompt": prompt, + "negative_prompt": "", + "aspect_ratio": self.aspect_ratio, + "seed": seed, + "output_format": "png", + "model": self.model_name, + "mode": "text-to-image", + "creativity": self.creativity, + }, + ) + if not response.ok: + raise Exception(f"HTTP {response.status_code}: {response.text}") + return response + + @weave.op() + def predict(self, prompt: str, seed: int) -> Image.Image: + response = self.send_generation_request(prompt=prompt, seed=seed) + image = Image.open(io.BytesIO(response.content)) + return {"image": image} diff --git a/hemm/utils.py b/hemm/utils.py index 447150c..e9dc091 100644 --- a/hemm/utils.py +++ b/hemm/utils.py @@ -155,3 +155,12 @@ def autogenerate_seed(set_to_max: bool = False) -> int: seed = -seed if seed < 0 else seed seed = seed % max_seed return seed + + +def custom_weave_wrapper(name: str) -> Callable[[Callable], Callable]: + def wrapper(fn: Callable) -> Callable: + op = weave.op()(fn) + op.name = name # type: ignore + return op + + return wrapper diff --git a/poetry.lock b/poetry.lock index e33bbe5..239504f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1281,6 +1281,25 @@ files = [ [package.extras] tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipython", "littleutils", "pytest", "rich"] +[[package]] +name = "fal-client" +version = "0.4.1" +description = "Python client for fal.ai" +optional = false +python-versions = ">=3.8" +files = [ + {file = "fal_client-0.4.1-py3-none-any.whl", hash = "sha256:3fe13ac5108a02c1c27e146e52dcbb0b10d33694d870c8a1769966af18ff4a3f"}, + {file = "fal_client-0.4.1.tar.gz", hash = "sha256:3121cdbf4be8a47226e6df8e782340c1a603b17ec04942a131c2929a32aedff3"}, +] + +[package.dependencies] +httpx = ">=0.21.0,<1" +httpx-sse = ">=0.4.0,<0.5" + +[package.extras] +dev = ["fal-client[test]"] +test = ["pillow", "pytest", "pytest-asyncio"] + [[package]] name = "fastjsonschema" version = "2.19.1" @@ -1630,6 +1649,17 @@ cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] +[[package]] +name = "httpx-sse" +version = "0.4.0" +description = "Consume Server-Sent Event (SSE) messages with HTTPX." +optional = false +python-versions = ">=3.8" +files = [ + {file = "httpx-sse-0.4.0.tar.gz", hash = "sha256:1e81a3a3070ce322add1d3529ed42eb5f70817f45ed6ec915ab753f961139721"}, + {file = "httpx_sse-0.4.0-py3-none-any.whl", hash = "sha256:f329af6eae57eaa2bdfd962b42524764af68075ea87370a2de920af5341e318f"}, +] + [[package]] name = "huggingface-hub" version = "0.23.5" @@ -4314,6 +4344,20 @@ files = [ [package.dependencies] six = ">=1.5" +[[package]] +name = "python-dotenv" +version = "1.0.1" +description = "Read key-value pairs from a .env file and set them as environment variables" +optional = false +python-versions = ">=3.8" +files = [ + {file = "python-dotenv-1.0.1.tar.gz", hash = "sha256:e324ee90a023d808f1959c46bcbc04446a10ced277783dc6ee09987c37ec10ca"}, + {file = "python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a"}, +] + +[package.extras] +cli = ["click (>=5.0)"] + [[package]] name = "python-json-logger" version = "2.0.7" @@ -7035,10 +7079,10 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.link testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] [extras] -core = ["accelerate", "datasets", "diffusers", "fire", "huggingface-hub", "instructor", "jsonlines", "poetry", "spacy", "torchmetrics", "transformers", "wandb", "weave"] +core = ["accelerate", "datasets", "diffusers", "fal-client", "fire", "huggingface-hub", "instructor", "jsonlines", "poetry", "sentencepiece", "spacy", "torchmetrics", "transformers", "wandb", "weave"] docs = ["jupyter", "mkdocs", "mkdocs-glightbox", "mkdocs-jupyter", "mkdocs-material", "mkdocs-minify-plugin", "mkdocstrings"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "ea40234059583011ea809df1b374a1e4fe7d7b37ee0210b70eaa34ac02f21001" +content-hash = "7ec1630f54026b46531236ebdbb4df5477984927166728e5314b5671ab62a8e2" diff --git a/pyproject.toml b/pyproject.toml index 1f7f906..30bb231 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,8 @@ instructor = "^1.3.4" torchmetrics = { extras = ["multimodal"], version = "^1.4.1" } mkdocstrings = {version = "^0.25.2", extras = ["python"]} sentencepiece = "^0.2.0" +fal-client = "^0.4.1" +python-dotenv = "^1.0.1" [tool.poetry.extras] core = [ @@ -45,7 +47,9 @@ core = [ "huggingface-hub", "datasets", "fire", + "fal-client", "jsonlines", + "python-dotenv", "spacy", "instructor", "torchmetrics",