Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[opt code] Stop moving gpu to cpu and back #476

Open
YacratesWyh opened this issue Jan 13, 2025 · 1 comment
Open

[opt code] Stop moving gpu to cpu and back #476

YacratesWyh opened this issue Jan 13, 2025 · 1 comment

Comments

@YacratesWyh
Copy link

YacratesWyh commented Jan 13, 2025

example for image_blend_advance_v2_gpu instead of image_blend_advance_v2
This code update from 10+sec to 0.3sec for my case.

from .imagefunc import *

NODE_NAME = 'ImageBlendAdvanceV2'

class ImageBlendAdvanceV2:

    def __init__(self):
        pass

    @classmethod
    def INPUT_TYPES(self):

        mirror_mode = ['None', 'horizontal', 'vertical']
        method_mode = ['lanczos', 'bicubic', 'hamming', 'bilinear', 'box', 'nearest']
        return {
            "required": {
                "background_image": ("IMAGE", ),  #
                "layer_image": ("IMAGE",),  #
                "invert_mask": ("BOOLEAN", {"default": True}),  # 反转mask
                "blend_mode": (chop_mode_v2,),  # 混合模式
                "opacity": ("INT", {"default": 100, "min": 0, "max": 100, "step": 1}),  # 透明度
                "x_percent": ("FLOAT", {"default": 50, "min": -999, "max": 999, "step": 0.01}),
                "y_percent": ("FLOAT", {"default": 50, "min": -999, "max": 999, "step": 0.01}),
                "mirror": (mirror_mode,),  # 镜像翻转
                "scale": ("FLOAT", {"default": 1, "min": 0.01, "max": 100, "step": 0.01}),
                "aspect_ratio": ("FLOAT", {"default": 1, "min": 0.01, "max": 100, "step": 0.01}),
                "rotate": ("FLOAT", {"default": 0, "min": -9999, "max": 9999, "step": 0.01}),
                "transform_method": (method_mode,),
                "anti_aliasing": ("INT", {"default": 0, "min": 0, "max": 16, "step": 1}),
            },
            "optional": {
                "layer_mask": ("MASK",),  #
            }
        }

    RETURN_TYPES = ("IMAGE", "MASK")
    RETURN_NAMES = ("image", "mask")
    FUNCTION = 'image_blend_advance_v2'
    CATEGORY = '😺dzNodes/LayerUtility'

    def image_blend_advance_v2(self, background_image, layer_image,
                            invert_mask, blend_mode, opacity,
                            x_percent, y_percent,
                            mirror, scale, aspect_ratio, rotate,
                            transform_method, anti_aliasing,
                            layer_mask=None
                            ):
        
        # from viztracer import VizTracer
        # with VizTracer(output_file="layerstyle250113-1.json") as tracer:
# I advise you to try viztracer to know where's the bottleneck
            b_images = []
            l_images = []
            l_masks = []
            ret_images = []
            ret_masks = []
            for b in background_image:
                b_images.append(torch.unsqueeze(b, 0))
            for l in layer_image:
                l_images.append(torch.unsqueeze(l, 0))
                # m = tensor2pil(l)
                # if m.mode == 'RGBA':
                if l.shape[-1] == 4:
                    l_masks.append(l[3,...]) # Get alpha channel from tensor
                else:
                    l_masks.append(torch.ones((l.shape[0], l.shape[1]))) # Create tensor of ones for alpha channel
            # b_images = background_image #.unsqueeze(1) # Add batch dimension
            # l_images = layer_image #.unsqueeze(1)
            
            # # Handle alpha channel for all images at once
            # if layer_image.shape[-1] == 4:
            #     l_masks = layer_image[..., 3] # Get alpha channel from all tensors
            # else:
            #     l_masks = torch.ones((layer_image.shape[0], layer_image.shape[1], layer_image.shape[2])) # Create alpha for all
            
            if layer_mask is not None:
                if layer_mask.dim() == 2:
                    layer_mask = torch.unsqueeze(layer_mask, 0)
                l_masks = []
                for m in layer_mask:
                    if invert_mask:
                        m = 1 - m
                    # l_masks.append(torch.unsqueeze(m, 0)).convert('L')
                    l_masks.append(0.299 * l[0,...] + 0.587 * l[1,...] + 0.114 * l[2,...])
                    
            max_batch = max(len(b_images), len(l_images), len(l_masks))
            for i in range(max_batch):
                background_image = b_images[i] if i < len(b_images) else b_images[-1]
                layer_image = l_images[i] if i < len(l_images) else l_images[-1]
                _mask = l_masks[i] if i < len(l_masks) else l_masks[-1] #4096*4096
                # preprocess
                _canvas = background_image[..., :3] # 1,4096,4096,3
                _layer = layer_image # Keep as tensor

                if _mask.shape != _layer.shape[1:3]:        #1,4096,4096,3
                     # _mask = Image.new('L', _layer.size, 'white')
                    _mask = torch.ones((_layer.shape[1], _layer.shape[2]))
                    log(f"Warning: {NODE_NAME} mask mismatch, dropped!", message_type='warning')

                orig_layer_width = _layer.shape[1]
                orig_layer_height = _layer.shape[2]
                # _mask = _mask.convert("RGB")
                _mask = _mask.unsqueeze(0).unsqueeze(-1).repeat(1, 1, 1, 3)

                
                target_layer_width = torch.tensor(orig_layer_width * scale, dtype=torch.int32)
                target_layer_height = torch.tensor(orig_layer_height * scale * aspect_ratio, dtype=torch.int32)

                # mirror
                if mirror == 'horizontal':
                    _layer = torch.flip(_layer, dims=[2])  # Flip horizontally (width dimension)
                    _mask = torch.flip(_mask, dims=[2])  # Flip horizontally (width dimension)
                elif mirror == 'vertical':
                    _layer = torch.flip(_layer, dims=[1])  # Flip vertically (height dimension)
                    _mask = torch.flip(_mask, dims=[1])  # Flip vertically (height dimension)

          
                # rotate
                if rotate != 0:
                    # Convert angle to radians
                    angle_rad = torch.tensor(rotate * math.pi / 180)
                    
                    # Calculate rotation matrix
                    cos_theta = torch.cos(angle_rad)
                    sin_theta = torch.sin(angle_rad)
                    rotation_matrix = torch.tensor([[cos_theta, -sin_theta],
                                                  [sin_theta, cos_theta]])

                    # Calculate new dimensions after rotation
                    old_h, old_w = _layer.shape[2], _layer.shape[3]
                    new_h = int(abs(old_h * cos_theta) + abs(old_w * sin_theta))
                    new_w = int(abs(old_w * cos_theta) + abs(old_h * sin_theta))

                    # Create grid for rotation
                    grid_x, grid_y = torch.meshgrid(torch.linspace(-1, 1, new_w),
                                                  torch.linspace(-1, 1, new_h))
                    grid = torch.stack([grid_x, grid_y], dim=2).unsqueeze(0)
                    
                    # Apply rotation
                    grid = grid @ rotation_matrix
                    
                    # Rotate layer and mask using grid_sample
                    _layer = torch.nn.functional.grid_sample(_layer, grid, mode='bilinear', padding_mode='zeros')
                    _mask = torch.nn.functional.grid_sample(_mask, grid, mode='bilinear', padding_mode='zeros')

                # 处理位置
                x = int(_canvas.shape[1] * x_percent / 100 - _layer.shape[1] / 2)
                y = int(_canvas.shape[2] * y_percent / 100 - _layer.shape[2] / 2)

                # composit layer
                # Create empty tensors for compositing
                _comp = _canvas.clone()
                _compmask = torch.zeros_like(_canvas)

                # Calculate dimensions
                h, w = _layer.shape[2], _layer.shape[3]
                H, W = _canvas.shape[2], _canvas.shape[3]
                # Calculate valid ranges for pasting
                x1, x2 = max(x, 0), min(x + w, W)
                y1, y2 = max(y, 0), min(y + h, H)                
                # Calculate source ranges
                sx1, sx2 = max(-x, 0), min(W-x, w) 
                sy1, sy2 = max(-y, 0), min(H-y, h)

                # Paste layer and mask into position
                if x1 < x2 and y1 < y2:
                    _comp[:, :, y1:y2, x1:x2] = _layer[:, :, sy1:sy2, sx1:sx2]
                    _compmask[:, :, y1:y2, x1:x2] = _mask[:, :, sy1:sy2, sx1:sx2]
                
                # Normalize tensors to 0-1 range
                backdrop_norm = _canvas[..., :3]    
                source_norm = _comp[..., :3]    
                mask_norm = _compmask 

                # Apply blend mode (this example shows linear light blend)
                blend = backdrop_norm + (2 * source_norm) - 1
                
                # Apply opacity and mask
                new_rgb = (1 - mask_norm * opacity/100) * backdrop_norm + mask_norm * (opacity/100) * blend

                # Ensure values are in valid range
                new_rgb = torch.clamp(new_rgb, 0, 1)
                # Update _comp with the new blended RGB values
                _comp = new_rgb
                # Handle alpha channel if present in background
                if background_image.shape[-1] == 4:
                    alpha = background_image[..., 3:]
                    _comp = torch.cat([new_rgb, alpha], dim=-1)
                # Convert back to 0-255 range and combine RGB with alpha
                # new_rgb = (new_rgb * 255).to(torch.uint8)
                # new_alpha = torch.maximum(_canvas[..., 3:], _comp[..., 3:])
                # _comp = torch.cat((new_rgb, new_alpha), dim=1)
                # composition background
                _canvas = (1 - _compmask) * _canvas + _compmask * _comp

                ret_images.append(_comp)
                ret_masks.append(_compmask)
                
            log(f"{NODE_NAME} Processed {len(ret_images)} image(s).", message_type='finish')
            return (torch.cat(ret_images, dim=0), torch.cat(ret_masks, dim=0),)

NODE_CLASS_MAPPINGS = {
    "LayerUtility: ImageBlendAdvance V2": ImageBlendAdvanceV2
}

NODE_DISPLAY_NAME_MAPPINGS = {
    "LayerUtility: ImageBlendAdvance V2": "LayerUtility: ImageBlendAdvance V2"
}
@chflame163
Copy link
Owner

Okay! This is something I have always wanted to do. At first, I learned how to use pillow, and after writing these nodes, I gradually realized that the best way is to use torch matrix calculation to accelerate. But I haven't had time to do it yet. Thank you :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants