Skip to content

Commit

Permalink
trigger super class' __post_init__
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Nov 24, 2023
1 parent 48a4410 commit 6b0c247
Showing 4 changed files with 12 additions and 3 deletions.
9 changes: 6 additions & 3 deletions optimum_benchmark/backends/config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from abc import ABC
from logging import getLogger
from dataclasses import dataclass
from typing import Optional, TypeVar

from psutil import cpu_count

LOGGER = getLogger("backend")


@dataclass
class BackendConfig(ABC):
@@ -18,7 +21,7 @@ class BackendConfig(ABC):

# device isolation options
continuous_isolation: bool = True
isolation_check_interval: Optional[int] = None
isolation_check_interval: Optional[float] = None

# clean up options
delete_cache: bool = False
@@ -32,8 +35,8 @@ def __post_init__(self):
if self.intra_op_num_threads == -1:
self.intra_op_num_threads = cpu_count()

if self.isolation_check_interval is None:
self.isolation_check_interval = 1 # 1 second
if self.continuous_isolation and self.isolation_check_interval is None:
self.isolation_check_interval = 1


BackendConfigT = TypeVar("BackendConfigT", bound=BackendConfig)
2 changes: 2 additions & 0 deletions optimum_benchmark/backends/neural_compressor/config.py
Original file line number Diff line number Diff line change
@@ -72,6 +72,8 @@ class INCConfig(BackendConfig):
calibration_config: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self):
super().__post_init__()

if self.ptq_quantization:
self.ptq_quantization_config = OmegaConf.to_object(
OmegaConf.merge(PTQ_QUANTIZATION_CONFIG, self.ptq_quantization_config)
2 changes: 2 additions & 0 deletions optimum_benchmark/backends/onnxruntime/config.py
Original file line number Diff line number Diff line change
@@ -130,6 +130,8 @@ class ORTConfig(BackendConfig):
peft_config: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self):
super().__post_init__()

if not self.no_weights and not self.export and self.torch_dtype is not None:
raise NotImplementedError("Can't convert an exported model's weights to a different dtype.")

2 changes: 2 additions & 0 deletions optimum_benchmark/backends/pytorch/config.py
Original file line number Diff line number Diff line change
@@ -72,6 +72,8 @@ class PyTorchConfig(BackendConfig):
peft_config: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self):
super().__post_init__()

if self.torch_compile:
self.torch_compile_config = OmegaConf.to_object(OmegaConf.merge(COMPILE_CONFIG, self.torch_compile_config))

0 comments on commit 6b0c247

Please sign in to comment.