Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

complex conv2d 数学理论和实现是否一致 #26

Open
shenxiaozheng opened this issue Feb 27, 2024 · 2 comments
Open

complex conv2d 数学理论和实现是否一致 #26

shenxiaozheng opened this issue Feb 27, 2024 · 2 comments

Comments

@shenxiaozheng
Copy link

shenxiaozheng commented Feb 27, 2024

No description provided.

@huyanxin
Copy link
Owner

complex conv 包含两部分

  1. weight,这部分我理解是没问题的;
  2. bias, 这部分是不符合的,应该把bias给成false,然后把这bias部分放到nn.Conv2d外面,单独使用nn.Parameter包装一下bias_rbias_i

如果你有兴趣可以提个pr把bias的bug修复一下

@huyanxin
Copy link
Owner

这个是我这边写的一个complex linear forward和backward的代码,你可以对照一下你这边推的公式看看是否有出入

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# ************************************************************************
# *
# * @file:run_clinear.py
# * @author: huyanxin ([email protected])
# * @date:2024-03-18 08:40
# * @description: Python Script
# * @Copyright (c)  all right reserved
# *
#*************************************************************************

from typing import Tuple


import torch
from torch import Tensor as Tensor
class ComplexLinearFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input_r: Tensor, input_i: Tensor, weight_r: Tensor, weight_i: Tensor, bias_r: Tensor, bias_i: Tensor):
        output_r = input_r @ weight_r.T - input_i @ weight_i.T + bias_r[None, ...]
        output_i = input_r @ weight_i.T + input_i @ weight_r.T + bias_i[None, ...]
        ctx.save_for_backward(input_r, input_i, weight_r, weight_i)
        return output_r, output_i

    @staticmethod
    def backward(ctx, grad_output_r: Tensor, grad_output_i: Tensor):
        input_r, input_i, weight_r, weight_i = ctx.saved_tensors

        g_out = grad_output_r.T@input_r + grad_output_i.T@input_i
        grad_input_r = grad_output_r @ weight_r + grad_output_i@weight_i
        grad_input_i = -grad_output_r @ weight_i + grad_output_i@weight_r
        grad_weight_r = grad_output_r.T @ input_r + grad_output_i.T@input_i
        grad_weight_i = -grad_output_r.T @ input_i + grad_output_i.T@input_r
        grad_bias_r = grad_output_r.sum(0)
        grad_bias_i = grad_output_i.sum(0)
        return grad_input_r, grad_input_i, grad_weight_r, grad_weight_i, grad_bias_r, grad_bias_i

class ComplexLinear(torch.nn.Module):
    def __init__(self, in_features: int, out_features: int, dtype: torch.dtype) -> None:
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight_r = torch.nn.Parameter(torch.empty((out_features, in_features), dtype=dtype))
        self.weight_i = torch.nn.Parameter(torch.empty((out_features, in_features), dtype=dtype))
        self.bias_r = torch.nn.Parameter(torch.empty((out_features,), dtype=dtype))
        self.bias_i = torch.nn.Parameter(torch.empty((out_features,), dtype=dtype))
        self.reset_parameters()

    def reset_parameters(self) -> None:
        torch.nn.init.uniform_(self.weight_r)
        torch.nn.init.uniform_(self.weight_i)
        torch.nn.init.uniform_(self.bias_r)
        torch.nn.init.uniform_(self.bias_i)

    def forward(self, input_r: Tensor, input_i: Tensor) :
        return ComplexLinearFunction.apply(input_r,input_i, self.weight_r, self.weight_i, self.bias_r,self.bias_i)


torch.manual_seed(20)
in_features = 128
out_features = 256
batch_size = 1
dtype = torch.float64

linear = ComplexLinear(in_features, out_features, dtype=dtype)
input_r = torch.randn(batch_size, in_features, requires_grad=True, dtype=dtype)
input_i = torch.randn(batch_size, in_features, requires_grad=True, dtype=dtype)

label_r = torch.randn(batch_size, out_features, requires_grad=False, dtype=dtype)
label_i = torch.randn(batch_size, out_features, requires_grad=False, dtype=dtype)

out_r,out_i = linear(input_r, input_i)
#loss = torch.sum(out_r-label_r) + torch.sum(out_i-label_i)
#loss.backward()
print(torch.autograd.gradcheck(linear, (input_r, input_i)))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants