diff --git a/docsrc/tutorials/serving_torch_tensorrt_with_triton.rst b/docsrc/tutorials/serving_torch_tensorrt_with_triton.rst index 31cb6733cf..465a6435af 100644 --- a/docsrc/tutorials/serving_torch_tensorrt_with_triton.rst +++ b/docsrc/tutorials/serving_torch_tensorrt_with_triton.rst @@ -22,42 +22,55 @@ Step 1: Optimize your model with Torch-TensorRT Most Torch-TensorRT users will be familiar with this step. For the purpose of this demonstration, we will be using a ResNet50 model from Torchhub. -Let’s first pull the `NGC PyTorch Docker container `__. You may need to create +We will be working in the ``//examples/triton`` directory which contains the scripts used in this tutorial. + +First pull the `NGC PyTorch Docker container `__. You may need to create an account and get the API key from `here `__. Sign up and login with your key (follow the instructions `here `__ after signing up). :: - # is the yy:mm for the publishing tag for NVIDIA's Pytorch - # container; eg. 22.04 + # YY.MM is the yy:mm for the publishing tag for NVIDIA's Pytorch + # container; eg. 24.08 + # NOTE: Use the publishing tag for both the PyTorch container and the Triton Containers - docker run -it --gpus all -v ${PWD}:/scratch_space nvcr.io/nvidia/pytorch:-py3 + docker run -it --gpus all -v ${PWD}:/scratch_space nvcr.io/nvidia/pytorch:YY.MM-py3 cd /scratch_space -Once inside the container, we can proceed to download a ResNet model from -Torchhub and optimize it with Torch-TensorRT. +With the container we can export the model in to the correct directory in our Triton model repository. This export script uses the **Dynamo** frontend for Torch-TensorRT to compile the PyTorch model to TensorRT. Then we save the model using **TorchScript** as a serialization format which is supported by Triton. :: - import torch - import torch_tensorrt - torch.hub._validate_not_a_forked_repo=lambda a,b,c: True + import torch + import torch_tensorrt as torchtrt + import torchvision + + import torch + import torch_tensorrt + torch.hub._validate_not_a_forked_repo=lambda a,b,c: True + + # load model + model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True).eval().to("cuda") + + # Compile with Torch TensorRT; + trt_model = torch_tensorrt.compile(model, + inputs= [torch_tensorrt.Input((1, 3, 224, 224))], + enabled_precisions= {torch_tensorrt.dtype.f16} + ) + + ts_trt_model = torch.jit.trace(trt_model, torch.rand(1, 3, 224, 224).to("cuda")) + + # Save the model + torch.jit.save(ts_trt_model, "/triton_example/model_repository/resnet50/1/model.pt") - # load model - model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True).eval().to("cuda") +You can run the script with the following command (from ``//examples/triton``) - # Compile with Torch TensorRT; - trt_model = torch_tensorrt.compile(model, - inputs= [torch_tensorrt.Input((1, 3, 224, 224))], - enabled_precisions= { torch.half} # Run with FP32 - ) +:: - # Save the model - torch.jit.save(trt_model, "model.pt") + docker run --gpus all -it --rm -v ${PWD}:/triton_example nvcr.io/nvidia/pytorch:YY.MM-py3 python /triton_example/export.py -After copying the model, exit the container. The next step in the process -is to set up a Triton Inference Server. +This will save the serialized TorchScript version of the ResNet model in the right directory in the model repository. Step 2: Set Up Triton Inference Server -------------------------------------- @@ -90,25 +103,23 @@ For the model we prepared in step 1, the following configuration can be used: :: - name: "resnet50" - platform: "pytorch_libtorch" - max_batch_size : 0 - input [ - { - name: "input__0" - data_type: TYPE_FP32 - dims: [ 3, 224, 224 ] - reshape { shape: [ 1, 3, 224, 224 ] } - } - ] - output [ - { - name: "output__0" - data_type: TYPE_FP32 - dims: [ 1, 1000 ,1, 1] - reshape { shape: [ 1, 1000 ] } - } - ] + name: "resnet50" + backend: "pytorch" + max_batch_size : 0 + input [ + { + name: "x" + data_type: TYPE_FP32 + dims: [ 1, 3, 224, 224 ] + } + ] + output [ + { + name: "output0" + data_type: TYPE_FP32 + dims: [1, 1000] + } + ] The ``config.pbtxt`` file is used to describe the exact model configuration with details like the names and shapes of the input and output layer(s), @@ -124,9 +135,9 @@ with the docker command below. Refer `this page -py3 tritonserver --model-repository=/models + docker run --gpus all --rm -p 8000:8000 -p 8001:8001 -p 8002:8002 -v ${PWD}:/triton_example nvcr.io/nvidia/tritonserver:24.08-py3 tritonserver --model-repository=/triton_example/model_repository This should spin up a Triton Inference server. Next step, building a simple http client to query the server. @@ -159,22 +170,24 @@ resize and normalize the query image. :: - import numpy as np - from torchvision import transforms - from PIL import Image - import tritonclient.http as httpclient - from tritonclient.utils import triton_to_np_dtype - - # preprocessing function - def rn50_preprocess(img_path="img1.jpg"): - img = Image.open(img_path) - preprocess = transforms.Compose([ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), - ]) - return preprocess(img).numpy() + import numpy as np + from torchvision import transforms + from PIL import Image + import tritonclient.http as httpclient + from tritonclient.utils import triton_to_np_dtype + + # preprocessing function + def rn50_preprocess(img_path="/triton_example/img1.jpg"): + img = Image.open(img_path) + preprocess = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) + return preprocess(img).unsqueeze(0).numpy() transformed_img = rn50_preprocess() @@ -186,14 +199,14 @@ with the Triton Inference Server. # Setting up client client = httpclient.InferenceServerClient(url="localhost:8000") -Secondly, we specify the names of the input and output layer(s) of our model. +Secondly, we specify the names of the input and output layer(s) of our model. This can be obtained during export and should already be specified in your ``config.pbtxt`` :: - inputs = httpclient.InferInput("input__0", transformed_img.shape, datatype="FP32") + inputs = httpclient.InferInput("x", transformed_img.shape, datatype="FP32") inputs.set_data_from_numpy(transformed_img, binary_data=True) - outputs = httpclient.InferRequestedOutput("output__0", binary_data=True, class_count=1000) + outputs = httpclient.InferRequestedOutput("output0", binary_data=True, class_count=1000) Lastly, we send an inference request to the Triton Inference Server. @@ -201,7 +214,7 @@ Lastly, we send an inference request to the Triton Inference Server. # Querying the server results = client.infer(model_name="resnet50", inputs=[inputs], outputs=[outputs]) - inference_output = results.as_numpy('output__0') + inference_output = results.as_numpy('output0') print(inference_output[:5]) The output should look like below: @@ -214,3 +227,11 @@ The output should look like below: The output format here is ``:``. To learn how to map these to the label names and more, refer to Triton Inference Server's `documentation `__. + +You can try out this client quickly using + +:: + + # Remember to use the same publishing tag for all steps (e.g. 24.08) + + docker run -it --net=host -v ${PWD}:/triton_example nvcr.io/nvidia/tritonserver:YY.MM-py3-sdk bash -c "pip install torchvision && python /triton_example/client.py" diff --git a/examples/triton/README.rst b/examples/triton/README.rst new file mode 100644 index 0000000000..e00d79cfce --- /dev/null +++ b/examples/triton/README.rst @@ -0,0 +1,237 @@ +.. _serving_torch_tensorrt_with_triton: + +Serving a Torch-TensorRT model with Triton +========================================== + +Optimization and deployment go hand in hand in a discussion about Machine +Learning infrastructure. Once network level optimization are done +to get the maximum performance, the next step would be to deploy it. + +However, serving this optimized model comes with its own set of considerations +and challenges like: building an infrastructure to support concurrent model +executions, supporting clients over HTTP or gRPC and more. + +The `Triton Inference Server `__ +solves the aforementioned and more. Let's discuss step-by-step, the process of +optimizing a model with Torch-TensorRT, deploying it on Triton Inference +Server, and building a client to query the model. + +Step 1: Optimize your model with Torch-TensorRT +----------------------------------------------- + +Most Torch-TensorRT users will be familiar with this step. For the purpose of +this demonstration, we will be using a ResNet50 model from Torchhub. + +We will be working in the ``//examples/triton`` directory which contains the scripts used in this tutorial. + +First pull the `NGC PyTorch Docker container `__. You may need to create +an account and get the API key from `here `__. +Sign up and login with your key (follow the instructions +`here `__ after signing up). + +:: + + # is the yy:mm for the publishing tag for NVIDIA's Pytorch + # container; eg. 24.08 + # NOTE: Use the publishing tag for both the PyTorch container and the Triton Containers + + docker run -it --gpus all -v ${PWD}:/scratch_space nvcr.io/nvidia/pytorch:-py3 + cd /scratch_space + +With the container we can export the model in to the correct directory in our Triton model repository. This export script uses the **Dynamo** frontend for Torch-TensorRT to compile the PyTorch model to TensorRT. Then we save the model using **TorchScript** as a serialization format which is supported by Triton. + +:: + + import torch + import torch_tensorrt as torchtrt + import torchvision + + import torch + import torch_tensorrt + torch.hub._validate_not_a_forked_repo=lambda a,b,c: True + + # load model + model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True).eval().to("cuda") + + # Compile with Torch TensorRT; + trt_model = torch_tensorrt.compile(model, + inputs= [torch_tensorrt.Input((1, 3, 224, 224))], + enabled_precisions= {torch_tensorrt.dtype.f16} + ) + + ts_trt_model = torch.jit.trace(trt_model, torch.rand(1, 3, 224, 224).to("cuda")) + + # Save the model + torch.jit.save(ts_trt_model, "/triton_example/model_repository/resnet50/1/model.pt") + +You can run the script with the following command (from ``//examples/triton``) + +:: + + docker run --gpus all -it --rm -v ${PWD}:/triton_example nvcr.io/nvidia/pytorch:YY.XX-py3 python /triton_example/export.py + +This will save the serialized TorchScript version of the ResNet model in the right directory in the model repository. + +Step 2: Set Up Triton Inference Server +-------------------------------------- + +If you are new to the Triton Inference Server and want to learn more, we +highly recommend to checking our `Github +Repository `__. + +To use Triton, we need to make a model repository. A model repository, as the +name suggests, is a repository of the models the Inference server hosts. While +Triton can serve models from multiple repositories, in this example, we will +discuss the simplest possible form of the model repository. + +The structure of this repository should look something like this: + +:: + + model_repository + | + +-- resnet50 + | + +-- config.pbtxt + +-- 1 + | + +-- model.pt + +There are two files that Triton requires to serve the model: the model itself +and a model configuration file which is typically provided in ``config.pbtxt``. +For the model we prepared in step 1, the following configuration can be used: + +:: + + name: "resnet50" + backend: "pytorch" + max_batch_size : 0 + input [ + { + name: "x" + data_type: TYPE_FP32 + dims: [ 1, 3, 224, 224 ] + } + ] + output [ + { + name: "output0" + data_type: TYPE_FP32 + dims: [1, 1000] + } + ] + +The ``config.pbtxt`` file is used to describe the exact model configuration +with details like the names and shapes of the input and output layer(s), +datatypes, scheduling and batching details and more. If you are new to Triton, +we highly encourage you to check out this `section of our +documentation `__ +for more details. + +With the model repository setup, we can proceed to launch the Triton server +with the docker command below. Refer `this page `__ for the pull tag for the container. + +:: + + # Make sure that the TensorRT version in the Triton container + # and TensorRT version in the environment used to optimize the model + # are the same. Roughly, like publishing tags should have the same TensorRT version + + docker run --gpus all --rm -p 8000:8000 -p 8001:8001 -p 8002:8002 -v ${PWD}:/triton_example nvcr.io/nvidia/tritonserver:24.08-py3 tritonserver --model-repository=/triton_example/model_repository + +This should spin up a Triton Inference server. Next step, building a simple +http client to query the server. + +Step 3: Building a Triton Client to Query the Server +---------------------------------------------------- + +Before proceeding, make sure to have a sample image on hand. If you don't +have one, download an example image to test inference. In this section, we +will be going over a very basic client. For a variety of more fleshed out +examples, refer to the `Triton Client Repository `__ + +:: + + wget -O img1.jpg "https://www.hakaimagazine.com/wp-content/uploads/header-gulf-birds.jpg" + +We then need to install dependencies for building a python client. These will +change from client to client. For a full list of all languages supported by Triton, +please refer to `Triton's client repository `__. + +:: + + pip install torchvision + pip install attrdict + pip install nvidia-pyindex + pip install tritonclient[all] + +Let's jump into the client. Firstly, we write a small preprocessing function to +resize and normalize the query image. + +:: + + import numpy as np + from torchvision import transforms + from PIL import Image + import tritonclient.http as httpclient + from tritonclient.utils import triton_to_np_dtype + + # preprocessing function + def rn50_preprocess(img_path="/triton_example/img1.jpg"): + img = Image.open(img_path) + preprocess = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) + return preprocess(img).unsqueeze(0).numpy() + + transformed_img = rn50_preprocess() + +Building a client requires three basic points. Firstly, we setup a connection +with the Triton Inference Server. + +:: + + # Setting up client + client = httpclient.InferenceServerClient(url="localhost:8000") + +Secondly, we specify the names of the input and output layer(s) of our model. This can be obtained during export and should already be specified in your ``config.pbtxt`` + +:: + + inputs = httpclient.InferInput("x", transformed_img.shape, datatype="FP32") + inputs.set_data_from_numpy(transformed_img, binary_data=True) + + outputs = httpclient.InferRequestedOutput("output0", binary_data=True, class_count=1000) + +Lastly, we send an inference request to the Triton Inference Server. + +:: + + # Querying the server + results = client.infer(model_name="resnet50", inputs=[inputs], outputs=[outputs]) + inference_output = results.as_numpy('output0') + print(inference_output[:5]) + +The output should look like below: + +:: + + [b'12.468750:90' b'11.523438:92' b'9.664062:14' b'8.429688:136' + b'8.234375:11'] + +The output format here is ``:``. +To learn how to map these to the label names and more, refer to Triton Inference Server's +`documentation `__. + +You can try out this client quickly using + +:: + + # Remember to use the same publishing tag for all steps (e.g. 24.08) + + docker run -it --net=host -v ${PWD}:/triton_example nvcr.io/nvidia/tritonserver:YY.MM-py3-sdk bash -c "pip install torchvision && python /triton_example/client.py" diff --git a/examples/triton/client.py b/examples/triton/client.py new file mode 100644 index 0000000000..eeb7564f25 --- /dev/null +++ b/examples/triton/client.py @@ -0,0 +1,61 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import numpy as np +import tritonclient.http as httpclient +from PIL import Image +from torchvision import transforms +from tritonclient.utils import triton_to_np_dtype + + +# preprocessing function +def rn50_preprocess(img_path="/triton_example/img1.jpg"): + img = Image.open(img_path) + preprocess = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) + return preprocess(img).unsqueeze(0).numpy() + + +transformed_img = rn50_preprocess() + +# Setting up client +client = httpclient.InferenceServerClient(url="localhost:8000") + +inputs = httpclient.InferInput("x", transformed_img.shape, datatype="FP32") +inputs.set_data_from_numpy(transformed_img, binary_data=True) + +outputs = httpclient.InferRequestedOutput("output0", binary_data=True, class_count=1000) + +# Querying the server +results = client.infer(model_name="resnet50", inputs=[inputs], outputs=[outputs]) +inference_output = results.as_numpy("output0") +print(inference_output[:5]) diff --git a/examples/triton/export.py b/examples/triton/export.py new file mode 100644 index 0000000000..fc042227e3 --- /dev/null +++ b/examples/triton/export.py @@ -0,0 +1,25 @@ +import torch +import torch_tensorrt +import torch_tensorrt as torchtrt +import torchvision + +torch.hub._validate_not_a_forked_repo = lambda a, b, c: True + +# load model +model = ( + torch.hub.load("pytorch/vision:v0.10.0", "resnet50", pretrained=True) + .eval() + .to("cuda") +) + +# Compile with Torch TensorRT; +trt_model = torch_tensorrt.compile( + model, + inputs=[torch_tensorrt.Input((1, 3, 224, 224))], + enabled_precisions={torch_tensorrt.dtype.f16}, +) + +ts_trt_model = torch.jit.trace(trt_model, torch.rand(1, 3, 224, 224).to("cuda")) + +# Save the model +torch.jit.save(ts_trt_model, "/triton_example/model_repository/resnet50/1/model.pt") diff --git a/examples/triton/img1.jpg b/examples/triton/img1.jpg new file mode 100644 index 0000000000..e5aff25397 Binary files /dev/null and b/examples/triton/img1.jpg differ diff --git a/examples/triton/model_repository/resnet50/config.pbtxt b/examples/triton/model_repository/resnet50/config.pbtxt new file mode 100644 index 0000000000..89a7c9eeb3 --- /dev/null +++ b/examples/triton/model_repository/resnet50/config.pbtxt @@ -0,0 +1,17 @@ +name: "resnet50" +backend: "pytorch" +max_batch_size : 0 +input [ + { + name: "x" + data_type: TYPE_FP32 + dims: [ 1, 3, 224, 224 ] + } +] +output [ + { + name: "output0" + data_type: TYPE_FP32 + dims: [1, 1000] + } +]