From 407c012d08c22f6485e47083a33dbd2ad24c486e Mon Sep 17 00:00:00 2001 From: manu12121999 Date: Fri, 20 Dec 2024 13:06:02 +0100 Subject: [PATCH] Ignoring NN.modules arguments --- ctrl_c_nn.py | 40 ++++++++++++++++++++++++++++++---------- 1 file changed, 30 insertions(+), 10 deletions(-) diff --git a/ctrl_c_nn.py b/ctrl_c_nn.py index 13ef834..f41568f 100644 --- a/ctrl_c_nn.py +++ b/ctrl_c_nn.py @@ -642,6 +642,8 @@ class nn: class Module: def __init__(self, *args, **kwargs): + if args != () or kwargs != {}: + print('Warning, ignoring', args, kwargs) self.cache = None def __call__(self, x: Tensor): @@ -694,6 +696,7 @@ def forward(self, x: Tensor): start = time.time() self.cache = x res = x.matmul_T_2d(self.weight) + self.bias + print("Linear took ", time.time() - start) return res def backward(self, dout: Tensor): @@ -766,14 +769,16 @@ def backward(self, dout: Tensor): return dout class Conv2d(Module): - def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride=1, padding=0, bias=True): + def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride=1, padding=0, bias=True, *args, **kwargs): + if args != () or kwargs != {}: + print('Warning, Conv2dTranspose ignoring', args, kwargs) super().__init__() self.stride = stride self.padding = padding self.kernel_size = kernel_size self.out_channels = out_channels self.weight = Tensor.fill(shape=(out_channels, in_channels, kernel_size, kernel_size), number=0.0) - self.bias = Tensor.fill(shape=(out_channels, ), number=0.0 if bias else 0) + self.bias = Tensor.fill(shape=(out_channels, ), number=0.0 if bias else 0.0) def forward(self, x: Tensor): return self.forward_gemm(x) @@ -832,7 +837,9 @@ def im2col(x_pad): return res class Conv2dTranspose(Module): - def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride=1, padding=0, bias=True): + def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride=1, padding=0, bias=True, *args, **kwargs): + if args != () or kwargs != {}: + print('Warning, Conv2dTranspose ignoring', args, kwargs) super().__init__() self.stride = stride self.padding = padding @@ -846,14 +853,16 @@ def forward(self, x: Tensor): class BatchNorm2d(Module): def __init__(self, num_features, eps=1e-05, *args, **kwargs): - self.weight = Tensor.random_float((num_features,)) - self.bias = Tensor.random_float((num_features,)) - self.running_mean = Tensor.zeros((num_features,)) - self.running_var = Tensor.ones((num_features,)) - self.num_batches_tracked = Tensor([0]) + if args != () or kwargs != {}: + print('Warning, BatchNorm2d ignoring', args, kwargs) + self.weight = Tensor.fill((num_features,), 0.0) + self.bias = Tensor.fill((num_features,), 0.0) + self.running_mean = Tensor.fill((num_features,), 0.0) + self.running_var = Tensor.fill((num_features,), 1.0) + self.num_batches_tracked = Tensor([0.0]) self.C = num_features self.eps = eps - super().__init__(*args, **kwargs) + super().__init__() def forward(self, x: Tensor): start_time = time.time() @@ -869,7 +878,9 @@ def forward(self, x: Tensor): return y class MaxPool2d(Module): - def __init__(self, kernel_size=2, stride=2, padding=0): + def __init__(self, kernel_size=2, stride=2, padding=0, *args, **kwargs): + if args != () or kwargs != {}: + print('Warning, MaxPool2d ignoring', args, kwargs) self.kernel_size = kernel_size self.stride = stride self.padding = padding @@ -919,6 +930,15 @@ class Dropout(Module): def forward(self, x: Tensor): return x + class Upsample(Module): + def __init__(self, scale_factor, *args, **kwargs): + super().__init__() + print('Warning, ignoring', args, kwargs) + self.scale_factor = scale_factor + + def forward(self, x: Tensor): + return F.interpolate(x, scale_factor=self.scale_factor) + class AbstractLoss: def __init__(self): self.cache = None