GithubHelp home page GithubHelp logo

ram-multiprocess-dataloader's Introduction

Demystify RAM Usage in Multi-Process Data Loaders

A typical PyTorch training program on 8 GPUs with 4 dataloader workers per GPU would create at least 8 * (4+1) = 40 processes. A naive use of torch dataset and dataloader can easily replicate your dataset's RAM usage by 40 times. This issue has probably affected everyone who has done anything nontrivial with PyTorch.

This blog post explains why it happens, and how to avoid the 40x RAM usage.

This github repo contains code and results for the above article.

Dependencies

  • Python >= 3.7
  • Linux
  • PyTorch >= 1.10
  • python -m pip install psutil tabulate msgpack
  • Detectron2 and GPUs needed only for main-multigpu*.py: installation instructions

ram-multiprocess-dataloader's People

Contributors

ppwwyyxx avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

ram-multiprocess-dataloader's Issues

How to avoid issues with dictionaries?

Hello,

I found the repo and blog post very interesting and useful, especially the tensor serialization utility.

I've encountered similar problems with RAM usage, but in my case I have to use dictionaries to store data, or even dictionaries of list.

Can you confirm that the issues you presented may happen with dictionaries too? I also happen to see many "too many open files errors" due to this.

Secondly, how can I serialize dictionaries in a torch Tensor instead of using lists?

I could work with lists too, but it would be very complicated to retrieve the correct Tensor in the main processes then (I find dictionaries to be more flexible on this point of view).

Using torchrun launcher instead of detectron's raises AuthenticationError('digest received was wrong')

Hi Yuxin,
Thank you for your very inspiring blog. I can reproduce your result with the original code. I am trying to adapt the solution for the multi-GPU case to the launcher of pytorch 2.0 itself (https://github.com/LeungTsang/RAM-multiprocess-dataloader). I almost keep everything the same but simply read the rank, local_rank, world_size variables from the environment variables provided by torchrun and initialize the process group by dist.init_process_group(backend='nccl'). However, running torchrun --nproc_per_node=2 --nnodes=1 --node_rank=0 main-multigpu-sharedmem.py, I got an error:

Traceback (most recent call last):
File "/home/lzeng/.conda/envs/torch2/lib/python3.10/multiprocessing/resource_sharer.py", line 138, in _serve
Traceback (most recent call last):
File "/home/lzeng/testspace/RAM-multiprocess-dataloader/main-multigpu-sharedmem.py", line 59, in
main()
File "/home/lzeng/testspace/RAM-multiprocess-dataloader/main-multigpu-sharedmem.py", line 26, in main
with self._listener.accept() as conn:
File "/home/lzeng/.conda/envs/torch2/lib/python3.10/multiprocessing/connection.py", line 465, in accept
ds = DatasetFromList(TorchShmSerializedList(
File "/home/lzeng/testspace/RAM-multiprocess-dataloader/serialize.py", line 99, in init
self._addr, self._lst = mp.reduction.ForkingPickler.loads(handle)
File "/home/lzeng/.conda/envs/torch2/lib/python3.10/site-packages/torch/multiprocessing/reductions.py", line 307, in rebuild_storage_fd
deliver_challenge(c, self._authkey)
File "/home/lzeng/.conda/envs/torch2/lib/python3.10/multiprocessing/connection.py", line 745, in deliver_challenge
fd = df.detach()
File "/home/lzeng/.conda/envs/torch2/lib/python3.10/multiprocessing/resource_sharer.py", line 57, in detach
with _resource_sharer.get_connection(self._id) as conn:
File "/home/lzeng/.conda/envs/torch2/lib/python3.10/multiprocessing/resource_sharer.py", line 86, in get_connection
c = Client(address, authkey=process.current_process().authkey)
File "/home/lzeng/.conda/envs/torch2/lib/python3.10/multiprocessing/connection.py", line 508, in Client
raise AuthenticationError('digest received was wrong')
multiprocessing.context.AuthenticationError: digest received was wrong
answer_challenge(c, authkey)
File "/home/lzeng/.conda/envs/torch2/lib/python3.10/multiprocessing/connection.py", line 759, in answer_challenge
raise AuthenticationError('digest sent was rejected')
multiprocessing.context.AuthenticationError: digest sent was rejected

Strangely, it seems a unique problem with torchrun launcher, ForkingPickler and pytorch tensor(Any dummy tensor to broadcast among workers will trigger the error). Using TorchShmSerializedList(NumpySerializedList) instead of TorchShmSerializedList(TorchSerializedList) may resolve it but unfavorably cost more memory. Unfortunately, I cannot find anything related online and it will be highly appreciated if you(or anybody) have more insight into the issue.

EOFError when using num_gpus > 2

Yuxin, thank you very much for the awesome blog and code repo! The main-multigpu-sharedmem.py works nicely as expected when using 2 GPUs. However, if I set the num_gpus here to 3, it will crash with the full error stack as follows:

❯ python main-multigpu-sharedmem.py
Serializing 860001 elements to byte tensors and concatenating them all ...
Serialized dataset takes 505.05 MiB
Traceback (most recent call last):
  File "/home/xiaot/miniconda3/envs/ram/lib/python3.8/multiprocessing/resource_sharer.py", line 147, in _serve
    send, close = self._cache.pop(key)
KeyError: 1
Worker 2 obtains a dataset of length=860001 from its local leader.
  time     PID  rss    pss    uss    shared    shared_file
------  ------  -----  -----  -----  --------  -------------
 77856  616200  1.5G   1.5G   1.4G   105.0M    104.9M
  time     PID  rss     pss     uss     shared    shared_file
------  ------  ------  ------  ------  --------  -------------
 77856  616202  229.4M  150.6M  124.7M  104.7M    104.7M
Traceback (most recent call last):
  File "main-multigpu-sharedmem.py", line 56, in <module>
    launch(main, num_gpus, dist_url="auto")
  File "/home/xiaot/miniconda3/envs/ram/lib/python3.8/site-packages/detectron2/engine/launch.py", line 69, in launch
    mp.start_processes(
  File "/home/xiaot/miniconda3/envs/ram/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 198, in start_processes
    while not context.join():
  File "/home/xiaot/miniconda3/envs/ram/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 160, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException: 

-- Process 1 terminated with the following error:
Traceback (most recent call last):
  File "/home/xiaot/miniconda3/envs/ram/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
    fn(i, *args)
  File "/home/xiaot/miniconda3/envs/ram/lib/python3.8/site-packages/detectron2/engine/launch.py", line 123, in _distributed_worker
    main_func(*args)
  File "/home/xiaot/Downloads/RAM-multiprocess-dataloader/main-multigpu-sharedmem.py", line 24, in main
    ds = DatasetFromList(TorchShmSerializedList(
  File "/home/xiaot/Downloads/RAM-multiprocess-dataloader/serialize.py", line 73, in __init__
    self._addr, self._lst = mp.reduction.ForkingPickler.loads(serialized)
  File "/home/xiaot/miniconda3/envs/ram/lib/python3.8/site-packages/torch/multiprocessing/reductions.py", line 305, in rebuild_storage_fd
    fd = df.detach()
  File "/home/xiaot/miniconda3/envs/ram/lib/python3.8/multiprocessing/resource_sharer.py", line 58, in detach
    return reduction.recv_handle(conn)
  File "/home/xiaot/miniconda3/envs/ram/lib/python3.8/multiprocessing/reduction.py", line 189, in recv_handle
    return recvfds(s, 1)[0]
  File "/home/xiaot/miniconda3/envs/ram/lib/python3.8/multiprocessing/reduction.py", line 159, in recvfds
    raise EOFError
EOFError

I wonder if this can be reproduced by others and if there are any potential solutions to it? Many thanks.

How to work with pin_memory?

Really appreciate a great blog and clear demo.

I implemented my own version of COCO loader based on your TorchSerializedList and the default RAM usage decreases significantly. However, when I set pin_memory=True in PyTorch dataloader, I see that USS gets increased quite a lot while the shared memory stays the same for the default COCODetection and TorchSerializedList

  RSS USS Shared
default / pin_memory=True 7.31GB 3.45GB 3.87GB
default / pin_memory=False 4.31GB 455.76MB 3.87GB
shm / pin_memory=True 4.72GB 3.65GB 1.22GB
shm / pin_memory=False 1.72GB 508.69MB 1.22GB

From a quick search, the shared memory doesn't seem to work with pin_memory well according to this thread
. Detectron2 repo doesn't seem to use any pin_memory in the dataloader. Can you provide your thoughts on relationship of using the serialized dataset with pin_memory=True?

Thanks!

[Question] Use `ForkingPickler` to pickle/unpickle tensors in shared memory

Hi, thank you for your great blog! I am new to Python and have a probably very basic question regarding ForkingPickler. When I tried to manage the sharing context (tensor) explicitly like the following:

if dist.get_rank(group=local_group) == 0:
    x = torch.arange(5)
    handles = [None] + [
           bytes(multiprocessing.reduction.ForkingPickler.dumps(x))
           for _ in range(num_local_process - 1)]
else:
    handles = None

handle = local_scatter(handles)
if dist.get_rank(group=local_group) > 0:       
    x =  multiprocessing.reduction.ForkingPickler.loads(handle)
    print(x)

I thought it would make tensor x in the shared memory, however, it crashed with the following stacktraces:

Traceback (most recent call last):
  File "ogbn_products_nc_mpi_shared_mem.py", line 314, in <module>
    x =  multiprocessing.reduction.ForkingPickler.loads(handle)
  File "/usr/local/lib/python3.8/dist-packages/torch/multiprocessing/reductions.py", line 307, in rebuild_storage_fd
    fd = df.detach()
  File "/usr/lib/python3.8/multiprocessing/resource_sharer.py", line 57, in detach
    with _resource_sharer.get_connection(self._id) as conn:
  File "/usr/lib/python3.8/multiprocessing/resource_sharer.py", line 87, in get_connection
    c = Client(address, authkey=process.current_process().authkey)
  File "/usr/lib/python3.8/multiprocessing/connection.py", line 508, in Client
    answer_challenge(c, authkey)
  File "/usr/lib/python3.8/multiprocessing/connection.py", line 757, in answer_challenge
    response = connection.recv_bytes(256)        # reject large message
  File "/usr/lib/python3.8/multiprocessing/connection.py", line 216, in recv_bytes
    buf = self._recv_bytes(maxlength)
  File "/usr/lib/python3.8/multiprocessing/connection.py", line 414, in _recv_bytes
    buf = self._recv(4)
  File "/usr/lib/python3.8/multiprocessing/connection.py", line 379, in _recv
    chunk = read(handle, remaining)
ConnectionResetError: [Errno 104] Connection reset by peer

I rewrite the auth_key to sync them in each local group of processes to run with mpirun/torchrun, but I am puzzled on how to correctly use the PyTorch's customized pickler for sharing memory among local processes...
Could you please help explain it?

Reproduction in PyTorch Lightning

Hi @ppwwyyxx,

I have been trying to reproduce your method of multi-gpu RAM sharing within PyTorch Lightning.
I can't get it to share between GPU workers though and thought I might reach out and ask whether you have any clue.
For this reason I uploaded my attempt here: https://github.com/luzuku/ram-share-lightning

I replaced the dataset with random numbers and I am only looking at total_pss and it sadly doubles when doubling the number of GPUs. Maybe I overlooked something?

I would really appreciate it if you could find some time. You could also include a working Lightning implementation in your repository for anybody else interested.

Reproducing the results with a `torch.utils.data.DataLoader`?

Thanks for this wonderful repo. It's a pleasure to work with it and very educational.

Is there a possibility to obtains the process ids from a a torch.utils.data.DataLoader instance? For example, with

ds = NaiveDatasetFromList(create_coco())
loader = torch.utils.data.DataLoader(ds, num_workers=4)

Could I infer the process ids to replicate the memory usage table you created?

can not reproduce the main-naive result?

hi, your blog helped me a lot, thanks.

but i cant reproduce the main-naive result.
in my case, the USS reached ~226M and stopped which didnt go up to 1.8G in your log. can you please tell me why?

time PID rss pss uss shared shared_file


86396 14650 2.7G 793.4M 307.7M 2.4G 17.8M
86396 14779 2.6G 684.6M 198.5M 2.4G 17.4M
86396 14780 2.6G 684.6M 198.4M 2.4G 17.4M
86396 14781 2.6G 685.0M 198.9M 2.4G 17.4M
86396 14782 2.6G 686.0M 200.2M 2.4G 17.4M
time PID rss pss uss shared shared_file


86397 14650 2.7G 797.2M 312.5M 2.4G 17.8M
86397 14779 2.6G 688.5M 203.3M 2.4G 17.4M
86397 14780 2.6G 688.5M 203.2M 2.4G 17.4M
86397 14781 2.6G 688.8M 203.7M 2.4G 17.4M
86397 14782 2.6G 689.8M 205.0M 2.4G 17.4M
time PID rss pss uss shared shared_file


86398 14650 2.7G 801.2M 317.5M 2.4G 17.8M
86398 14779 2.6G 692.5M 208.3M 2.4G 17.4M
86398 14780 2.6G 692.4M 208.2M 2.4G 17.4M
86398 14781 2.6G 692.8M 208.7M 2.4G 17.4M
86398 14782 2.6G 693.8M 209.9M 2.4G 17.4M
time PID rss pss uss shared shared_file


86399 14650 2.7G 805.1M 322.4M 2.4G 17.8M
86399 14779 2.6G 696.3M 213.1M 2.4G 17.4M
86399 14780 2.6G 696.4M 213.1M 2.4G 17.4M
86399 14781 2.6G 696.7M 213.6M 2.4G 17.4M
86399 14782 2.6G 697.7M 214.8M 2.4G 17.4M
time PID rss pss uss shared shared_file


86400 14650 2.7G 809.0M 327.3M 2.4G 17.8M
86400 14779 2.6G 700.3M 218.1M 2.4G 17.4M
86400 14780 2.6G 700.3M 218.0M 2.4G 17.4M
86400 14781 2.6G 700.7M 218.6M 2.4G 17.4M
86400 14782 2.6G 701.6M 219.7M 2.4G 17.4M
time PID rss pss uss shared shared_file


86402 14650 2.7G 813.0M 332.3M 2.3G 17.8M
86402 14779 2.6G 704.2M 223.0M 2.4G 17.4M
86402 14780 2.6G 704.3M 222.9M 2.4G 17.4M
86402 14781 2.6G 704.6M 223.5M 2.4G 17.4M
86402 14782 2.6G 705.5M 224.7M 2.3G 17.4M
time PID rss pss uss shared shared_file


86403 14650 2.7G 815.7M 335.9M 2.3G 17.8M
86403 14779 2.6G 707.0M 226.6M 2.3G 17.4M
86403 14780 2.6G 706.9M 226.5M 2.3G 17.4M
86403 14781 2.6G 706.9M 226.5M 2.3G 17.4M
86403 14782 2.6G 706.9M 226.3M 2.3G 17.4M
time PID rss pss uss shared shared_file


86404 14650 2.7G 815.7M 335.9M 2.3G 17.8M
86404 14779 2.6G 707.0M 226.6M 2.3G 17.4M
86404 14780 2.6G 706.9M 226.5M 2.3G 17.4M
86404 14781 2.6G 706.9M 226.5M 2.3G 17.4M
86404 14782 2.6G 706.9M 226.3M 2.3G 17.4M
time PID rss pss uss shared shared_file


86405 14650 2.7G 815.7M 335.9M 2.3G 17.8M
86405 14779 2.6G 707.0M 226.6M 2.3G 17.4M
86405 14780 2.6G 706.9M 226.5M 2.3G 17.4M
86405 14781 2.6G 706.9M 226.5M 2.3G 17.4M
86405 14782 2.6G 706.9M 226.3M 2.3G 17.4M
time PID rss pss uss shared shared_file


86406 14650 2.7G 815.7M 335.9M 2.3G 17.8M
86406 14779 2.6G 707.0M 226.6M 2.3G 17.4M
86406 14780 2.6G 706.9M 226.5M 2.3G 17.4M
86406 14781 2.6G 706.9M 226.5M 2.3G 17.4M
86406 14782 2.6G 706.9M 226.3M 2.3G 17.4M

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.