Skip to content

Commit

Permalink
Add --load-in-4bit and --load-in-8bit for HF eval backend (#332)
Browse files Browse the repository at this point in the history
Allows using bitsandbytes quantization in `mergekit-evolve` when a) not
using vLLM and b) not using in-memory mode.
  • Loading branch information
cg123 authored Jan 25, 2025
1 parent 84c83f8 commit 269eb63
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 0 deletions.
6 changes: 6 additions & 0 deletions mergekit/evo/actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(
vllm: bool = False,
batch_size: Optional[int] = None,
task_manager: Optional[lm_eval.tasks.TaskManager] = None,
quantization_config: Optional[transformers.BitsAndBytesConfig] = None,
):
self.config = config
self.genome = genome
Expand All @@ -72,6 +73,7 @@ def __init__(
self.vllm = vllm
self.batch_size = batch_size
self.task_manager = task_manager
self.quantization_config = quantization_config

if config.shuffle:
monkeypatch_lmeval_shuffle()
Expand Down Expand Up @@ -105,6 +107,9 @@ def evaluate_genotype(
logging.error("Model merge failed")
return {"score": None, "results": None}

kwargs = {}
if self.quantization_config is not None:
kwargs["quantization_config"] = self.quantization_config
logging.info(f"Model merged to {merged_path}")
return evaluate_model(
merged_path,
Expand All @@ -114,6 +119,7 @@ def evaluate_genotype(
vllm=self.vllm,
batch_size=self.batch_size,
task_manager=self.task_manager,
**kwargs,
)


Expand Down
2 changes: 2 additions & 0 deletions mergekit/evo/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,15 @@ def evaluate_model(
vllm: bool,
batch_size: Optional[int] = None,
task_manager: Optional[lm_eval.tasks.TaskManager] = None,
**kwargs,
) -> dict:
# monkeypatch_tqdm()
monkeypatch_lmeval_vllm()
try:
model_args = {
"pretrained": merged_path,
"dtype": "bfloat16",
**kwargs,
}
if vllm:
model_args["gpu_memory_utilization"] = 0.8
Expand Down
18 changes: 18 additions & 0 deletions mergekit/evo/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import ray.util.queue
import ray.util.scheduling_strategies
import torch
import transformers

from mergekit.evo.actors import InMemoryMergeEvaluator, OnDiskMergeEvaluator
from mergekit.evo.config import EvolMergeConfiguration
Expand All @@ -43,6 +44,7 @@ def __init__(
batch_size: Optional[int] = None,
task_search_path: Union[str, List[str], None] = None,
model_storage_path: Optional[str] = None,
quantization_config: Optional[transformers.BitsAndBytesConfig] = None,
):
self.config = config
self.genome = genome
Expand All @@ -51,6 +53,7 @@ def __init__(
self.batch_size = batch_size
self.task_manager = lm_eval.tasks.TaskManager(include_path=task_search_path)
self.model_storage_path = model_storage_path
self.quantization_config = quantization_config
if self.model_storage_path:
os.makedirs(self.model_storage_path, exist_ok=True)

Expand Down Expand Up @@ -91,6 +94,7 @@ def __init__(
vllm=vllm,
batch_size=self.batch_size,
task_manager=self.task_manager,
quantization_config=self.quantization_config,
)
for _ in range(self.num_gpus)
]
Expand Down Expand Up @@ -120,6 +124,7 @@ def __init__(
batch_size: Optional[int] = None,
task_manager: Optional[lm_eval.tasks.TaskManager] = None,
model_storage_path: Optional[str] = None,
quantization_config: Optional[transformers.BitsAndBytesConfig] = None,
):
self.config = config
self.genome = genome
Expand All @@ -130,6 +135,7 @@ def __init__(
self.batch_size = batch_size
self.task_manager = task_manager
self.model_storage_path = model_storage_path
self.quantization_config = quantization_config
self._shutdown = False

async def evaluate_genotype(self, genotype: np.ndarray):
Expand Down Expand Up @@ -159,6 +165,9 @@ async def process_queue(self):

while merged and len(evaluating) < self.num_gpus:
future_result, merged_path = merged.pop()
kwargs = {}
if self.quantization_config is not None:
kwargs["quantization_config"] = self.quantization_config
evaluating[
evaluate_model_ray.remote(
merged_path,
Expand All @@ -168,6 +177,7 @@ async def process_queue(self):
vllm=self.vllm,
batch_size=self.batch_size,
task_manager=self.task_manager,
**kwargs,
)
] = future_result

Expand Down Expand Up @@ -222,6 +232,8 @@ def __init__(
vllm=vllm,
num_gpus=self.num_gpus,
task_manager=self.task_manager,
batch_size=self.batch_size,
quantization_config=self.quantization_config,
)
self.actor.process_queue.remote()

Expand All @@ -242,6 +254,7 @@ def evaluate_genotype_serial(
vllm: bool = False,
batch_size: Optional[int] = None,
task_manager: Optional[lm_eval.tasks.TaskManager] = None,
quantization_config: Optional[transformers.BitsAndBytesConfig] = None,
):
pg = ray.util.placement_group([{"CPU": 1, "GPU": 1}], strategy="STRICT_PACK")
strat = ray.util.scheduling_strategies.PlacementGroupSchedulingStrategy(
Expand All @@ -252,6 +265,9 @@ def evaluate_genotype_serial(
)
if not merged_path:
return {"score": None, "results": None}
kwargs = {}
if quantization_config is not None:
kwargs["quantization_config"] = quantization_config
res = ray.get(
evaluate_model_ray.options(scheduling_strategy=strat).remote(
merged_path,
Expand All @@ -261,6 +277,7 @@ def evaluate_genotype_serial(
vllm=vllm,
batch_size=batch_size,
task_manager=task_manager,
**kwargs,
)
)
ray.util.remove_placement_group(pg)
Expand Down Expand Up @@ -292,6 +309,7 @@ def evaluate_genotypes(self, genotypes: List[np.ndarray]) -> List[dict]:
vllm=self.vllm,
batch_size=self.batch_size,
task_manager=self.task_manager,
quantization_config=self.quantization_config,
)
for x in genotypes
]
Expand Down
38 changes: 38 additions & 0 deletions mergekit/scripts/evolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,18 @@
default=None,
help="Maximum time to run the optimization in seconds",
)
@click.option(
"--load-in-8bit",
is_flag=True,
default=False,
help="Evaluate models at 8-bit precision",
)
@click.option(
"--load-in-4bit",
is_flag=True,
default=False,
help="Evaluate models at 4-bit precision",
)
@click.option(
"--force-population-size",
type=int,
Expand Down Expand Up @@ -142,6 +154,8 @@ def main(
save_final_model: bool,
reshard: bool,
timeout: Optional[float],
load_in_8bit: bool,
load_in_4bit: bool,
force_population_size: Optional[int],
):
config = EvolMergeConfiguration.model_validate(
Expand All @@ -150,6 +164,29 @@ def main(

check_for_naughty_config(config, allow=allow_benchmark_tasks)

if load_in_4bit and load_in_8bit:
raise ValueError("Cannot load models in both 4-bit and 8-bit")

if load_in_4bit or load_in_8bit:
if vllm:
raise ValueError("Cannot use vLLM with 4-bit or 8-bit models")
if in_memory:
raise ValueError("Cannot use in-memory mode with 4-bit or 8-bit models")
try:
import bitsandbytes
except ImportError:
raise RuntimeError("bitsandbytes is not installed")

bnb_config = transformers.BitsAndBytesConfig(
load_in_8bit=load_in_8bit,
load_in_4bit=load_in_4bit,
bnb_4bit_compute_dtype="bfloat16",
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
)
else:
bnb_config = None

if use_wandb:
if not wandb:
raise RuntimeError("wandb is not installed")
Expand Down Expand Up @@ -235,6 +272,7 @@ def main(
model_storage_path=os.path.join(storage_path, "merged"),
batch_size=batch_size,
task_search_path=task_search_path,
quantization_config=bnb_config,
)

x0 = genome.initial_genotype(random=config.random_init).view(-1).numpy()
Expand Down

0 comments on commit 269eb63

Please sign in to comment.