GithubHelp home page GithubHelp logo

Comments (10)

SeanNaren avatar SeanNaren commented on May 4, 2024 6

A lot has changed since this issue, and I'd like to summarize:

There are two ways to consider scaling architectures

  1. Split layers onto devices manually
  2. Split all layers equally onto devices

1 is extremely difficult to get right when architectures are large and complicated and to maintain effeciency. 2 which in recent years via DeepSpeed and now FairScale are more prominent, offer an elegant way to scale model architecture with minimal annotation.

Fully Sharded Data Parallel has been merged, and offers the ability to leverage 2 and in most cases, solve the underlying scaling issue. I have a PR for FSDP documentation #7791 which will hopefully explain more as to how this works :) Once merged, we should be able to close this!

EDIT code example:

import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from fairscale.nn import wrap

class MyModule(LightningModule):
    def configure_sharded_model(self):
        # layers will be sharded across all devices
        model_a = wrap(SomeModel())
        layer_1 = wrap(Linear(...))
        layer2 = wrap(Linear(...))
        self.model = nn.Sequential(model_a, layer_1, layer_2)

    def forward(x):
        x = self.model(x)
        return x

model = MyModule()
trainer = Trainer(gpus=4, plugins='fsdp')
trainer.fit(model)

from lightning.

sholalkere avatar sholalkere commented on May 4, 2024

Could use something similar to this to approximate mem usage per layer/module and then balance accordingly.

from lightning.

williamFalcon avatar williamFalcon commented on May 4, 2024

that’s helpful. You also beed to account for the size of the inout and output including taking batch size into account. sometimes the problem is that the layer output blows up the ram. so we’d need to probably try catch a few passes through each block and calculate its full memory requirement.

the memory requirement is weights + input + output. and gpu 0 has added overhead of optimizer which in case of adam has the grads

from lightning.

stale avatar stale commented on May 4, 2024

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

from lightning.

tchaton avatar tchaton commented on May 4, 2024

@SeanNaren, Fairscale should partially provide this feature with OSS right ?

from lightning.

SeanNaren avatar SeanNaren commented on May 4, 2024

@tchaton, not exactly! This is covered by #4443 which introduces the pipe accelerator (allows you to split a model across GPUs). The self balancing part isn't easy, but can be done via functions like this in fairscale:
https://github.com/facebookresearch/fairscale/blob/7c5203eb772d7c67e45ed6ff6b66579b8e5cbc6c/fairscale/nn/pipe/balance/__init__.py#L100

I've been looking into the pipe accelerator but there are a few nice changes coming up with this PR: facebookresearch/fairscale#156

Would be nice to get them in first before adding the plugin/accelerator for this :)

from lightning.

LithiumH avatar LithiumH commented on May 4, 2024

Has there been any progress on this feature? I see that there’s a Beta section on the documentation here: https://pytorch-lightning.readthedocs.io/en/latest/multi_gpu.html#model-parallelism-beta but I don’t know if this works with DDP

from lightning.

tchaton avatar tchaton commented on May 4, 2024

Any updates on this issues ?

from lightning.

SeanNaren avatar SeanNaren commented on May 4, 2024

Hey @tchaton, a small update :)

It's been a while and supporting transparent self-balancing architectures with no friction hasn't been solved, and that's primarily due to the difficulty of engineering such balancing.

In most cases this requires a lot of engineering effort, and even our pipe implementation is very specific/provides little flexibility when using.

The current roadmap tends to Fully Sharded Data Parallel replacing the need for self-balancing, by allowing the user to annotate layers (or automate annotation) with FSDP, signalling that these layers should be loaded into memory, do any necessary computation and be de-allocated ASAP. This allows to scale the model size drastically and trade off time. If anyone is interested, look at our initial integration which we're working with the FairScale team to prove out and ensure we have rigid tests/benchmarks in place #6152

from lightning.

carmocca avatar carmocca commented on May 4, 2024

Closing this super old issue. strategy="fsdp" is your friend.

You can find guides at https://lightning.ai/docs/pytorch/latest/advanced/model_parallel.html for the Trainer and https://lightning.ai/docs/fabric/latest/advanced/model_parallel/fsdp.html for Fabric

from lightning.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.