Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to port DISK UNET from Kornia to Inf2. Compilation taking hours with no signs of progress #1039

Open
kandakji opened this issue Nov 20, 2024 · 3 comments

Comments

@kandakji
Copy link

Hi,

I'm trying to port some models from Kornia. I was able to port NetVlad and LightGlue.

When it comes to Disk, the trace command from torch_neuronx is Input tensor is not an XLA tensor: LazyFloatType although I moved the tensors and the model to the xla device.

So, I started experimenting with torch.jit.trace the compiler runs but is just stuck at this debug entry:

2024-11-20T19:01:27Z INFO 454853 [job.Frontend.0]: Executing: <site-packages>/neuronxcc/starfish/bin/hlo2penguin --input /tmp/ubuntu/neuroncc_compile_workdir/ab6ffbbd-bb47-4739-9ef3-fef030126a68/model.MODULE_16216335577045190367+11b4a2df.hlo_module.pb --out-dir ./ --output penguin.py --layers-per-module=1 --partition --coalesce-all-gathers=false --coalesce-reduce-scatters=false --coalesce-all-reduces=false --emit-tensor-level-dropout-ops --emit-tensor-level-rng-ops --expand-batch-norm-training --enable-native-kernel --native-kernel-auto-cast=matmult-to-bf16

@fayyadd
Copy link

fayyadd commented Nov 22, 2024

Thank you for reaching out! To help us investigate this issue, can you please share the neuron versions in your environment pip list | grep neuron and the steps to reproduce the issue?

@kandakji
Copy link
Author

kandakji commented Nov 27, 2024

Hello @fayyadd , here're the pip extract:

aws-neuronx-runtime-discovery 2.9
libneuronxla                  2.0.4115.0
neuronx-cc                    2.15.141.0+d3cfc8ca
torch-neuronx                 2.1.2.2.3.1

Here's the code to reporoduce:


import torch
from extractors.disk import DISK
import os
import torch_neuronx

import torch_xla.core.xla_model as xm

device = xm.xla_device()

disk_input = {"image": torch.ones(1, 3, 1024, 768).mul(0.5)}

# load disk model
conf = {
    "max_num_keypoints": 2000,
}

disk_model = DISK(conf).to(device)

os.environ['NEURON_CC_FLAGS'] = '--verbose DEBUG --target inf2 --model-type unet-inference --optlevel 1'

neuron_disk_model = torch.jit.trace(disk_model, disk_input, strict=False)

# Analyze the model - this will show operator support and operator count
neuron_disk_model_unet = torch_neuronx.trace(disk_model, disk_input)

Here's extractors/disk.py:


import time
import torch
from kornia.feature import DISK as _DISK
from threading import Lock
from copy import copy
from types import SimpleNamespace
from collections import namedtuple

class DISK(torch.nn.Module):
    _lock = Lock()
    default_conf = {
        "weights": "depth",
        "max_num_keypoints": 2000,
        "desc_dim": 128,
        "nms_window_size": 5,
        "detection_threshold": 0.0,
        "pad_if_not_divisible": True,
    }
    required_inputs = ['image']

    def __init__(self, conf):
        self.conf = SimpleNamespace(**{**self.default_conf, **conf})
        super().__init__()
        print("DISK loading model")
        start_time = time.time()
        self.required_inputs = copy(self.required_inputs)
        self.model = _DISK.from_pretrained(self.conf.weights, device=torch.device('cpu'))
        print("DISK model loaded in %.2fs" % (time.time() - start_time))

    @torch.no_grad()
    def forward(self, data):
        """Check the data and call the _forward method of the child model."""
        for key in self.required_inputs:
            assert key in data, 'Missing key {} in data'.format(key)

        image = data['image']
        features = self.model(image,
                              n=self.conf.max_num_keypoints,
                              window_size=self.conf.nms_window_size,
                              score_threshold=self.conf.detection_threshold,
                              pad_if_not_divisible=self.conf.pad_if_not_divisible)
        keypoints = [f.keypoints for f in features]
        scores = [f.detection_scores for f in features]
        descriptors = [f.descriptors for f in features]
        del features

        keypoints = torch.stack(keypoints, 0)
        scores = torch.stack(scores, 0)
        descriptors = torch.stack(descriptors, 0)

        # return [keypoints.cpu().contiguous(), descriptors.cpu().contiguous(), scores.cpu().contiguous()]
        Result = namedtuple('Result', ['keypoints', 'descriptors', 'detection_scores'])
        return Result(keypoints, descriptors, scores)
        # return {'keypoints': keypoints.cpu().contiguous(),
        #         'descriptors': descriptors.cpu().contiguous(),
        #         'detection_scores': scores.cpu().contiguous()}


    def warmup(self, device="cuda", iterations=5):

        print("DISK warming up with {} iterations".format(iterations))
        start_time = time.time()
        self.warming_up = True
        warming_input = {"image": torch.rand(1, 3, 768, 768, dtype=torch.float32, device=device)}
        for _ in range(iterations):
            with torch.no_grad():
                _ = self(warming_input, device)
        self.warming_up = False
        self.warmed_up = True
        print("DISK warmed up in %.2fs" % (time.time() - start_time))

@jluntamazon
Copy link
Contributor

Currently this model is not well supported due to the use of Billinear upsampling. There is ongoing work to support this operation, but it will require an upcoming release.

Original Issue Resolution
While this cannot yet be resolved, below documents the steps used to create a model which is able to be traced correctly (compilation fails).

The original issue should have been resolvable by just tracing the UNet portion of the model for Neuron rather than the included wrapper. One of the reasons this model does not work well with Neuron is the use of variable shaped tensors in the post-processing stages of the model: https://github.com/kornia/kornia/blob/main/kornia/feature/disk/detector.py#L56

To avoid variable shaped tensors, we can usually trace the compute-heavy (static) portion of the model and then allow the post-processing portion of the model to execute on CPU.

If we supported bilinear upsampling, this would be how I would have modified the script:

import torch
from disk import DISK
import torch_neuronx

# load disk model
conf = {
    "max_num_keypoints": 2000,
}
disk_model = DISK(conf)

# Trace Unet
unet_input = torch.ones(1, 3, 1024, 768).mul(0.5)
compiler_args = '--verbose DEBUG --target trn1 --model-type unet-inference --optlevel 1'
unet = torch_neuronx.trace(disk_model.model.unet, unet_input, compiler_args=compiler_args)

# Replace module with traced Neuron module
disk_model.model.unet = unet

# Create final artifact
disk_input = {"image": unet_input}
neuron_disk_model_unet = torch.jit.trace(disk_model, disk_input, strict=False)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants