Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

style: autoformat code #36

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
394 changes: 262 additions & 132 deletions garment_adapter/attention_processor.py

Large diffs are not rendered by default.

216 changes: 150 additions & 66 deletions garment_adapter/garment_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,31 @@
if is_torch2_available():
from .attention_processor import REFAttnProcessor2_0 as REFAttnProcessor
from .attention_processor import AttnProcessor2_0 as AttnProcessor
from .attention_processor import REFAnimateDiffAttnProcessor2_0 as REFAnimateDiffAttnProcessor
from .attention_processor import (
REFAnimateDiffAttnProcessor2_0 as REFAnimateDiffAttnProcessor,
)
else:
from .attention_processor import REFAttnProcessor, AttnProcessor


class ClothAdapter:
def __init__(self, sd_pipe, ref_path, device, enable_cloth_guidance, set_seg_model=True):
def __init__(
self, sd_pipe, ref_path, device, enable_cloth_guidance, set_seg_model=True
):
self.enable_cloth_guidance = enable_cloth_guidance
self.device = device
self.pipe = sd_pipe.to(self.device)
self.set_adapter(self.pipe.unet, "write")

ref_unet = copy.deepcopy(sd_pipe.unet)
if ref_unet.config.in_channels == 9:
ref_unet.conv_in = torch.nn.Conv2d(4, 320, ref_unet.conv_in.kernel_size, ref_unet.conv_in.stride, ref_unet.conv_in.padding)
ref_unet.conv_in = torch.nn.Conv2d(
4,
320,
ref_unet.conv_in.kernel_size,
ref_unet.conv_in.stride,
ref_unet.conv_in.padding,
)
ref_unet.register_to_config(in_channels=4)
state_dict = {}
with safe_open(ref_path, framework="pt", device="cpu") as f:
Expand All @@ -36,8 +46,10 @@ def __init__(self, sd_pipe, ref_path, device, enable_cloth_guidance, set_seg_mod
self.set_seg_model()
self.attn_store = {}

def set_seg_model(self, ):
checkpoint_path = 'checkpoints/cloth_segm.pth'
def set_seg_model(
self,
):
checkpoint_path = "checkpoints/cloth_segm.pth"
self.seg_net = load_seg_model(checkpoint_path, device=self.device)

def set_adapter(self, unet, type):
Expand All @@ -50,23 +62,25 @@ def set_adapter(self, unet, type):
unet.set_attn_processor(attn_procs)

def generate(
self,
cloth_image,
cloth_mask_image=None,
prompt=None,
a_prompt="best quality, high quality",
num_images_per_prompt=4,
negative_prompt=None,
seed=-1,
guidance_scale=7.5,
cloth_guidance_scale=2.5,
num_inference_steps=20,
height=512,
width=384,
**kwargs,
self,
cloth_image,
cloth_mask_image=None,
prompt=None,
a_prompt="best quality, high quality",
num_images_per_prompt=4,
negative_prompt=None,
seed=-1,
guidance_scale=7.5,
cloth_guidance_scale=2.5,
num_inference_steps=20,
height=512,
width=384,
**kwargs,
):
if cloth_mask_image is None:
cloth_mask_image = generate_mask(cloth_image, net=self.seg_net, device=self.device)
cloth_mask_image = generate_mask(
cloth_image, net=self.seg_net, device=self.device
)

cloth = prepare_image(cloth_image, height, width)
cloth_mask = prepare_mask(cloth_mask_image, height, width)
Expand All @@ -76,7 +90,9 @@ def generate(
prompt = "a photography of a model"
prompt = prompt + ", " + a_prompt
if negative_prompt is None:
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
negative_prompt = (
"monochrome, lowres, bad anatomy, worst quality, low quality"
)

with torch.inference_mode():
prompt_embeds, negative_prompt_embeds = self.pipe.encode_prompt(
Expand All @@ -86,11 +102,26 @@ def generate(
do_classifier_free_guidance=True,
negative_prompt=negative_prompt,
)
prompt_embeds_null = self.pipe.encode_prompt([""], device=self.device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=False)[0]
cloth_embeds = self.pipe.vae.encode(cloth).latent_dist.mode() * self.pipe.vae.config.scaling_factor
self.ref_unet(torch.cat([cloth_embeds] * num_images_per_prompt), 0, prompt_embeds_null, cross_attention_kwargs={"attn_store": self.attn_store})
prompt_embeds_null = self.pipe.encode_prompt(
[""],
device=self.device,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=False,
)[0]
cloth_embeds = (
self.pipe.vae.encode(cloth).latent_dist.mode()
* self.pipe.vae.config.scaling_factor
)
self.ref_unet(
torch.cat([cloth_embeds] * num_images_per_prompt),
0,
prompt_embeds_null,
cross_attention_kwargs={"attn_store": self.attn_store},
)

generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
generator = (
torch.Generator(self.device).manual_seed(seed) if seed is not None else None
)
if self.enable_cloth_guidance:
images = self.pipe(
prompt_embeds=prompt_embeds,
Expand All @@ -101,7 +132,11 @@ def generate(
generator=generator,
height=height,
width=width,
cross_attention_kwargs={"attn_store": self.attn_store, "do_classifier_free_guidance": guidance_scale > 1.0, "enable_cloth_guidance": self.enable_cloth_guidance},
cross_attention_kwargs={
"attn_store": self.attn_store,
"do_classifier_free_guidance": guidance_scale > 1.0,
"enable_cloth_guidance": self.enable_cloth_guidance,
},
**kwargs,
).images
else:
Expand All @@ -113,45 +148,70 @@ def generate(
generator=generator,
height=height,
width=width,
cross_attention_kwargs={"attn_store": self.attn_store, "do_classifier_free_guidance": guidance_scale > 1.0, "enable_cloth_guidance": self.enable_cloth_guidance},
cross_attention_kwargs={
"attn_store": self.attn_store,
"do_classifier_free_guidance": guidance_scale > 1.0,
"enable_cloth_guidance": self.enable_cloth_guidance,
},
**kwargs,
).images

return images, cloth_mask_image

def generate_inpainting(
self,
cloth_image,
cloth_mask_image=None,
num_images_per_prompt=4,
seed=-1,
cloth_guidance_scale=2.5,
num_inference_steps=20,
height=512,
width=384,
**kwargs,
self,
cloth_image,
cloth_mask_image=None,
num_images_per_prompt=4,
seed=-1,
cloth_guidance_scale=2.5,
num_inference_steps=20,
height=512,
width=384,
**kwargs,
):
if cloth_mask_image is None:
cloth_mask_image = generate_mask(cloth_image, net=self.seg_net, device=self.device)
cloth_mask_image = generate_mask(
cloth_image, net=self.seg_net, device=self.device
)

cloth = prepare_image(cloth_image, height, width)
cloth_mask = prepare_mask(cloth_mask_image, height, width)
cloth = (cloth * cloth_mask).to(self.device, dtype=torch.float16)

with torch.inference_mode():
prompt_embeds_null = self.pipe.encode_prompt([""], device=self.device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=False)[0]
cloth_embeds = self.pipe.vae.encode(cloth).latent_dist.mode() * self.pipe.vae.config.scaling_factor
self.ref_unet(torch.cat([cloth_embeds] * num_images_per_prompt), 0, prompt_embeds_null, cross_attention_kwargs={"attn_store": self.attn_store})
prompt_embeds_null = self.pipe.encode_prompt(
[""],
device=self.device,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=False,
)[0]
cloth_embeds = (
self.pipe.vae.encode(cloth).latent_dist.mode()
* self.pipe.vae.config.scaling_factor
)
self.ref_unet(
torch.cat([cloth_embeds] * num_images_per_prompt),
0,
prompt_embeds_null,
cross_attention_kwargs={"attn_store": self.attn_store},
)

generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
generator = (
torch.Generator(self.device).manual_seed(seed) if seed is not None else None
)
images = self.pipe(
prompt_embeds=prompt_embeds_null,
cloth_guidance_scale=cloth_guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
height=height,
width=width,
cross_attention_kwargs={"attn_store": self.attn_store, "do_classifier_free_guidance": cloth_guidance_scale > 1.0, "enable_cloth_guidance": False},
cross_attention_kwargs={
"attn_store": self.attn_store,
"do_classifier_free_guidance": cloth_guidance_scale > 1.0,
"enable_cloth_guidance": False,
},
**kwargs,
).images

Expand All @@ -164,7 +224,9 @@ def __init__(self, sd_pipe, pipe_path, ref_path, device, set_seg_model=True):
self.pipe = sd_pipe.to(self.device)
self.set_adapter(self.pipe.unet, "write")

ref_unet = UNet2DConditionModel.from_pretrained(pipe_path, subfolder='unet', torch_dtype=sd_pipe.dtype)
ref_unet = UNet2DConditionModel.from_pretrained(
pipe_path, subfolder="unet", torch_dtype=sd_pipe.dtype
)
state_dict = {}
with safe_open(ref_path, framework="pt", device="cpu") as f:
for key in f.keys():
Expand All @@ -177,8 +239,10 @@ def __init__(self, sd_pipe, pipe_path, ref_path, device, set_seg_model=True):
self.set_seg_model()
self.attn_store = {}

def set_seg_model(self, ):
checkpoint_path = 'checkpoints/cloth_segm.pth'
def set_seg_model(
self,
):
checkpoint_path = "checkpoints/cloth_segm.pth"
self.seg_net = load_seg_model(checkpoint_path, device=self.device)

def set_adapter(self, unet, type):
Expand All @@ -191,23 +255,25 @@ def set_adapter(self, unet, type):
unet.set_attn_processor(attn_procs)

def generate(
self,
cloth_image,
cloth_mask_image=None,
prompt=None,
a_prompt="best quality, high quality",
num_images_per_prompt=4,
negative_prompt=None,
seed=-1,
guidance_scale=7.5,
cloth_guidance_scale=3.,
num_inference_steps=20,
height=512,
width=384,
**kwargs,
self,
cloth_image,
cloth_mask_image=None,
prompt=None,
a_prompt="best quality, high quality",
num_images_per_prompt=4,
negative_prompt=None,
seed=-1,
guidance_scale=7.5,
cloth_guidance_scale=3.0,
num_inference_steps=20,
height=512,
width=384,
**kwargs,
):
if cloth_mask_image is None:
cloth_mask_image = generate_mask(cloth_image, net=self.seg_net, device=self.device)
cloth_mask_image = generate_mask(
cloth_image, net=self.seg_net, device=self.device
)

cloth = prepare_image(cloth_image, height, width)
cloth_mask = prepare_mask(cloth_mask_image, height, width)
Expand All @@ -227,11 +293,26 @@ def generate(
do_classifier_free_guidance=True,
negative_prompt=negative_prompt,
)
prompt_embeds_null = self.pipe.encode_prompt([""], device=self.device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=False)[0]
cloth_embeds = self.pipe.vae.encode(cloth).latent_dist.mode() * self.pipe.vae.config.scaling_factor
self.ref_unet(torch.cat([cloth_embeds] * num_images_per_prompt), 0, prompt_embeds_null, cross_attention_kwargs={"attn_store": self.attn_store})
prompt_embeds_null = self.pipe.encode_prompt(
[""],
device=self.device,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=False,
)[0]
cloth_embeds = (
self.pipe.vae.encode(cloth).latent_dist.mode()
* self.pipe.vae.config.scaling_factor
)
self.ref_unet(
torch.cat([cloth_embeds] * num_images_per_prompt),
0,
prompt_embeds_null,
cross_attention_kwargs={"attn_store": self.attn_store},
)

generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
generator = (
torch.Generator(self.device).manual_seed(seed) if seed is not None else None
)
frames = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
Expand All @@ -241,7 +322,10 @@ def generate(
generator=generator,
height=height,
width=width,
cross_attention_kwargs={"attn_store": self.attn_store, "do_classifier_free_guidance": guidance_scale > 1.0},
cross_attention_kwargs={
"attn_store": self.attn_store,
"do_classifier_free_guidance": guidance_scale > 1.0,
},
**kwargs,
).frames

Expand Down
Loading