From 1fe23d55100f2aee11a9ff64500b33803e10d92e Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Thu, 22 Feb 2024 12:00:52 +0000 Subject: [PATCH] disable timm no weights --- optimum_benchmark/backends/pytorch/backend.py | 12 ++++++++---- tests/configs/_timm_.yaml | 5 ----- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/optimum_benchmark/backends/pytorch/backend.py b/optimum_benchmark/backends/pytorch/backend.py index 8ebac51a..45a7d460 100644 --- a/optimum_benchmark/backends/pytorch/backend.py +++ b/optimum_benchmark/backends/pytorch/backend.py @@ -1,4 +1,5 @@ import gc +import json import os from collections import OrderedDict from logging import getLogger @@ -70,8 +71,8 @@ def __init__(self, config: PyTorchConfig): LOGGER.info("\t+ Creating backend temporary directory") self.tmpdir = TemporaryDirectory() - if self.config.no_weights and self.config.library == "diffusers": - raise ValueError("Diffusion pipelines are not supported with no_weights=True") + if self.config.no_weights and (self.config.library == "diffusers" or self.config.library == "timm"): + raise ValueError("Diffusion pipelines and Timm models don't support no weights") elif self.config.no_weights: LOGGER.info("\t+ Loading model with random weights") self.load_model_with_no_weights() @@ -179,7 +180,10 @@ def load_model_from_pretrained(self) -> None: ) def create_no_weights_model(self) -> None: - self.no_weights_model = os.path.join(self.tmpdir.name, "no_weights") + if self.pretrained_config is None: + raise ValueError("Can't create no weights model without a pretrained config") + + self.no_weights_model = os.path.join(self.tmpdir.name, "no_weights_model") LOGGER.info("\t+ Creating no weights model directory") os.makedirs(self.no_weights_model, exist_ok=True) LOGGER.info("\t+ Creating no weights model state dict") @@ -202,8 +206,8 @@ def create_no_weights_model(self) -> None: self.pretrained_config.quantization_config = self.quantization_config.to_dict() # tricking from_pretrained to load the model as if it was quantized + LOGGER.info("\t+ Saving no weights model pretrained config") if self.config.library == "transformers": - LOGGER.info("\t+ Saving no weights model pretrained config") self.pretrained_config.save_pretrained(save_directory=self.no_weights_model) def load_model_with_no_weights(self) -> None: diff --git a/tests/configs/_timm_.yaml b/tests/configs/_timm_.yaml index c1087829..22d47cdd 100644 --- a/tests/configs/_timm_.yaml +++ b/tests/configs/_timm_.yaml @@ -1,7 +1,2 @@ backend: model: timm/tiny_vit_21m_224.in1k - -hydra: - sweeper: - params: - backend.no_weights: true,false