Skip to content

Commit

Permalink
warn
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Jul 31, 2024
1 parent 5e9bbf7 commit 393947b
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 10 deletions.
5 changes: 4 additions & 1 deletion optimum_benchmark/backends/diffusers_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,10 @@ def extract_diffusers_shapes_from_model(model: str, **kwargs) -> Dict[str, int]:
shapes["width"] = vae_config["sample_size"]

else:
warnings.warn("Could not extract shapes from the model.")
warnings.warn("Could not extract shapes [num_channels, height, width] from diffusion pipeline.")
shapes["num_channels"] = -1
shapes["height"] = -1
shapes["width"] = -1

return shapes

Expand Down
14 changes: 5 additions & 9 deletions optimum_benchmark/task_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,16 +196,15 @@ def infer_model_type_from_model_name_or_path(

inferred_model_type = None

if library_name == "timm":
if library_name == "llama_cpp":
inferred_model_type = "llama_cpp"

elif library_name == "timm":
timm_config = get_timm_pretrained_config(model_name_or_path)
inferred_model_type = timm_config.architecture

elif library_name == "diffusers":
from diffusers import DiffusionPipeline

get_diffusers_pretrained_config
config = DiffusionPipeline.load_config(model_name_or_path)
config, _ = config if isinstance(config, tuple) else (config, None)
config = get_diffusers_pretrained_config(model_name_or_path, revision=revision, token=token)
class_name = config["_class_name"]

for task_name, model_mapping in DIFFUSERS_TASKS_TO_MODEL_TYPES_TO_MODEL_CLASSES.items():
Expand All @@ -216,9 +215,6 @@ def infer_model_type_from_model_name_or_path(
if inferred_model_type is not None:
break

elif library_name == "llama_cpp":
inferred_model_type = "llama_cpp"

else:
transformers_config = get_transformers_pretrained_config(model_name_or_path, revision=revision, token=token)
inferred_model_type = transformers_config.model_type
Expand Down

0 comments on commit 393947b

Please sign in to comment.