Is there a unfold function in JAX? #5968
-
Dear all, thank you for your wonderful work! I have a question, is there a function as torch.nn.F.unfold which transforms the input images into matrix (or like the im2col function in pytorch c++ code.) already build in JAX? What I can find now is a version built with numpy here, but it seems that the jnp.pad(X, ((0,0), (0,0), (pad, pad), (pad, pad))) function don't support the JIT compile? And if I build this function with numpy in python it is slower than the torch.nn.F.unfold. Thank you for all information! :) Best, |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 4 replies
-
Thanks for the question, and the kind wodrs! I think What do you think? |
Beta Was this translation helpful? Give feedback.
-
@mattjj Hi! I was trying the solution and ran the following code:
but i found the following error:
The error looks very cryptic cannot figure out what it means. Is it an XLA / Jax bug? Thanks! |
Beta Was this translation helpful? Give feedback.
-
sorry for necroposting but is there a backward operation for conv_general_dilated_patches, equivalent to fold |
Beta Was this translation helpful? Give feedback.
Thanks for the question, and the kind wodrs!
I think
conv_general_dilated_patches
might do what we want here. It's a pretty general function and there might be room for convenience wrappers.What do you think?