From 41757e45643bcb5fb79b41e6deffa94c054d2a44 Mon Sep 17 00:00:00 2001 From: tanghengjian Date: Tue, 16 Apr 2024 14:27:15 +0800 Subject: [PATCH] Code optimization --- ip_adapter/attention_processor.py | 12 ++++++------ ip_adapter/ip_adapter.py | 22 +++++++--------------- 2 files changed, 13 insertions(+), 21 deletions(-) diff --git a/ip_adapter/attention_processor.py b/ip_adapter/attention_processor.py index 6d30ffa..4907e56 100644 --- a/ip_adapter/attention_processor.py +++ b/ip_adapter/attention_processor.py @@ -91,14 +91,14 @@ class IPAttnProcessor(nn.Module): The context length of the image features. """ - def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, skip=False): + def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, selected=True): super().__init__() self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim self.scale = scale self.num_tokens = num_tokens - self.skip = skip + self.selected = selected self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) @@ -155,7 +155,7 @@ def __call__( hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) - if not self.skip: + if self.selected: # for ip-adapter ip_key = self.to_k_ip(ip_hidden_states) ip_value = self.to_v_ip(ip_hidden_states) @@ -289,7 +289,7 @@ class IPAttnProcessor2_0(torch.nn.Module): The context length of the image features. """ - def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, skip=False): + def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, selected=True): super().__init__() if not hasattr(F, "scaled_dot_product_attention"): @@ -299,7 +299,7 @@ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens= self.cross_attention_dim = cross_attention_dim self.scale = scale self.num_tokens = num_tokens - self.skip = skip + self.selected = selected self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) @@ -370,7 +370,7 @@ def __call__( hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) - if not self.skip: + if self.selected: # for ip-adapter ip_key = self.to_k_ip(ip_hidden_states) ip_value = self.to_v_ip(ip_hidden_states) diff --git a/ip_adapter/ip_adapter.py b/ip_adapter/ip_adapter.py index 329a8e6..fd47e5e 100644 --- a/ip_adapter/ip_adapter.py +++ b/ip_adapter/ip_adapter.py @@ -113,21 +113,13 @@ def set_ip_adapter(self): if block_name in name: selected = True break - if selected: - attn_procs[name] = IPAttnProcessor( - hidden_size=hidden_size, - cross_attention_dim=cross_attention_dim, - scale=1.0, - num_tokens=self.num_tokens, - ).to(self.device, dtype=torch.float16) - else: - attn_procs[name] = IPAttnProcessor( - hidden_size=hidden_size, - cross_attention_dim=cross_attention_dim, - scale=1.0, - num_tokens=self.num_tokens, - skip=True - ).to(self.device, dtype=torch.float16) + attn_procs[name] = IPAttnProcessor( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, + num_tokens=self.num_tokens, + selected=selected + ).to(self.device, dtype=torch.float16) unet.set_attn_processor(attn_procs) if hasattr(self.pipe, "controlnet"): if isinstance(self.pipe.controlnet, MultiControlNetModel):