diff --git a/cyto_dl/nn/head/vic_reg.py b/cyto_dl/nn/head/vic_reg.py index 1a3eb966..8c993b73 100644 --- a/cyto_dl/nn/head/vic_reg.py +++ b/cyto_dl/nn/head/vic_reg.py @@ -14,9 +14,11 @@ def __init__(self, dimensions=[2048, 8192, 8192, 8192]): super().__init__() layers = [] for i in range(len(dimensions) - 1): - layers.append(nn.Linear(dimensions[i], dimensions[i + 1])) - layers.append(nn.BatchNorm1d(dimensions[i + 1])) - layers.append(nn.ReLU(True)) + layers += [ + nn.Linear(dimensions[i], dimensions[i + 1]), + nn.BatchNorm1d(dimensions[i + 1]), + nn.ReLU(True), + ] layers.append(nn.Linear(dimensions[-2], dimensions[-1], bias=False)) self.model = nn.Sequential(*layers)