Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Coverage] RuntimeError (torch.ops.aten.expand.default) during matrix multiplication in SwinTransformer block #3257

Open
Tracked by #3179
chohk88 opened this issue Oct 22, 2024 · 0 comments

Comments

@chohk88
Copy link
Collaborator

chohk88 commented Oct 22, 2024

While running a SwinTransformer block using torch-tensorrt, there is a RuntimeError (torch.ops.aten.expand.default) during matrix multiplication in the attention mechanism. Specifically, the issue arises during the calculation of attention scores using q @ k.transpose(-2, -1) in the forward pass.

The following error message is produced:

Traceback (most recent call last):
  File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/backend/backends.py", line 114, in _pretraced_backend
    trt_compiled = compile_module(
  File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/_compiler.py", line 487, in compile_module
    trt_module = convert_module(
  File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/conversion/_conversion.py", line 141, in convert_module
    interpreter_result = interpret_module_to_result(
  File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/conversion/_conversion.py", line 120, in interpret_module_to_result
    interpreter_result = interpreter.run()
  File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 617, in run
    self._construct_trt_network_def()
  File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 348, in _construct_trt_network_def
    super().run()
  File "/root/.pyenv/versions/3.10.15/lib/python3.10/site-packages/torch/fx/interpreter.py", line 146, in run
    self.env[node] = self.run_node(node)
  File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 683, in run_node
    trt_node: torch.fx.Node = super().run_node(n)
  File "/root/.pyenv/versions/3.10.15/lib/python3.10/site-packages/torch/fx/interpreter.py", line 203, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 792, in call_function
    return converter(self.ctx, target, args, kwargs, self._cur_node_name)
  File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/conversion/converter_utils.py", line 539, in convert_with_type_enforcement
    return func(ctx, target, new_args, new_kwargs, name)
  File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py", line 1170, in aten_ops_expand
    return impl.slice.expand(
  File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py", line 240, in expand
    raise RuntimeError(
RuntimeError: expand called with 4-dimensional shape on Tensor with 4 dimensions. Cannot expand to shape with rank smaller than original tensor.

While executing %expand : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%mul, [1, 4, 49, 16]), kwargs = {_itensor_to_tensor_meta: {<tensorrt_bindings.tensorrt.ITensor object at 0x7ff359701cf0>: ((1, 49, 64), torch.float32, False, (3136, 64, 1), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7ff358215430>: ((64,), torch.float32, True, (1,), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7ff3582155b0>: ((64,), torch.float32, True, (1,), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7ff358207db0>: ((1, 49, 64), torch.float32, False, (3136, 64, 1), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7ff358207d30>: ((192, 64), torch.float32, True, (64, 1), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7ff358204ef0>: ((64, 192), torch.float32, False, (1, 64), None, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7ff358206c70>: ((49, 64), torch.float32, False, (64, 1), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7ff3582059f0>: ((49, 192), torch.float32, False, (192, 1), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7ff358205af0>: ((1, 49, 192), torch.float32, False, (9408, 192, 1), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7ff358215cf0>: ((1, 49, 3, 4, 16), torch.float32, False, (9408, 192, 64, 16, 1), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7ff358215df0>: ((3, 1, 4, 49, 16), torch.float32, False, (64, 9408, 16, 192, 1), None, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7ff358216130>: ((1, 4, 49, 16), torch.float32, False, (9408, 16, 192, 1), None, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7ff3582161f0>: ((1, 4, 49, 16), torch.float32, False, (9408, 16, 192, 1), None, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7ff358216270>: ((1, 4, 49, 16), torch.float32, False, (9408, 16, 192, 1), None, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7ff358216630>: ((1, 4, 49, 16), torch.float32, False, (3136, 16, 64, 1), None, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7ff3582168b0>: ((1, 4, 16, 49), torch.float32, False, (9408, 16, 1, 192), None, False, {})}})

Reproduction Code:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_tensorrt

# Set device and backend
backend = "torch_tensorrt"
device = torch.device("cuda:0")

class WindowAttention(nn.Module):
    def __init__(self, dim, num_heads, window_size, qkv_bias=False, attn_drop=0.0, proj_drop=0.0):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.scale = (dim // num_heads) ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))  # Error happens here
        attn = self.softmax(attn)
        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        return self.proj_drop(x)

class SwinTransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, window_size):
        super().__init__()
        self.attn = WindowAttention(dim, num_heads, window_size)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(nn.Linear(dim, dim), nn.GELU(), nn.Linear(dim, dim))

    def forward(self, x):
        shortcut = x
        x = self.norm1(x)
        x = self.attn(x)
        x = shortcut + x
        x = x + self.mlp(self.norm2(x))
        return x

# Example input and usage
dim = 64
num_heads = 4
window_size = 7

x = torch.randn(1, 49, dim).to(device)  # Example input (B, N, C)
block = SwinTransformerBlock(dim, num_heads, window_size)
block.eval()
model = block.to(device)

# Forward pass through block
block = torch.compile(
    block,
    backend=backend,
    options={
        "truncate_long_and_double": True,
        "enabled_precisions": {torch.float16, torch.float32},
        "device": device,
        "min_block_size": 5,
        "require_full_compilation": True
    },
    dynamic=False,
)

outputs_after = model(x)  # Error occurs here
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant