Thank you for open sourcing your implementation.
I tested the FFTConv2d vs the pytorch nn.Conv2d layer using a simple LeNet5 architecture on MNIST with the Adam optimizer with lr=1e-3. By using pytorch's nn.Conv2d I am able to reach ~98% after 1 epoch and I reach ~99% after around 10 epochs whilst using FFTConv2d with the exact same architecture and parameters I get ~92% after 1 epoch and only get up to ~97% after 10 epochs.
I thought this might be due to some aliasing effects so I padded the kernel and the input using the "s" argument in rfftn but I still get the same subpar performance.
class LeNet5(nn.Module):
def __init__(self, num_classes=10):
super().__init__()
pad = 0
self.conv1 = nn.Sequential(
FFTConv2d(1, 6, kernel_size=5, stride=1, padding=pad),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.conv2 = nn.Sequential(
FFTConv2d(6, 16, kernel_size=5, stride=1, padding=pad),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.conv3 = nn.Sequential(
FFTConv2d(16, 120, kernel_size=5, stride=1, padding=pad),
nn.ReLU()
)
self.transform_output = nn.Flatten(start_dim=1)
self.fc = nn.Sequential(
nn.Linear(120, 84),
nn.ReLU(),
nn.Linear(84, num_classes)
)
def forward(self, x):
out = self.conv1(x)
out = self.conv2(out)
out = self.conv3(out)
out = self.transform_output(out)
out = self.fc(out)
return out