Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster DeepSpeed engine initialization #140

Merged
merged 1 commit into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ run_docker_cuda:
--rm \
--pid host \
--shm-size 64G \
--gpus '"device=0,1"' \
--gpus all \
--entrypoint /bin/bash \
--volume $(PWD):/workspace \
--workdir /workspace \
Expand All @@ -81,8 +81,7 @@ run_docker_rocm:
--pid host \
--shm-size 64G \
--device /dev/kfd \
--device /dev/dri/renderD128 \
--device /dev/dri/renderD129 \
--device /dev/dri/ \
--entrypoint /bin/bash \
--volume $(PWD):/workspace \
--workdir /workspace \
Expand Down
33 changes: 22 additions & 11 deletions optimum_benchmark/backends/pytorch/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,18 +142,33 @@ def load_model_from_pretrained(self) -> None:
LOGGER.info(f"\t+ Moving pipeline to device: {self.config.device}")
self.pretrained_model.to(self.config.device)
elif self.config.deepspeed_inference:
with torch.device("cpu"):
LOGGER.info("\t+ Loading DeepSpeed model directly on CPU to avoid OOM")
self.pretrained_model = self.automodel_class.from_pretrained(
pretrained_model_name_or_path=self.config.model, **self.config.hub_kwargs, **self.automodel_kwargs
)
if self.config.no_weights:
with torch.device("meta"):
LOGGER.info("\t+ Loading model on meta device for fast initialization")
self.pretrained_model = self.automodel_class.from_pretrained(
pretrained_model_name_or_path=self.config.model,
**self.config.hub_kwargs,
**self.automodel_kwargs,
)
LOGGER.info("\t+ Materializing model on CPU")
self.pretrained_model.to_empty(device="cpu")
LOGGER.info("\t+ Tying model weights")
self.pretrained_model.tie_weights()
else:
LOGGER.info("\t+ Loading model on cpu to avoid OOM")
with torch.device("cpu"):
self.pretrained_model = self.automodel_class.from_pretrained(
pretrained_model_name_or_path=self.config.model,
**self.config.hub_kwargs,
**self.automodel_kwargs,
)

torch.distributed.barrier() # better safe than hanging
LOGGER.info("\t+ Initializing DeepSpeed Inference")
LOGGER.info("\t+ Initializing DeepSpeed Inference Engine")
self.pretrained_model = init_inference(self.pretrained_model, config=self.config.deepspeed_inference_config)
torch.distributed.barrier() # better safe than hanging
elif self.is_quantized:
# we can't use device context manager since the model is quantized
# we can't use device context manager on quantized models
LOGGER.info("\t+ Loading Quantized model")
self.pretrained_model = self.automodel_class.from_pretrained(
pretrained_model_name_or_path=self.config.model,
Expand Down Expand Up @@ -218,10 +233,6 @@ def load_model_with_no_weights(self) -> None:
self.load_model_from_pretrained()
self.config.model = original_model

# dunno how necessary this is
LOGGER.info("\t+ Tying model weights")
self.pretrained_model.tie_weights()

def process_quantization_config(self) -> None:
if self.is_gptq_quantized:
LOGGER.info("\t+ Processing GPTQ config")
Expand Down
Loading