Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TRT support for MAISI #701

Open
wants to merge 23 commits into
base: dev
Choose a base branch
from
Open

TRT support for MAISI #701

wants to merge 23 commits into from

Conversation

borisfom
Copy link
Contributor

@borisfom borisfom commented Oct 16, 2024

Description

TRT optimization support for MAISI.
Depends on Project-MONAI/MONAI#8153

Status

Work in progress

@Nic-Ma
Copy link
Collaborator

Nic-Ma commented Oct 22, 2024

Hi @yiheng-wang-nv ,

Is the CI pipeline broken?

Thanks.

@yiheng-wang-nv
Copy link
Collaborator

Hi @KumoLiu , just FYI, the MAISI tensorrt enhancement PR contains the content of this PR: Project-MONAI/MONAI#8153

We may need to merge this one first before merging the MAISI one

Signed-off-by: Boris Fomitchev <[email protected]>
@binliunls
Copy link
Contributor

Hi @borisfom ,
I got the error shown below on an A100 40GB GPU. Is this expected?

Traceback (most recent call last):
  File "/opt/monai/monai/bundle/config_item.py", line 374, in evaluate
    return eval(value[len(self.prefix) :], globals_, locals)
  File "<string>", line 1, in <module>
  File "/home/liubin/data/bundles/maisi_ct_generative/scripts/sample.py", line 681, in sample_multiple_images
    synthetic_images, synthetic_labels = self.sample_one_pair(
  File "/home/liubin/data/bundles/maisi_ct_generative/scripts/sample.py", line 759, in sample_one_pair
    synthetic_images, synthetic_labels = ldm_conditional_sample_one_image(
  File "/home/liubin/data/bundles/maisi_ct_generative/scripts/sample.py", line 245, in ldm_conditional_sample_one_image
    down_block_res_samples, mid_block_res_sample = controlnet(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/monai/monai/networks/trt_compiler.py", line 609, in trt_forward
    return self._trt_compiler.forward(self, argv, kwargs)
  File "/opt/monai/monai/networks/trt_compiler.py", line 454, in forward
    raise e
  File "/opt/monai/monai/networks/trt_compiler.py", line 445, in forward
    self._build_and_save(model, build_args)
  File "/opt/monai/monai/networks/trt_compiler.py", line 590, in _build_and_save
    convert_to_onnx(
  File "/opt/monai/monai/networks/utils.py", line 699, in convert_to_onnx
    torch.onnx.export(
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/__init__.py", line 377, in export
    export(
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py", line 502, in export
    _export(
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py", line 1564, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py", line 1117, in _model_to_graph
    graph = _optimize_graph(
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py", line 639, in _optimize_graph
    graph = _C._jit_pass_onnx(graph, operator_export_type)
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py", line 1848, in _run_symbolic_function
    return symbolic_fn(graph_context, *inputs, **attrs)
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/symbolic_helper.py", line 281, in wrapper
    return fn(g, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/symbolic_opset14.py", line 173, in scaled_dot_product_attention
    query_scaled = g.op("Mul", query, g.op("Sqrt", scale))
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/jit_utils.py", line 92, in op
    return _add_op(self, opname, *raw_args, outputs=outputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/jit_utils.py", line 239, in _add_op
    inputs = [_const_if_tensor(graph_context, arg) for arg in args]
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/jit_utils.py", line 239, in <listcomp>
    inputs = [_const_if_tensor(graph_context, arg) for arg in args]
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/jit_utils.py", line 270, in _const_if_tensor
    return _add_op(graph_context, "onnx::Constant", value_z=arg)
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/jit_utils.py", line 247, in _add_op
    node = _create_node(
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/jit_utils.py", line 306, in _create_node
    _add_attribute(node, key, value, aten=aten)
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/jit_utils.py", line 336, in _add_attribute
    return getattr(node, f"{kind}_")(name, value)
TypeError: z_(): incompatible function arguments. The following argument types are supported:
    1. (self: torch._C.Node, arg0: str, arg1: torch.Tensor) -> torch._C.Node

Invoked with: %728 : Tensor = onnx::Constant(), scope: monai.apps.generation.maisi.networks.controlnet_maisi.ControlNetMaisi::/monai.networks.nets.diffusion_model_unet.AttnDownBlock::down_blocks.2/monai.networks.blocks.spatialattention.SpatialAttentionBlock::attentions.0/monai.networks.blocks.selfattention.SABlock::attn
, 'value', 0.1767766952966369 
(Occurred when translating scaled_dot_product_attention).

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/opt/monai/monai/bundle/__main__.py", line 31, in <module>
    fire.Fire()
  File "/usr/local/lib/python3.10/dist-packages/fire/core.py", line 135, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/usr/local/lib/python3.10/dist-packages/fire/core.py", line 468, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "/usr/local/lib/python3.10/dist-packages/fire/core.py", line 684, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/opt/monai/monai/bundle/scripts.py", line 1010, in run
    workflow.run()
  File "/opt/monai/monai/bundle/workflows.py", line 363, in run
    return self._run_expr(id=self.run_id)
  File "/opt/monai/monai/bundle/workflows.py", line 397, in _run_expr
    return self.parser.get_parsed_content(id, **kwargs) if id in self.parser else None
  File "/opt/monai/monai/bundle/config_parser.py", line 290, in get_parsed_content
    return self.ref_resolver.get_resolved_content(id=id, **kwargs)
  File "/opt/monai/monai/bundle/reference_resolver.py", line 193, in get_resolved_content
    return self._resolve_one_item(id=id, **kwargs)
  File "/opt/monai/monai/bundle/reference_resolver.py", line 163, in _resolve_one_item
    self._resolve_one_item(id=d, waiting_list=waiting_list, **kwargs)
  File "/opt/monai/monai/bundle/reference_resolver.py", line 175, in _resolve_one_item
    item.evaluate(globals={f"{self._vars}": self.resolved_content}) if run_eval else item
  File "/opt/monai/monai/bundle/config_item.py", line 376, in evaluate
    raise RuntimeError(f"Failed to evaluate {self}") from e
RuntimeError: Failed to evaluate ConfigExpression: 
"$__local_refs['ldm_sampler'].sample_multiple_images(__local_refs['num_output_samples'])"

Thanks,
Bin

@binliunls
Copy link
Contributor

binliunls commented Oct 31, 2024

Should be fine for MAISI as I tested.

  maisi bundle inference(ms) trt_bundle inference(ms)
end2end (latent feature generation) 80237.45 40979.48
end2end (image_decoding) 2187.02 2777.27

Thanks,
Bin Liu

@borisfom
Copy link
Contributor Author

@binliunls : how come image_decoding is much slower with TRT? How do I run a test for that ?

@binliunls
Copy link
Contributor

binliunls commented Nov 1, 2024

@binliunls : how come image_decoding is much slower with TRT? How do I run a test for that ?

I was running the command line like python -m monai.bundle run --config_file="['configs/inference.json', 'configs/inference_trt.json']" --output_size_xy=256 --output_size_z=256 on an A100 40G GPU. Then the bundle will output the latency for image decoding. And this was an one-time running, since I got the colossus shutdown when I was going to run it serveral times. So there may be some bias. I will do it again once I get a new colossus node.

@KumoLiu
Copy link
Collaborator

KumoLiu commented Nov 15, 2024

Project-MONAI/MONAI#8153 has been merged.
Do we need update the readme for MAISI and also include the benchmark data there? @binliunls @yiheng-wang-nv

@binliunls
Copy link
Contributor

Hi @borisfom ,
Here is the benchmark details about 100 times running on MAISI with 256x256x256 input shape on A100 80GB. I am not sure why the Image Decoding suffers a slowdown. Can be some overhead issues. Will try to figure it out later.

Latency Type TRT Mean Latency (s) Bundle Mean Latency (s)
Mask Preparation 2.897087729 2.793987193
Feature Generation 35.12124193 76.54545327
Image Decoding 1.483238726 1.194563277
Latency Type TRT Median Latency (s) Bundle Median Latency (s)
Mask Preparation 2.90212667 2.80731046
Feature Generation 35.12641037 76.54435086
Image Decoding 1.490729215 1.17928219

Thanks,
Bin

@borisfom
Copy link
Contributor Author

borisfom commented Nov 16, 2024 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants