GithubHelp home page GithubHelp logo

Comments (3)

fkodom avatar fkodom commented on May 25, 2024 1

Sorry for the delay here..

Did you do matrix multiplication (c = a.matmul(b) or c = a @ b) or direct multiplication (c = a * b)? From your comment, it looks like direct multiplication. Direct multiplication may already have Autograd support in v1.7.0.

from fft-conv-pytorch.

RobinhoodKi avatar RobinhoodKi commented on May 25, 2024

Thanks a lot for your reply ! I'm not quite familiar with fft-conv. I'd appreciate it if you can give me some advice. Here's my code. If I have two image of signal, whose shapes are [1, H0, W0] , and two image of kernel whose shapes are [1, H1,W1] , H1<H0 and W1 < W0. I want to do convolution independently in a batch. Like signal 1 conv with kernel 1 and signal 2 conv with kernel 2. My question is that after doing FFT, can I use (c = a * b) directly or should I use c = a @ b instead. Actually, I dont understand why we need to use matrix multiplication? According to convolution theory, I think we should. use dot multiplication. Thanks a lot again for your sharing !
ζˆͺ屏2021-01-22上午10 43 31

from fft-conv-pytorch.

fkodom avatar fkodom commented on May 25, 2024

Ah, sorry I think I misunderstood your original question. You're exactly right -- you should use direct multiplication after performing the FFT. But if you try and track gradients through your function, I believe you'll get an error in PyTorch 1.7.0.

Example:

layer = FFTConv1d(4, 4, 3)
x = torch.randn(2, 4, 128)
y = layer(x)
loss = y.mean()
loss.backward()

# RuntimeError: _unsafe_view does not support automatic differentiation for outputs with complex dtype.

^^ That's where my "hacky" comment came from:

# (fft_conv.py - lines 24:27)
# Compute the real and imaginary parts independently, then manually insert them
# into the output Tensor.  This is fairly hacky but necessary for PyTorch 1.7.0,
# because Autograd is not enabled for complex matrix operations yet.  Not exactly
# idiomatic PyTorch code, but it should work for all future versions (>= 1.7.0).

You likely won't see this error with the fft_conv method, because gradients are not tracked by default! But nn.Module layers like FFTConv1d automatically track gradients, and an error gets thrown when loss.backward() is called.

Also, if the signal/kernel has more than 1 channel, you'll need to use the complex_matmul method in fft_conv.py. (Because PyTorch convolutions actually perform matrix multiplication over the channel dimension.) But from your question, it sounds like you're only using 1 channel.

from fft-conv-pytorch.

Related Issues (12)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.