You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
While running a SwinTransformer block using torch-tensorrt, there is a RuntimeError (torch.ops.aten.expand.default) during matrix multiplication in the attention mechanism. Specifically, the issue arises during the calculation of attention scores using q @ k.transpose(-2, -1) in the forward pass.
The following error message is produced:
Traceback (most recent call last):
File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/backend/backends.py", line 114, in _pretraced_backend
trt_compiled = compile_module(
File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/_compiler.py", line 487, in compile_module
trt_module = convert_module(
File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/conversion/_conversion.py", line 141, in convert_module
interpreter_result = interpret_module_to_result(
File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/conversion/_conversion.py", line 120, in interpret_module_to_result
interpreter_result = interpreter.run()
File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 617, in run
self._construct_trt_network_def()
File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 348, in _construct_trt_network_def
super().run()
File "/root/.pyenv/versions/3.10.15/lib/python3.10/site-packages/torch/fx/interpreter.py", line 146, in run
self.env[node] = self.run_node(node)
File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 683, in run_node
trt_node: torch.fx.Node = super().run_node(n)
File "/root/.pyenv/versions/3.10.15/lib/python3.10/site-packages/torch/fx/interpreter.py", line 203, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 792, in call_function
return converter(self.ctx, target, args, kwargs, self._cur_node_name)
File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/conversion/converter_utils.py", line 539, in convert_with_type_enforcement
return func(ctx, target, new_args, new_kwargs, name)
File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py", line 1170, in aten_ops_expand
return impl.slice.expand(
File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py", line 240, in expand
raise RuntimeError(
RuntimeError: expand called with 4-dimensional shape on Tensor with 4 dimensions. Cannot expand to shape with rank smaller than original tensor.
While executing %expand : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%mul, [1, 4, 49, 16]), kwargs = {_itensor_to_tensor_meta: {<tensorrt_bindings.tensorrt.ITensor object at 0x7ff359701cf0>: ((1, 49, 64), torch.float32, False, (3136, 64, 1), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7ff358215430>: ((64,), torch.float32, True, (1,), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7ff3582155b0>: ((64,), torch.float32, True, (1,), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7ff358207db0>: ((1, 49, 64), torch.float32, False, (3136, 64, 1), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7ff358207d30>: ((192, 64), torch.float32, True, (64, 1), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7ff358204ef0>: ((64, 192), torch.float32, False, (1, 64), None, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7ff358206c70>: ((49, 64), torch.float32, False, (64, 1), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7ff3582059f0>: ((49, 192), torch.float32, False, (192, 1), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7ff358205af0>: ((1, 49, 192), torch.float32, False, (9408, 192, 1), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7ff358215cf0>: ((1, 49, 3, 4, 16), torch.float32, False, (9408, 192, 64, 16, 1), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7ff358215df0>: ((3, 1, 4, 49, 16), torch.float32, False, (64, 9408, 16, 192, 1), None, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7ff358216130>: ((1, 4, 49, 16), torch.float32, False, (9408, 16, 192, 1), None, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7ff3582161f0>: ((1, 4, 49, 16), torch.float32, False, (9408, 16, 192, 1), None, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7ff358216270>: ((1, 4, 49, 16), torch.float32, False, (9408, 16, 192, 1), None, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7ff358216630>: ((1, 4, 49, 16), torch.float32, False, (3136, 16, 64, 1), None, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7ff3582168b0>: ((1, 4, 16, 49), torch.float32, False, (9408, 16, 1, 192), None, False, {})}})
Reproduction Code:
importtorchimporttorch.nnasnnimporttorch.nn.functionalasFimporttorch_tensorrt# Set device and backendbackend="torch_tensorrt"device=torch.device("cuda:0")
classWindowAttention(nn.Module):
def__init__(self, dim, num_heads, window_size, qkv_bias=False, attn_drop=0.0, proj_drop=0.0):
super().__init__()
self.dim=dimself.num_heads=num_headsself.window_size=window_sizeself.scale= (dim//num_heads) **-0.5self.qkv=nn.Linear(dim, dim*3, bias=qkv_bias)
self.attn_drop=nn.Dropout(attn_drop)
self.proj=nn.Linear(dim, dim)
self.proj_drop=nn.Dropout(proj_drop)
self.softmax=nn.Softmax(dim=-1)
defforward(self, x):
B_, N, C=x.shapeqkv=self.qkv(x).reshape(B_, N, 3, self.num_heads, C//self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v=qkv[0], qkv[1], qkv[2]
q=q*self.scaleattn= (q @ k.transpose(-2, -1)) # Error happens hereattn=self.softmax(attn)
attn=self.attn_drop(attn)
x= (attn @ v).transpose(1, 2).reshape(B_, N, C)
x=self.proj(x)
returnself.proj_drop(x)
classSwinTransformerBlock(nn.Module):
def__init__(self, dim, num_heads, window_size):
super().__init__()
self.attn=WindowAttention(dim, num_heads, window_size)
self.norm1=nn.LayerNorm(dim)
self.norm2=nn.LayerNorm(dim)
self.mlp=nn.Sequential(nn.Linear(dim, dim), nn.GELU(), nn.Linear(dim, dim))
defforward(self, x):
shortcut=xx=self.norm1(x)
x=self.attn(x)
x=shortcut+xx=x+self.mlp(self.norm2(x))
returnx# Example input and usagedim=64num_heads=4window_size=7x=torch.randn(1, 49, dim).to(device) # Example input (B, N, C)block=SwinTransformerBlock(dim, num_heads, window_size)
block.eval()
model=block.to(device)
# Forward pass through blockblock=torch.compile(
block,
backend=backend,
options={
"truncate_long_and_double": True,
"enabled_precisions": {torch.float16, torch.float32},
"device": device,
"min_block_size": 5,
"require_full_compilation": True
},
dynamic=False,
)
outputs_after=model(x) # Error occurs here
The text was updated successfully, but these errors were encountered:
While running a SwinTransformer block using torch-tensorrt, there is a RuntimeError (torch.ops.aten.expand.default) during matrix multiplication in the attention mechanism. Specifically, the issue arises during the calculation of attention scores using
q @ k.transpose(-2, -1)
in the forward pass.The following error message is produced:
Reproduction Code:
The text was updated successfully, but these errors were encountered: