Skip to content

Commit

Permalink
add utils
Browse files Browse the repository at this point in the history
  • Loading branch information
avdravid authored Jun 13, 2024
1 parent d72f365 commit 9ea2047
Showing 1 changed file with 161 additions and 0 deletions.
161 changes: 161 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import torch
import torchvision
import os
import shutil
import gc
import tqdm
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from transformers import CLIPTextModel
from lora_w2w import LoRAw2w
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel, LMSDiscreteScheduler
from safetensors.torch import save_file
from transformers import AutoTokenizer, PretrainedConfig
from PIL import Image
import warnings
warnings.filterwarnings("ignore")
from diffusers import (
AutoencoderKL,
DDPMScheduler,
DiffusionPipeline,
DPMSolverMultistepScheduler,
UNet2DConditionModel,
PNDMScheduler,
StableDiffusionPipeline
)



######## Basic utilities

### load base models
def load_models(device):
pretrained_model_name_or_path = "stablediffusionapi/realistic-vision-v51"

revision = None
rank = 1
weight_dtype = torch.bfloat16

# Load scheduler, tokenizer and models.
pipe = StableDiffusionPipeline.from_pretrained("stablediffusionapi/realistic-vision-v51",
torch_dtype=torch.float16,safety_checker = None,
requires_safety_checker = False).to(device)
noise_scheduler = pipe.scheduler
del pipe
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path, subfolder="tokenizer", revision=revision
)
text_encoder = CLIPTextModel.from_pretrained(
pretrained_model_name_or_path, subfolder="text_encoder", revision=revision
)
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae", revision=revision)
unet = UNet2DConditionModel.from_pretrained(
pretrained_model_name_or_path, subfolder="unet", revision=revision
)

unet.requires_grad_(False)
unet.to(device, dtype=weight_dtype)
vae.requires_grad_(False)

text_encoder.requires_grad_(False)
vae.requires_grad_(False)
vae.to(device, dtype=weight_dtype)
text_encoder.to(device, dtype=weight_dtype)
print("")

return unet, vae, text_encoder, tokenizer, noise_scheduler



### basic inference to generate images conditioned on text prompts
@torch.no_grad
def inference(network, unet, vae, text_encoder, tokenizer, prompt, negative_prompt, guidance_scale, noise_scheduler, ddim_steps, seed, generator, device):
generator = generator.manual_seed(seed)
latents = torch.randn(
(1, unet.in_channels, 512 // 8, 512 // 8),
generator = generator,
device = device
).bfloat16()


text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")

text_embeddings = text_encoder(text_input.input_ids.to(device))[0]

max_length = text_input.input_ids.shape[-1]
uncond_input = tokenizer(
[negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt"
)
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
noise_scheduler.set_timesteps(ddim_steps)
latents = latents * noise_scheduler.init_noise_sigma

for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)):
latent_model_input = torch.cat([latents] * 2)
latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t)
with network:
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
#guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = noise_scheduler.step(noise_pred, t, latents).prev_sample

latents = 1 / 0.18215 * latents
image = vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)

return image



### save model in w2w space (principal component representation)
def save_model_w2w(network, path):
proj = network.proj.clone().detach().float()

if not os.path.exists(path):
os.makedirs(path)

torch.save(proj, path+"/"+"w2wmodel.pt")


### save model in format compatible with Diffusers
def save_model_for_diffusers(network,std, mean, v, weight_dimensions, path):
proj = network.proj.clone().detach()
unproj = torch.matmul(proj,v[:, :].T)*std+mean

final_weights0 = {}
counter = 0
for key in weight_dimensions.keys():
final_weights0[key] = unproj[0, counter:counter+weight_dimensions[key][0][0]].unflatten(0, weight_dimensions[key][1])
counter += weight_dimensions[key][0][0]

#renaming keys to be compatible with Diffusers
for key in list(final_weights0.keys()):
final_weights0[key.replace( "lora_unet_", "base_model.model.").replace("A", "down").replace("B", "up").replace( "weight", "identity1.weight").replace("_lora", ".lora").replace("lora_down", "lora_A").replace("lora_up", "lora_B")] = final_weights0.pop(key)



final_weights0_keys = sorted(final_weights0.keys())

final_weights = {}
for i,key in enumerate(final_weights0_keys):
final_weights[key] = final_weights0[key]

if not os.path.exists(path):
os.makedirs(path+"/unet")
else:
os.mkdir(path+"/unet")



#add config for PeftConfig
shutil.copyfile("../files/adapter_config.json", path+"/unet/adapter_config.json")

save_file(final_weights, path+"/unet/adapter_model.safetensors")






0 comments on commit 9ea2047

Please sign in to comment.