diff --git a/sampling.py b/sampling.py new file mode 100644 index 0000000..f07341a --- /dev/null +++ b/sampling.py @@ -0,0 +1,56 @@ +import torch +import torchvision +import os +import gc +import tqdm +import matplotlib.pyplot as plt +import torchvision.transforms as transforms +from transformers import CLIPTextModel +from peft import PeftModel, LoraConfig +from lora_w2w import LoRAw2w +from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel, LMSDiscreteScheduler +from peft.utils.save_and_load import load_peft_weights, set_peft_model_state_dict +from transformers import AutoTokenizer, PretrainedConfig +from PIL import Image +import warnings +warnings.filterwarnings("ignore") +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + DiffusionPipeline, + DPMSolverMultistepScheduler, + UNet2DConditionModel, + PNDMScheduler, + StableDiffusionPipeline +) + + + +######## Sampling utilities + + +def sample_weights(unet, proj, mean, std, v, device, factor = 1.0): + # get mean and standard deviation for each principal component + m = torch.mean(proj, 0) + standev = torch.std(proj, 0) + del proj + torch.cuda.empty_cache() + # sample + sample = torch.zeros([1, 1000]).to(device) + for i in range(1000): + sample[0, i] = torch.normal(m[i], factor*standev[i], (1,1)) + + # load weights into network + network = LoRAw2w( sample, mean, std, v, + unet, + rank=1, + multiplier=1.0, + alpha=27.0, + train_method="xattn-strict" + ).to(device, torch.bfloat16) + + return network + + + +