-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathDPMs_Training.py
95 lines (77 loc) · 3.53 KB
/
DPMs_Training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#Training
from diffusers import DPMSolverMultistepScheduler,get_cosine_schedule_with_warmup,DPMSolverSinglestepScheduler
import torch.functional as F
import random
import torch.nn as nn
from tqdm import tqdm_notebook
from tqdm import tqdm
from utils import Unet_Conditional,US8K
from torchvision.transforms import RandAugment
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
@dataclass
class TrainingConfig:
image_size = 224 # the generated image resolution
train_batch_size = 16
eval_batch_size = 16 # how many images to sample during evaluation
num_epochs = 500
gradient_accumulation_steps = 1
learning_rate = 1e-4
lr_warmup_steps = 500
save_image_epochs = 10
save_model_epochs = 30
mixed_precision = 'fp16' # `no` for float32, `fp16` for automatic mixed precision
output_dir = 'ddpm-butterflies-128' # the model namy locally and on the HF Hub
push_to_hub = False # whether to upload the saved model to the HF Hub
hub_private_repo = False
overwrite_output_dir = True # overwrite the old model when re-running the notebook
seed = 0
config = TrainingConfig()
dataset1=US8K(transform_size=128,train=True,root="Preprocessing_us8k")
dataset2=US8K(transform_size=128,train=True,root="Preprocessing_us8k_augmentation")
urbansound8k=["air_conditioner","car_horn","children_playing","dog_bark","drilling","engine_idling","gun_shot","jackhammer","siren","street_music"]
noise_scheduler=DPMSolverMultistepScheduler(num_train_timesteps=1000)
noise_scheduler.set_timesteps(num_inference_steps=20)
loss=nn.MSELoss()
device='cuda'
model = Unet_Conditional(labels_dim=10,dim=64).to(device)
model=torch.load('./Diffusion_us8k_dim64/Diffusion_us8k_dim64_90augmentation_175.pt')
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
print(sum(p.numel() for p in model.parameters() if p.requires_grad))
def forward(model,scheduler,config,batch_size=16,sample_class=1,device='cuda'):
sample=torch.randn(batch_size,3,config.image_size,config.image_size).to(device)
for i,t in enumerate(tqdm(scheduler.timesteps)):
#print(t.shape)
with torch.no_grad():
residual=model(sample,t=t*torch.ones(batch_size).long().to(device),label=sample_class*torch.ones(batch_size).long().to(device))
sample=scheduler.step(residual,t,sample).prev_sample
return sample
epoches=3500
for epoch in tqdm_notebook(range(0,epoches)):
if(random.randint(0,20)==0):
dataloader=DataLoader(dataset1, batch_size=30, shuffle=True)
else:
dataloader=DataLoader(dataset2, batch_size=30, shuffle=True)
for data,label in tqdm_notebook(dataloader):
#print(label.shape)
data=data.to(device)
# data=255*data
# data=torch.tensor(data,dtype=torch.uint8)
# data=augmen(data)
# data=data.float()
# data=data/255
label=torch.argmax(label,dim=1).to(device).long()
optimizer.zero_grad()
noise=torch.randn_like(data)
timesteps=torch.randint(0,noise_scheduler.num_train_timesteps,(data.shape[0],)).to(device)
noisy_image=noise_scheduler.add_noise(data,noise,timesteps)
noise_pred=model(noisy_image,time=timesteps,label=label.long())
loss_val=loss(noise_pred,noise)
loss_val.backward()
optimizer.step()
if(epoch%5==0):
print("Epoch: ",epoch,"Loss: ",loss_val.item())
torch.save(model,f'./Diffusion_us8k_dim64/Diffusion_us8k_dim64_90augmentation_{epoch}.pt')