Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a PixelShuffle implementation like in pytorch #2464

Open
Rick-29 opened this issue Nov 7, 2024 · 3 comments
Open

Add a PixelShuffle implementation like in pytorch #2464

Rick-29 opened this issue Nov 7, 2024 · 3 comments

Comments

@Rick-29
Copy link

Rick-29 commented Nov 7, 2024

In PyTorch there is a module called PixelShuffle .
I created as mall implementation of it that currently only supports 4D Tensors (I tried to follow the implementation format of the library) and wanted to share it to check if someday it could be added to the crate.
This is the full code:

use burn::{config::Config, module::Module, prelude::Backend, tensor::Tensor};

#[derive(Config, Debug)]
pub struct PixelShuffleConfig {
    #[config(default = "2")]
    upscale_factor: usize
}

#[derive(Module, Debug, Clone)]
pub struct PixelShuffle {
    upscale_factor: usize
}

impl PixelShuffleConfig {
    pub fn init(&self) -> PixelShuffle {
        PixelShuffle { upscale_factor: self.upscale_factor }
    }
}

impl PixelShuffle {
    pub fn forward<B: Backend>(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {        
        let mut dims = input.dims();
        dims.reverse();
        let c = dims[2];
        let h = dims[1];
        let w = dims[0];
        if c % (self.upscale_factor * self.upscale_factor) != 0 {
            panic!("pixel_shuffle expects its input's 'channel' dimension to be divisible by the square of upscale_factor, but input.size(-3)={c} is not divisible by {}", self.upscale_factor * self.upscale_factor)
        }
        let oc = c / (self.upscale_factor * self.upscale_factor);
        let oh = h * self.upscale_factor;
        let ow = w * self.upscale_factor;
    
        let x = input.reshape([dims[3], oc, self.upscale_factor, self.upscale_factor, h, w]);  
        let x = x.permute([0, -5, -2, -4, -1, -3]);
        x.reshape([dims[3], oc, oh, ow])
    }
}

And here are some test that I made comparing the outputs with the ones from the pytorch implementation:

#[cfg(test)]
mod tests {
    use burn::backend::Wgpu;

    use super::*;

    #[test]
    fn test_pixel_shuffle() {
        let shuffle = PixelShuffle { upscale_factor: 3 };
        let tensor1 = Tensor::<Wgpu, 4>::random([1, 9, 4, 4], burn::tensor::Distribution::Default, &Default::default());
        let tensor1_shuffle = shuffle.forward(tensor1);
        dbg!(tensor1_shuffle.dims());
        assert_eq!([1, 1, 12, 12], tensor1_shuffle.dims());

        let tensor2 = Tensor::<Wgpu, 4>::random([1, 18, 7, 5], burn::tensor::Distribution::Default, &Default::default());
        let tensor2_shuffle = shuffle.forward(tensor2);
        dbg!(tensor2_shuffle.dims());
        assert_eq!([1, 2, 21, 15], tensor2_shuffle.dims());

        let tensor3 = Tensor::<Wgpu, 4>::random([128, 45, 33, 7], burn::tensor::Distribution::Default, &Default::default());
        let tensor3_shuffle = shuffle.forward(tensor3);
        dbg!(tensor3_shuffle.dims());
        assert_eq!([128, 5, 99, 21], tensor3_shuffle.dims());
    }

    #[test]
    #[should_panic]
    fn test_pixel_shuffle_panic() {
        let shuffle = PixelShuffle { upscale_factor: 3 };
        let tensor1 = Tensor::<Wgpu, 4>::random([128, 46, 33, 7], burn::tensor::Distribution::Default, &Default::default());
        let tensor1_shuffle = shuffle.forward(tensor1);
        dbg!(tensor1_shuffle.dims());
    }
}

The original implementation in c++ is here

@laggui
Copy link
Member

laggui commented Nov 8, 2024

We're always open to PRs! Feel free to open one to add this to the modules in burn-core 🙂

@Rick-29
Copy link
Author

Rick-29 commented Nov 8, 2024

How do i do that?
Sorry, I am really new to creating pull requests

@laggui
Copy link
Member

laggui commented Nov 8, 2024

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants