Skip to content

Commit

Permalink
disable timm no weights
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Feb 22, 2024
1 parent 984a689 commit 1fe23d5
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 9 deletions.
12 changes: 8 additions & 4 deletions optimum_benchmark/backends/pytorch/backend.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import gc
import json
import os
from collections import OrderedDict
from logging import getLogger
Expand Down Expand Up @@ -70,8 +71,8 @@ def __init__(self, config: PyTorchConfig):
LOGGER.info("\t+ Creating backend temporary directory")
self.tmpdir = TemporaryDirectory()

if self.config.no_weights and self.config.library == "diffusers":
raise ValueError("Diffusion pipelines are not supported with no_weights=True")
if self.config.no_weights and (self.config.library == "diffusers" or self.config.library == "timm"):
raise ValueError("Diffusion pipelines and Timm models don't support no weights")
elif self.config.no_weights:
LOGGER.info("\t+ Loading model with random weights")
self.load_model_with_no_weights()
Expand Down Expand Up @@ -179,7 +180,10 @@ def load_model_from_pretrained(self) -> None:
)

def create_no_weights_model(self) -> None:
self.no_weights_model = os.path.join(self.tmpdir.name, "no_weights")
if self.pretrained_config is None:
raise ValueError("Can't create no weights model without a pretrained config")

self.no_weights_model = os.path.join(self.tmpdir.name, "no_weights_model")
LOGGER.info("\t+ Creating no weights model directory")
os.makedirs(self.no_weights_model, exist_ok=True)
LOGGER.info("\t+ Creating no weights model state dict")
Expand All @@ -202,8 +206,8 @@ def create_no_weights_model(self) -> None:
self.pretrained_config.quantization_config = self.quantization_config.to_dict()
# tricking from_pretrained to load the model as if it was quantized

LOGGER.info("\t+ Saving no weights model pretrained config")
if self.config.library == "transformers":
LOGGER.info("\t+ Saving no weights model pretrained config")
self.pretrained_config.save_pretrained(save_directory=self.no_weights_model)

def load_model_with_no_weights(self) -> None:
Expand Down
5 changes: 0 additions & 5 deletions tests/configs/_timm_.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,2 @@
backend:
model: timm/tiny_vit_21m_224.in1k

hydra:
sweeper:
params:
backend.no_weights: true,false

0 comments on commit 1fe23d5

Please sign in to comment.