Comments (10)
Hi @HieuPhan33 ,
Thanks for the feedback and question! Are you able to try PyTorch 1.6 and let me know if things work properly? I suspect this is a PyTorch version issue (the backward hook seems problematic with PyTorch 1.7 and above) and will push a fix to the current implementation soon (based on PyTorch's custom backward support). I should be able to do this in the next few days and will let you know!
from deq.
Hi @jerrybai1995,
Thanks for your reply. Unfortunately, RTX 3090 is incompatible with PyTorch 1.6.
Please let me know when you fix the backward hook problem with torch 1.7.
Appreciate your works!
from deq.
Hi, I now can train with torch 1.7.0 (instead of 1.7.1)!
from deq.
Heyo, just chiming in here to say that I'm experiencing a similar segmentation fault. I'm on 1.9 and it occurs when removing a hook during the backward_hook
function.
from deq.
Hi @tesfaldet,
Could you try the following implementation?
Replacing L453-460, currently
deq/MDEQ-Vision/lib/models/mdeq_core.py
Lines 453 to 460 in c161644
with
z1_cp = z1.clone().detach().requires_grad_()
new_z1_cp = func(z1_cp)
def backward_hook(grad):
result = self.b_solver(lambda y: autograd.grad(new_z1_cp, z1_cp, y, retain_graph=True)[0] + grad, torch.zeros_like(grad),
threshold=b_thres, stop_mode=self.stop_mode, name="backward")
return result['result']
new_z1.register_hook(backward_hook) # Notice that it's new_z1 here, not new_z1_cp!
This should probably resolve the issue with PyTorch 1.9 but on the other hand pays the cost of an additional layer (in order to produce new_z1_cp
). Please let me know if this resolves the issue.
(@HieuPhan33, cc you in case you run into this issue in the future when using PyTorch >1.7.0 😄 ).
from deq.
I tried this with python 1.6, 1.7, and 1.9 and would experience an out of memory during the backward pass:
def forward(self, x):
# setup
_, c, h, w = x.shape
x0 = tensor2vec(x)
func = lambda y: tensor2vec(self.f(vec2tensor(y, (c, h, w))))
# Forward pass
with torch.no_grad():
x_star = self.solver(func, x0, threshold=30)['result']
if self.training:
# re-engage autograd tape
x_star_new = func(x_star.requires_grad_())
# set up Jacobian-vector product for backward pass
def backward_hook(grad):
# Compute the fixed point of yJ + grad, where J=J_f is the Jacobian of f at z_star
grad_func = lambda y: autograd.grad(x_star_new, x_star, y, retain_graph=True)[0] + grad
new_grad = self.solver(grad_func, torch.zeros_like(grad),
threshold=40)['result']
print('grad, new_grad', grad, new_grad)
return new_grad
if self.hook is not None:
self.hook.remove()
torch.cuda.synchronize()
self.hook = x_star_new.register_hook(backward_hook)
return vec2tensor(x_star_new, (c, h, w))
I'll try what you recommended
from deq.
@tesfaldet Oh what you posted here will definitely have out-of-memory (OOM) error because whenever the backward pass goes through x_star_new
, PyTorch will run the backward_hook
function (as that's what hook does), which will call the backward pass on x_star_new
again (via the autograd.grad
call), which will again call the backward_hook
function. Therefore there'll be an infinite recursion loop. This is the fundamental reason why I put self.hook.remove()
and torch.cuda.synchronize()
within the backward_hook
function originally.
from deq.
Ohhhhhhh, I see. So the autogrid.grad
call calls the backward_hook
function again. I completely didn't realize! The reason I moved the self.hook.remove()
block out of the backward_hook
function to begin with was because it was causing a segmentation fault otherwise :( This was on all the versions of PyTorch I listed above. Your suggested change seems to forego removing hooks but I worry if that would slowly eat up memory over time.
from deq.
I don't think it would lead to a memory leak because python's gc will clean up the old hook once its reference count goes to zero and a new training iteration comes in. The major drawback of the suggested change is that we have to pay some additional speed and memory cost to compute the z1_cp
.
from deq.
Gotcha. Your suggested change seems to work! Although I think I'm using it incorrectly in my own model (although the ResNet example here works fine, with some slight modifications to it).
from deq.
Related Issues (20)
- Two slightly different process for Deq HOT 2
- CIFAR-10 Reproduction HOT 6
- Test ImageNet Pre-trained Model HOT 10
- Segmentation fault after removing hook HOT 3
- RuntimeError: einsum(): the number of subscripts in the equation (3) does not match the number of dimensions (4) for operand 0 and no ellipsis was given HOT 1
- DEQ for Vision Transformer HOT 2
- Memory consumption on CIFAR-10 HOT 4
- I'd like to ask if anderson can't be used normally sometimes HOT 11
- Does MDEQ have different inference results for different batch sizes? HOT 6
- Expected a 'cuda' device type for generator (related to speed issues?) HOT 5
- RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment HOT 2
- Question about Remove Hook HOT 6
- Higher order derivatives
- UnboundLocalError: local variable 'lowest_xest' referenced before assignment HOT 4
- Broyden defeats the purpose of DEQs? HOT 6
- UserWarning: resource_tracker: There appear to be 14 leaked semaphore objects to clean up at shutdown HOT 4
- Expected a 'cuda' device type for generator but found 'cpu' HOT 2
- Mismatch between a pretrained ImageNet model and a config file HOT 1
- Hyperparameters for MDEQ-XL on ImageNet
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from deq.