diff --git a/src/nodes/pipeline_loader.py b/src/nodes/pipeline_loader.py index f7e8727..089e2a1 100644 --- a/src/nodes/pipeline_loader.py +++ b/src/nodes/pipeline_loader.py @@ -47,37 +47,37 @@ def load_pipeline(self, weight_dtype): WEIGHTS_PATH, subfolder="vae", torch_dtype=weight_dtype - ).requires_grad_(False).eval().to(DEVICE) + ).requires_grad_(False).eval() unet = UNet2DConditionModel.from_pretrained( WEIGHTS_PATH, subfolder="unet", torch_dtype=weight_dtype - ).requires_grad_(False).eval().to(DEVICE) + ).requires_grad_(False).eval() image_encoder = CLIPVisionModelWithProjection.from_pretrained( WEIGHTS_PATH, subfolder="image_encoder", torch_dtype=weight_dtype - ).requires_grad_(False).eval().to(DEVICE) + ).requires_grad_(False).eval() unet_encoder = UNet2DConditionModel_ref.from_pretrained( WEIGHTS_PATH, subfolder="unet_encoder", torch_dtype=weight_dtype - ).requires_grad_(False).eval().to(DEVICE) + ).requires_grad_(False).eval() text_encoder_one = CLIPTextModel.from_pretrained( WEIGHTS_PATH, subfolder="text_encoder", torch_dtype=weight_dtype - ).requires_grad_(False).eval().to(DEVICE) + ).requires_grad_(False).eval() text_encoder_two = CLIPTextModelWithProjection.from_pretrained( WEIGHTS_PATH, subfolder="text_encoder_2", torch_dtype=weight_dtype - ).requires_grad_(False).eval().to(DEVICE) + ).requires_grad_(False).eval() tokenizer_one = AutoTokenizer.from_pretrained( WEIGHTS_PATH, @@ -106,8 +106,8 @@ def load_pipeline(self, weight_dtype): image_encoder=image_encoder, torch_dtype=weight_dtype, ) - pipe.unet_encoder = unet_encoder - pipe = pipe.to(DEVICE) pipe.weight_dtype = weight_dtype - + pipe.unet_encoder = unet_encoder + pipe.enable_sequential_cpu_offload() + pipe.unet_encoder.to(DEVICE) return (pipe, ) \ No newline at end of file