-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use this as an alternative for the flatten()
function in Pytorch?
#16
Comments
I quickly drafted something, making minor changes here and there, and it seems to work. Among the changes I made in the original algorithm, the biggest change was to remove the recursivity to gain performance by using a reference variable import torch
import numpy as np
import math
import torch
import numpy as np
import math
# Function to compute sign of a value
def sgn(x):
return -1 if x < 0 else (1 if x > 0 else 0)
# Optimized Recursive generate2d function for Gilbert curve
def generate2d(x: int, y: int, ax: int, ay: int, bx: int, by: int, result):
"""Recursive generation of 2D coordinates using the Gilbert space-filling curve."""
# Width and height of the grid to fill
w = abs(ax + ay)
h = abs(bx + by)
# Direction vectors (calculated once and reused)
dax, day = sgn(ax), sgn(ay) # Major direction
dbx, dby = sgn(bx), sgn(by) # Orthogonal direction
# Handle trivial row or column fills
if h == 1 or w == 1:
if h == 1:
for _ in range(w):
result.append((x, y))
x, y = x + dax, y + day # Inlining move_point
elif w == 1:
for _ in range(h):
result.append((x, y))
x, y = x + dbx, y + dby # Inlining move_point
return
# Halve the movement vectors
ax2, ay2 = ax // 2, ay // 2
bx2, by2 = bx // 2, by // 2
w2 = abs(ax2 + ay2)
h2 = abs(bx2 + by2)
if 2 * w > 3 * h:
if w2 % 2 and w > 2:
ax2, ay2 = ax2 + dax, ay2 + day
generate2d(x, y, ax2, ay2, bx, by, result)
generate2d(x + ax2, y + ay2, ax - ax2, ay - ay2, bx, by, result)
else:
if h2 % 2 and h > 2:
bx2, by2 = bx2 + dbx, by2 + dby
generate2d(x, y, bx2, by2, ax2, ay2, result)
generate2d(x + bx2, y + by2, ax, ay, bx - bx2, by - by2, result)
generate2d(x + (ax - dax) + (bx2 - dbx),
y + (ay - day) + (by2 - dby),
-bx2, -by2, -(ax - ax2), -(ay - ay2), result)
# Top-level gilbert2d function
def gilbert2d(width, height):
result = []
if width >= height:
generate2d(0, 0, width, 0, 0, height, result)
else:
generate2d(0, 0, 0, height, width, 0, result)
return result
# Optimized reshape function using batch updates
def reshape_via_gilbert(tensor, width=None, height=None, path=None):
flattened_tensor = tensor.flatten()
num_elements = flattened_tensor.numel()
if width is None or height is None:
if width is None and height is not None:
# Automatically calculate width
width = (num_elements + height - 1) // height
if height is None and width is not None:
# Automatically calculate height
height = (num_elements + width - 1) // width
if height is None and width is None:
# Automatically calculate width and height
height = height or math.isqrt(num_elements)
width = width or (num_elements + height - 1) // height
# Create an empty tensor to store the reshaped values
reshaped_tensor = torch.zeros((height, width), dtype=tensor.dtype, device=tensor.device)
if path is None:
# Get the Gilbert curve path
path = gilbert2d(width, height)
# Convert path to list of index tensors (for batch update)
idx_list = torch.tensor(path[:num_elements], dtype=torch.long, device=tensor.device)
reshaped_tensor[idx_list[:, 1], idx_list[:, 0]] = flattened_tensor[:num_elements]
return reshaped_tensor
# Example usage
tensor = torch.tensor([
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16],
[17, 18, 19, 20],
[21, 22, 23, 24],
[25, 26, 27, 28],
[29, 30, 31, 32],
[33, 34, 35, 36],
])
reshaped_tensor = reshape_via_gilbert(tensor)
print(reshaped_tensor)
reshaped_tensor = reshape_via_gilbert(tensor, width=5)
print(reshaped_tensor)
reshaped_tensor = reshape_via_gilbert(tensor, height=4)
print(reshaped_tensor)
reshaped_tensor = reshape_via_gilbert(tensor, width=8, height=8)
print(reshaped_tensor) And the resulting tensor:
|
Hi, I agree that a better ordered In the code however, why generate the Gilbert curve over a shape that is close to a square? I think it would be better to use the original tensor shape directly, if the tensor is 2D or 3D. Thanks for the links, the 3Blue1Brown video got me thinking if the Gilbert curve is actually stable in the limit like the Hilbert curve, which would be interesting to prove (or to fix the algorithm to make it true). A simple test would be to look at the curve over a [kn, km] grid, where k grows slowly (say by 0.1) and observe any abrupt changes or discontinuities, what do you think @abetusk? I took a quick look at the linked paper. If point cloud processing is important to you, I would suggest using a Hilbert spatial sort (see e.g. https://doc.cgal.org/latest/Spatial_sorting/index.html) to "flatten" the point cloud directly, skipping voxelization altogether. |
Thanks for your answer!
I'm still learning and trying to figure out the best way to do this. Don't expect production ready in the code I posted here, I'm just doing experimentations. That said, I've updated the code, we can now optionally add the
Thanks for this going to check it out. In the meantime, I posted a message on PyTorch's forums to have some insights on how I could optimize the algorithm using GPUs. Link: https://discuss.pytorch.org/t/custom-flatten-function-using-gpu-acceleration/211830 |
Hello,
I’ve recently started new AI classes and I learned about PyTorch’s
flatten()
function, exploring its pros and cons. One of the notable issues is that the original function doesn't preserve spatial locality, which can be important in this contexts of AI and image processing.This reminded me of the incredible @3Blue1Brown video on Hilbert curves (https://www.youtube.com/watch?v=3s7h2MHQtxc). I began wondering if there was an algorithm or method to flatten an image of any size with a Hilbert curve, preserving locality. And that’s what led me here.
Are you aware of an existing implementation for this and Pytorch, do you think this be a worthwhile feature to explore? The idea would be to use this algorithm to reshape a matrix into a simple vector.
Thank you!
Interesting reads:
The text was updated successfully, but these errors were encountered: