From 675956010ffc30ceddf7bda9d071a8f6c1eee725 Mon Sep 17 00:00:00 2001 From: Disty0 Date: Thu, 30 Jan 2025 18:44:26 +0300 Subject: [PATCH] Multiply the slice rate for dyn bmm by 4 --- installer.py | 2 +- modules/sd_hijack_dynamic_atten.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/installer.py b/installer.py index 6c850a877..9dacf4682 100644 --- a/installer.py +++ b/installer.py @@ -707,7 +707,7 @@ def install_ipex(torch_command): if args.use_nightly: torch_command = os.environ.get('TORCH_COMMAND', '--pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/xpu') if os.environ.get('TRITON_COMMAND', None) is None: - os.environ.setdefault('TRITON_COMMAND', 'skip') # pytorch auto installs pytorch-triton-rocm as a dependency instead + os.environ.setdefault('TRITON_COMMAND', 'skip') # pytorch auto installs pytorch-triton-xpu as a dependency instead else: if "linux" in sys.platform: # default to US server. If The China server is needed, change .../release-whl/stable/xpu/us/ to .../release-whl/stable/xpu/cn/ diff --git a/modules/sd_hijack_dynamic_atten.py b/modules/sd_hijack_dynamic_atten.py index bba5cbc2a..6c3e69e3a 100644 --- a/modules/sd_hijack_dynamic_atten.py +++ b/modules/sd_hijack_dynamic_atten.py @@ -116,7 +116,7 @@ def dynamic_scaled_dot_product_attention(query, key, value, attn_mask=None, drop @cache -def find_bmm_slice_sizes(query_shape, query_element_size, slice_rate=4, trigger_rate=6): +def find_bmm_slice_sizes(query_shape, query_element_size, slice_rate=2, trigger_rate=4): if len(query_shape) == 3: batch_size_attention, query_tokens, shape_three = query_shape shape_four = 1 @@ -197,7 +197,7 @@ def __call__(self, attn, hidden_states: torch.Tensor, encoder_hidden_states=None # Slicing parts: batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2] hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype) - do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_bmm_slice_sizes(query.shape, query.element_size(), slice_rate=shared.opts.dynamic_attention_slice_rate, trigger_rate=shared.opts.dynamic_attention_trigger_rate) + do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_bmm_slice_sizes(query.shape, query.element_size(), slice_rate=shared.opts.dynamic_attention_slice_rate*4, trigger_rate=shared.opts.dynamic_attention_trigger_rate*4) if do_split: for i in range(batch_size_attention // split_slice_size):