Skip to content

Commit

Permalink
Multiply the slice rate for dyn bmm by 4
Browse files Browse the repository at this point in the history
  • Loading branch information
Disty0 committed Jan 30, 2025
1 parent 115aac6 commit 6759560
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion installer.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,7 +707,7 @@ def install_ipex(torch_command):
if args.use_nightly:
torch_command = os.environ.get('TORCH_COMMAND', '--pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/xpu')
if os.environ.get('TRITON_COMMAND', None) is None:
os.environ.setdefault('TRITON_COMMAND', 'skip') # pytorch auto installs pytorch-triton-rocm as a dependency instead
os.environ.setdefault('TRITON_COMMAND', 'skip') # pytorch auto installs pytorch-triton-xpu as a dependency instead
else:
if "linux" in sys.platform:
# default to US server. If The China server is needed, change .../release-whl/stable/xpu/us/ to .../release-whl/stable/xpu/cn/
Expand Down
4 changes: 2 additions & 2 deletions modules/sd_hijack_dynamic_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def dynamic_scaled_dot_product_attention(query, key, value, attn_mask=None, drop


@cache
def find_bmm_slice_sizes(query_shape, query_element_size, slice_rate=4, trigger_rate=6):
def find_bmm_slice_sizes(query_shape, query_element_size, slice_rate=2, trigger_rate=4):
if len(query_shape) == 3:
batch_size_attention, query_tokens, shape_three = query_shape
shape_four = 1
Expand Down Expand Up @@ -197,7 +197,7 @@ def __call__(self, attn, hidden_states: torch.Tensor, encoder_hidden_states=None
# Slicing parts:
batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_bmm_slice_sizes(query.shape, query.element_size(), slice_rate=shared.opts.dynamic_attention_slice_rate, trigger_rate=shared.opts.dynamic_attention_trigger_rate)
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_bmm_slice_sizes(query.shape, query.element_size(), slice_rate=shared.opts.dynamic_attention_slice_rate*4, trigger_rate=shared.opts.dynamic_attention_trigger_rate*4)

if do_split:
for i in range(batch_size_attention // split_slice_size):
Expand Down

0 comments on commit 6759560

Please sign in to comment.