GithubHelp home page GithubHelp logo

fkodom / fft-conv-pytorch Goto Github PK

View Code? Open in Web Editor NEW
440.0 8.0 53.0 149 KB

Implementation of 1D, 2D, and 3D FFT convolutions in PyTorch. Much faster than direct convolutions for large kernel sizes.

License: MIT License

Python 100.00%
pytorch python3 convolution image-processing neural-networks

fft-conv-pytorch's Introduction

trophies

streak-stats

If you find my projects useful, please consider becoming a sponsor. Everything here comes from my free time, and is released under permissive licenses (e.g. MIT). Your contribution helps fund open-source AI.

buymeacoffee

fft-conv-pytorch's People

Contributors

alexhagen avatar aretor avatar fkodom avatar papkov avatar yoyololicon avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

fft-conv-pytorch's Issues

FFTConvTranspose

  1. Is it possible to perform a transpose in the spectral domain?
  2. FFTConv2d(in,out,.......) initializes weights just like conv?

Using fft-conv hurts convergence

Thank you for open sourcing your implementation.

I tested the FFTConv2d vs the pytorch nn.Conv2d layer using a simple LeNet5 architecture on MNIST with the Adam optimizer with lr=1e-3. By using pytorch's nn.Conv2d I am able to reach ~98% after 1 epoch and I reach ~99% after around 10 epochs whilst using FFTConv2d with the exact same architecture and parameters I get ~92% after 1 epoch and only get up to ~97% after 10 epochs.

I thought this might be due to some aliasing effects so I padded the kernel and the input using the "s" argument in rfftn but I still get the same subpar performance.

The exact architecture is as below:

class LeNet5(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        pad = 0
        self.conv1 = nn.Sequential(
            FFTConv2d(1, 6, kernel_size=5, stride=1, padding=pad),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.conv2 = nn.Sequential(
            FFTConv2d(6, 16, kernel_size=5, stride=1, padding=pad),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.conv3 = nn.Sequential(
            FFTConv2d(16, 120, kernel_size=5, stride=1, padding=pad),
            nn.ReLU()
        )
        self.transform_output = nn.Flatten(start_dim=1)
        self.fc = nn.Sequential(
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, num_classes)
        )

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)
        out = self.transform_output(out)
        out = self.fc(out)
        return out

Any idea why this is the case?

Propagation of error becomes large very fast

Hi fkodom,
I was trying to experiment with your fftconv implementation but I saw that just after 2 layers the error becomes quite large. Is it due to a malinterpretation of your code from my part? Might it be that the FFT + IFFT simply introduce too much error when concatenated?

Thank you in advance for your help! :)
David

image

SOLUTION:
The problem is that the convolution was growing very largely, for which a difference of 1e-3 was not that big. In short the method works fantastically good :)

Complex value support?

Hi,

Thanks for putting this together - complex number support is starting to get pretty advanced in pytorch, is there any chance you'd consider changing things here to allow for it?

Cheers,

Stephen

can't work on GPU?

Dear authors,

I am interested in this work and thanks for your share. I find this function can't work on GPU. Is it true?

CUDA out of memory with complex_matmul

Hello,

Thank you for the effort you put into making this work, however I am very confused. When I want to apply this "FFTCONV2D" layer to a network (Resnet for example), in GPU I always get the error 'CUDA OUT OF MEMORY......'. It's due to the "complex_matmul" function which needs a lot of resources, how can I solve this problem please?

License

I'm trying to package this for conda-forge, any information on the license?

Thank you

Best,

Mark

xref: yoyololicon#9

bug

x=torch.randn(1,3,32,128)
net1=nn.Conv2d(3, 64, 1, (2,1), 1)
net2=FFTConv2d(3,64,1,(2,1),1)
out1=net1(x)
out2=net2(x)
print(out1.shape,out2.shape)
out!=out2

Autograd for complex matrix multiplication in Pytorch ?

Thanks a lot for sharing your code. It helps me a lot ! I have a question about complex matrix multiplication. I notice that you made some comments on this part saying that "This is fairly hacky but necessary for PyTorch 1.7.0, because Autograd is not enabled for complex matrix operations yet." But when I using Pytorch 1.7.0 , I just use 'c = a * b ' directly. It seems that autograd works quite well ? I'm not sure now... Would you like to do some experiments to double check it ? Thanks a lot !

Stride

Thank you for this code. How can I specify the stride just like a normal conv layer? Thanks.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.