Comments (3)
Sorry for the delay here..
Did you do matrix multiplication (c = a.matmul(b)
or c = a @ b
) or direct multiplication (c = a * b
)? From your comment, it looks like direct multiplication. Direct multiplication may already have Autograd support in v1.7.0.
from fft-conv-pytorch.
Thanks a lot for your reply ! I'm not quite familiar with fft-conv. I'd appreciate it if you can give me some advice. Here's my code. If I have two image of signal, whose shapes are [1, H0, W0]
, and two image of kernel whose shapes are [1, H1,W1]
, H1<H0
and W1 < W0
. I want to do convolution independently in a batch. Like signal 1 conv with kernel 1 and signal 2 conv with kernel 2. My question is that after doing FFT, can I use (c = a * b)
directly or should I use c = a @ b
instead. Actually, I dont understand why we need to use matrix multiplication? According to convolution theory, I think we should. use dot multiplication. Thanks a lot again for your sharing !
from fft-conv-pytorch.
Ah, sorry I think I misunderstood your original question. You're exactly right -- you should use direct multiplication after performing the FFT. But if you try and track gradients through your function, I believe you'll get an error in PyTorch 1.7.0.
Example:
layer = FFTConv1d(4, 4, 3)
x = torch.randn(2, 4, 128)
y = layer(x)
loss = y.mean()
loss.backward()
# RuntimeError: _unsafe_view does not support automatic differentiation for outputs with complex dtype.
^^ That's where my "hacky" comment came from:
# (fft_conv.py - lines 24:27)
# Compute the real and imaginary parts independently, then manually insert them
# into the output Tensor. This is fairly hacky but necessary for PyTorch 1.7.0,
# because Autograd is not enabled for complex matrix operations yet. Not exactly
# idiomatic PyTorch code, but it should work for all future versions (>= 1.7.0).
You likely won't see this error with the fft_conv
method, because gradients are not tracked by default! But nn.Module
layers like FFTConv1d
automatically track gradients, and an error gets thrown when loss.backward()
is called.
Also, if the signal/kernel has more than 1 channel, you'll need to use the complex_matmul
method in fft_conv.py
. (Because PyTorch convolutions actually perform matrix multiplication over the channel dimension.) But from your question, it sounds like you're only using 1 channel.
from fft-conv-pytorch.
Related Issues (12)
- can't work on GPU? HOT 1
- FFTConvTranspose
- Stride HOT 1
- CUDA out of memory with complex_matmul HOT 5
- Complex value support?
- Using fft-conv hurts convergence HOT 2
- License HOT 8
- Depth-wise separable convolution? HOT 11
- Propagation of error becomes large very fast HOT 1
- bug HOT 4
- in_channels must be divisible by groups
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
π Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. πππ
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google β€οΈ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from fft-conv-pytorch.