GithubHelp home page GithubHelp logo

Comments (10)

jerrybai1995 avatar jerrybai1995 commented on July 21, 2024

Hi @HieuPhan33 ,

Thanks for the feedback and question! Are you able to try PyTorch 1.6 and let me know if things work properly? I suspect this is a PyTorch version issue (the backward hook seems problematic with PyTorch 1.7 and above) and will push a fix to the current implementation soon (based on PyTorch's custom backward support). I should be able to do this in the next few days and will let you know!

from deq.

HieuPhan33 avatar HieuPhan33 commented on July 21, 2024

Hi @jerrybai1995,
Thanks for your reply. Unfortunately, RTX 3090 is incompatible with PyTorch 1.6.
Please let me know when you fix the backward hook problem with torch 1.7.
Appreciate your works!

from deq.

HieuPhan33 avatar HieuPhan33 commented on July 21, 2024

Hi, I now can train with torch 1.7.0 (instead of 1.7.1)!

from deq.

tesfaldet avatar tesfaldet commented on July 21, 2024

Heyo, just chiming in here to say that I'm experiencing a similar segmentation fault. I'm on 1.9 and it occurs when removing a hook during the backward_hook function.

from deq.

jerrybai1995 avatar jerrybai1995 commented on July 21, 2024

Hi @tesfaldet,

Could you try the following implementation?

Replacing L453-460, currently

def backward_hook(grad):
if self.hook is not None:
self.hook.remove()
torch.cuda.synchronize()
result = self.b_solver(lambda y: autograd.grad(new_z1, z1, y, retain_graph=True)[0] + grad, torch.zeros_like(grad),
threshold=b_thres, stop_mode=self.stop_mode, name="backward")
return result['result']
self.hook = new_z1.register_hook(backward_hook)

with

z1_cp = z1.clone().detach().requires_grad_()
new_z1_cp = func(z1_cp)
def backward_hook(grad):
      result = self.b_solver(lambda y: autograd.grad(new_z1_cp, z1_cp, y, retain_graph=True)[0] + grad, torch.zeros_like(grad), 
                            threshold=b_thres, stop_mode=self.stop_mode, name="backward")
      return result['result']
new_z1.register_hook(backward_hook)         # Notice that it's new_z1 here, not new_z1_cp!

This should probably resolve the issue with PyTorch 1.9 but on the other hand pays the cost of an additional layer (in order to produce new_z1_cp). Please let me know if this resolves the issue.

(@HieuPhan33, cc you in case you run into this issue in the future when using PyTorch >1.7.0 😄 ).

from deq.

tesfaldet avatar tesfaldet commented on July 21, 2024

I tried this with python 1.6, 1.7, and 1.9 and would experience an out of memory during the backward pass:

def forward(self, x):
    # setup
    _, c, h, w = x.shape
    x0 = tensor2vec(x)
    func = lambda y: tensor2vec(self.f(vec2tensor(y, (c, h, w))))

    # Forward pass
    with torch.no_grad():
        x_star = self.solver(func, x0, threshold=30)['result']

    if self.training:
        # re-engage autograd tape
        x_star_new = func(x_star.requires_grad_())

        # set up Jacobian-vector product for backward pass
        def backward_hook(grad):
            # Compute the fixed point of yJ + grad, where J=J_f is the Jacobian of f at z_star
            grad_func = lambda y: autograd.grad(x_star_new, x_star, y, retain_graph=True)[0] + grad
            new_grad = self.solver(grad_func, torch.zeros_like(grad),
                                    threshold=40)['result']
            print('grad, new_grad', grad, new_grad)
            return new_grad

        if self.hook is not None:
            self.hook.remove()
            torch.cuda.synchronize()

        self.hook = x_star_new.register_hook(backward_hook)
    return vec2tensor(x_star_new, (c, h, w))

I'll try what you recommended

from deq.

jerrybai1995 avatar jerrybai1995 commented on July 21, 2024

@tesfaldet Oh what you posted here will definitely have out-of-memory (OOM) error because whenever the backward pass goes through x_star_new, PyTorch will run the backward_hook function (as that's what hook does), which will call the backward pass on x_star_new again (via the autograd.grad call), which will again call the backward_hook function. Therefore there'll be an infinite recursion loop. This is the fundamental reason why I put self.hook.remove() and torch.cuda.synchronize() within the backward_hook function originally.

from deq.

tesfaldet avatar tesfaldet commented on July 21, 2024

Ohhhhhhh, I see. So the autogrid.grad call calls the backward_hook function again. I completely didn't realize! The reason I moved the self.hook.remove() block out of the backward_hook function to begin with was because it was causing a segmentation fault otherwise :( This was on all the versions of PyTorch I listed above. Your suggested change seems to forego removing hooks but I worry if that would slowly eat up memory over time.

from deq.

jerrybai1995 avatar jerrybai1995 commented on July 21, 2024

I don't think it would lead to a memory leak because python's gc will clean up the old hook once its reference count goes to zero and a new training iteration comes in. The major drawback of the suggested change is that we have to pay some additional speed and memory cost to compute the z1_cp.

from deq.

tesfaldet avatar tesfaldet commented on July 21, 2024

Gotcha. Your suggested change seems to work! Although I think I'm using it incorrectly in my own model (although the ResNet example here works fine, with some slight modifications to it).

from deq.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.