GithubHelp home page GithubHelp logo

saturncloud / dask-pytorch-ddp Goto Github PK

View Code? Open in Web Editor NEW
56.0 6.0 9.0 66 KB

dask-pytorch-ddp is a Python package that makes it easy to train PyTorch models on dask clusters using distributed data parallel.

Home Page: https://saturncloud.io/docs/examples/python/pytorch/qs-03-pytorch-gpu-dask-single-model/

License: BSD 3-Clause "New" or "Revised" License

Makefile 3.64% Python 96.36%
pytorch machine-learning deep-learning distributed-computing dask nlp computer-vision

dask-pytorch-ddp's Introduction

dask-pytorch-ddp

dask-pytorch-ddp is a Python package that makes it easy to train PyTorch models on Dask clusters using distributed data parallel. The intended scope of the project is

  • bootstrapping PyTorch workers on top of a Dask cluster
  • Using distributed data stores (e.g., S3) as normal PyTorch datasets
  • mechanisms for tracking and logging intermediate results, training statistics, and checkpoints.

At this point, this library and examples provided are tailored to computer vision tasks, but this library is intended to be useful for any sort of PyTorch tasks. The only thing really specific to image processing is the S3ImageFolder dataset class. Implementing a PyTorch dataset (assuming map style random access) outside of images currently requires implementing __getitem__(self, idx: int): and __len__(self): We plan to add more varied examples for other use cases in the future, and welcome PRs extending functionality.

Typical non-dask workflow

A typical example of non-dask PyTorch usage is as follows:

Loading Data

Create an dataset (ImageFolder), and wrap it in a DataLoader

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(250),
    transforms.ToTensor()
])

whole_dataset = ImageFolder(path, transform=transform)

batch_size = 100
num_workers = 64
indices = list(range(len(data)))
np.random.shuffle(indices)
train_idx = indices[:num]
test_idx = indices[num:num+num]

train_sampler = SubsetRandomSampler(train_idx)
train_loader = DataLoader(data, sampler=train_sampler, batch_size=batch_size, num_workers=num_workers)

Training a Model

Loop over the dataset, and train the model by stepping the optimizer

device = torch.device(0)
net = models.resnet18(pretrained=False)
model = net.to(device)
device_ids = [0]

criterion = nn.CrossEntropyLoss().cuda()
lr = 0.001
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
count = 0
for epoch in range(n_epochs):
    model.train()  # Set model to training mode
    for inputs, labels in train_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        loss = criterion(outputs, labels)

        # zero the parameter gradients
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        count += 1

Now on Dask

With dask_pytorch_ddp and PyTorch Distributed Data Parallel, we can train on multiple workers as follows:

Loading Data

Load the dataset from S3, and explicitly set the multiprocessing context (Dask defaults to spawn, but pytorch is generally configured to use fork)

from dask_pytorch_ddp.data import S3ImageFolder

whole_dataset = S3ImageFolder(bucket, prefix, transform=transform)
train_loader = torch.utils.data.DataLoader(
    whole_dataset, sampler=train_sampler, batch_size=batch_size, num_workers=num_workers, multiprocessing_context=mp.get_context('fork')
)

Training in Parallel

Wrap the training loop in a function (and add metrics logging. Not necessary, but very useful). Convert the model into a PyTorch Distributed Data Parallel (DDP) model which knows how to sync gradients together across workers.

import uuid
import pickle
import logging
import json


key = uuid.uuid4().hex
rh = DaskResultsHandler(key)

def run_transfer_learning(bucket, prefix, samplesize, n_epochs, batch_size, num_workers, train_sampler):
    worker_rank = int(dist.get_rank())
    device = torch.device(0)
    net = models.resnet18(pretrained=False)
    model = net.to(device)
    model = DDP(model, device_ids=[0])

    criterion = nn.CrossEntropyLoss().cuda()
    lr = 0.001
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    whole_dataset = S3ImageFolder(bucket, prefix, transform=transform)
    
    train_loader = torch.utils.data.DataLoader(
        whole_dataset,
        sampler=train_sampler,
        batch_size=batch_size,
        num_workers=num_workers,
        multiprocessing_context=mp.get_context('fork')
    )
    
    count = 0
    for epoch in range(n_epochs):
        # Each epoch has a training and validation phase
        model.train()  # Set model to training mode
        for inputs, labels in train_loader:
            dt = datetime.datetime.now().isoformat()
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)

            # zero the parameter gradients
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            count += 1

            # statistics
            rh.submit_result(
                f"worker/{worker_rank}/data-{dt}.json",
                json.dumps({'loss': loss.item(), 'epoch': epoch, 'count': count, 'worker': worker_rank})
            )
            if (count % 100) == 0 and worker_rank == 0:
                rh.submit_result(f"checkpoint-{dt}.pkl", pickle.dumps(model.state_dict()))

How does it work?

dask-pytorch-ddp is largely a wrapper around existing pytorch functionality. pytorch.distributed provides infrastructure for Distributed Data Parallel (DDP).

In DDP, you create N workers, and the 0th worker is the "master", and coordinates the synchronization of buffers and gradients. In SGD, gradients are normally averaged between all data points in a batch. By running batches on multiple workers, and averaging the gradients, DDP enables you to run SGD with a much bigger batch size (N * batch_size)

dask-pytorch-ddp sets some environment variables to configure the "master" host and port, and then calls init_process_group before training, and calls destroy_process_group after training. This is the same process normally done manually by the data scientist.

Multi GPU machines

dask_cuda_worker automatically rotates CUDA_VISIBLE_DEVICES for each worker it creates (typically one per GPU). As a result, your PyTorch code should always start with the 0th GPU.

For example, if I have an 8 GPU machine, the 3rd worker will have CUDA_VISIBLE_DEVICES set to 2,3,4,5,6,7,0,1. On that worker, if I call torch.device(0), I will get GPU 2.

What else?

dask-pytorch-ddp also implements an S3 based ImageFolder. More distributed friendly datasets are planned. dask-pytorch-ddp also implements a basic results aggregation framework so that it is easy to collect training metrics across different workers. Currently, only DaskResultsHandler which leverages Dask pub-sub communication protocols is implemented, but an S3 based result handler is planned.

Some Notes

Dask generally spawns processes. PyTorch generally forks. When using a multiprocessing enabled data loader, it is a good idea to pass the Fork multiprocessing context to force the use of Forking in the data loader.

Some Dask deployments do not permit spawning processes. To override this, you can change the distributed.worker.daemon setting.

Environment variables are a convenient way to do this:

DASK_DISTRIBUTED__WORKER__DAEMON=False

dask-pytorch-ddp's People

Contributors

hhuuggoo avatar jameslamb avatar skirmer avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

dask-pytorch-ddp's Issues

`dispatch.run` is not resilient to worker loss

dispatch.run uses worker-restrictions to pin tasks to the workers they should get executed on. Should a worker get removed (or possibly restarted), this will cause the task to transition to the no-worker state and remain there indefinitely (see dask/distributed#7346). From what I see, there is no mechanism implemented to prevent this.

To circumvent this, dask-pytorch-ddp would probably also benefit from dask/distributed#8624.

Missing file VERSION for source installation from PyPI

When doing an install from source from PyPI, the installation fails because the VERSION file is missing:

❯ pip install dask-pytorch --no-binary :all:
Collecting dask-pytorch
  Downloading dask-pytorch-0.1.0.tar.gz (8.1 kB)
    ERROR: Command errored out with exit status 1:
     command: /Users/jqi/miniconda3/bin/python3.8 -c 'import sys, setuptools, tokenize; sys.argv[0] = '"'"'/private/var/folders/nq/vp3dgt812jgb0q09rh5l706c0000gn/T/pip-install-lwhzmu1h/dask-pytorch/setup.py'"'"'; __file__='"'"'/private/var/folders/nq/vp3dgt812jgb0q09rh5l706c0000gn/T/pip-install-lwhzmu1h/dask-pytorch/setup.py'"'"';f=getattr(tokenize, '"'"'open'"'"', open)(__file__);code=f.read().replace('"'"'\r\n'"'"', '"'"'\n'"'"');f.close();exec(compile(code, __file__, '"'"'exec'"'"'))' egg_info --egg-base /private/var/folders/nq/vp3dgt812jgb0q09rh5l706c0000gn/T/pip-pip-egg-info-_i2ttudx
         cwd: /private/var/folders/nq/vp3dgt812jgb0q09rh5l706c0000gn/T/pip-install-lwhzmu1h/dask-pytorch/
    Complete output (5 lines):
    Traceback (most recent call last):
      File "<string>", line 1, in <module>
      File "/private/var/folders/nq/vp3dgt812jgb0q09rh5l706c0000gn/T/pip-install-lwhzmu1h/dask-pytorch/setup.py", line 7, in <module>
        with open("VERSION", "r") as f:
    FileNotFoundError: [Errno 2] No such file or directory: 'VERSION'
    ----------------------------------------
ERROR: Command errored out with exit status 1: python setup.py egg_info Check the logs for full command output.

This won't be a problem for most users because pip will install from the wheel distribution, but in some cases the wheel won't work and pip will fall back to unpacking the tar archive and installing from source. You can download the tar archive manually and see that indeed the VERSION file is missing.

VERSION is a nonstandard file for setup.py builds, so you'll need to create a MANIFEST.in file that explicitly includes it. https://packaging.python.org/guides/using-manifest-in/

Assertion Error

Hi,

When I was run this code (https://saturncloud.io/docs/examples/python/pytorch/qs-03-pytorch-gpu-dask-single-model/), I get this error:

daskcluster-worker-1     | 2022-11-13 17:01:17,386 - distributed.worker - WARNING - Compute Failed
daskcluster-worker-1     | Key:       dispatch_with_ddp-cbbbf432f092a3807b25cc40c48f7660
daskcluster-worker-1     | Function:  dispatch_with_ddp
daskcluster-worker-1     | args:      ()
daskcluster-worker-1     | kwargs:    {'pytorch_function': <function train at 0x7f06b9bba040>, 'master_addr': '172.23.0.4', 'master_port': 12345, 'rank': 1, 'world_size': 2, 'backend': 'nccl'}
daskcluster-worker-1     | Exception: 'AssertionError()'
daskcluster-worker-1     | 
daskcluster-worker-2     | 2022-11-13 17:01:17,387 - distributed.worker - WARNING - Compute Failed
daskcluster-worker-2     | Key:       dispatch_with_ddp-9ce4ce0b9f5f85ff8ead8f6f2e9a9bcf
daskcluster-worker-2     | Function:  dispatch_with_ddp
daskcluster-worker-2     | args:      ()
daskcluster-worker-2     | kwargs:    {'pytorch_function': <function train at 0x7fa21dd38940>, 'master_addr': '172.23.0.4', 'master_port': 12345, 'rank': 0, 'world_size': 2, 'backend': 'nccl'}
daskcluster-worker-2     | Exception: 'AssertionError()'

Why I did get this error? Can you help me?

Thank you.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.