GithubHelp home page GithubHelp logo

很棒的效果!请教一下,我在复现的时候发现80G A100放不下80的batch_size,是怎么做到80的?另外train_m3e.py这个脚本采用的是PairInBatch,没有用到add_swap_loss。process_zh_datasets.py处理的数据也是Pair格式 about uniem HOT 11 CLOSED

wangyuxinwhy avatar wangyuxinwhy commented on July 29, 2024
很棒的效果!请教一下,我在复现的时候发现80G A100放不下80的batch_size,是怎么做到80的?另外train_m3e.py这个脚本采用的是PairInBatch,没有用到add_swap_loss。process_zh_datasets.py处理的数据也是Pair格式

from uniem.

Comments (11)

hjq133 avatar hjq133 commented on July 29, 2024 2

之前说的有些问题,我用的是32 张 32G V100,而不是40G A100,batch size 能开到 34。用的是deepspeed的ZERO2/3优化,以及 mix precision 16。那么看上去如果是80G A100的话,似乎是能开到batch size 80的,以下是我的deepspeed config,给你做参考:
deepspeed ZERO2
{ "zero_optimization": { "stage": 2, "allgather_partitions": true, "allgather_bucket_size": 2e8, "overlap_comm": true, "reduce_scatter": true, "reduce_bucket_size": 2e8, "contiguous_gradients": true }, "train_micro_batch_size_per_gpu": "auto" }

deepspeed ZERO3
{ "zero_optimization": { "stage": 3, "offload_optimizer": { "device": "cpu", "pin_memory": true }, "offload_param": { "device": "cpu", "pin_memory": true }, "overlap_comm": true, "contiguous_gradients": true, "reduce_bucket_size": "auto", "stage3_prefetch_bucket_size": "auto", "stage3_param_persistence_threshold": "auto", "sub_group_size": 1e9, "stage3_max_live_parameters": 1e9, "stage3_max_reuse_distance": 1e9, "stage3_gather_16bit_weights_on_model_save": "auto" }, "train_micro_batch_size_per_gpu": "auto" }

from uniem.

hjq133 avatar hjq133 commented on July 29, 2024

可能得开 ZERO2/3 优化 + mix precision + gradient checkpointing才能扩到那么大。我这里40G A100,fp16 + ZERO2 能开到34,不过似乎也差一点。

from uniem.

ARSblithe212 avatar ARSblithe212 commented on July 29, 2024

如果开gradient accumulation 就丧失了增大batch_size增加nagetivate sample num的目的了。

from uniem.

wangyuxinwhy avatar wangyuxinwhy commented on July 29, 2024

有开混合精度吗?在 BatchSize 比较大的时候,开混合精度是会减少显存的使用的。

m3e 的数据格式是 PairRecord ,所以没有 add_swap_loss。在复现 instructor 的时候使用的是 Medi 数据集,Medi 数据集是 TripletRecord ,添加了 swap loss。

from uniem.

wangyuxinwhy avatar wangyuxinwhy commented on July 29, 2024

如果开gradient accumulation 就丧失了增大batch_size增加nagetivate sample num的目的了。

嗯,他指的是 gradient checkpointing

from uniem.

wangyuxinwhy avatar wangyuxinwhy commented on July 29, 2024

可能得开 ZERO2/3 优化 + mix precision + gradient checkpointing才能扩到那么大。我这里40G A100,fp16 + ZERO2 能开到34,不过似乎也差一点。

几张卡开的 ZERO2 ?

from uniem.

ARSblithe212 avatar ARSblithe212 commented on July 29, 2024

如果开gradient accumulation 就丧失了增大batch_size增加nagetivate sample num的目的了。

嗯,他指的是 gradient checkpointing

好的,我试试gradient checkpointing 和开混合精度

from uniem.

ARSblithe212 avatar ARSblithe212 commented on July 29, 2024

有开混合精度吗?在 BatchSize 比较大的时候,开混合精度是会减少显存的使用的。

m3e 的数据格式是 PairRecord ,所以没有 add_swap_loss。在复现 instructor 的时候使用的是 Medi 数据集,Medi 数据集是 TripletRecord ,添加了 swap loss。

请教你,batch size 80是怎么达到的?

from uniem.

wangyuxinwhy avatar wangyuxinwhy commented on July 29, 2024

有开混合精度吗?在 BatchSize 比较大的时候,开混合精度是会减少显存的使用的。
m3e 的数据格式是 PairRecord ,所以没有 add_swap_loss。在复现 instructor 的时候使用的是 Medi 数据集,Medi 数据集是 TripletRecord ,添加了 swap loss。

请教你,batch size 80是怎么达到的?

没有其他的优化了,不过可能 instructor 这个 max-length 我可能记混了... max-length 应该是 400。m3e 的是 512 ,复现 instructor 的是 400。你先试试 512 的,如果还是不行的话,再换成 400。实在抱歉,没有好好记录实验的细节。

from uniem.

ARSblithe212 avatar ARSblithe212 commented on July 29, 2024

有开混合精度吗?在 BatchSize 比较大的时候,开混合精度是会减少显存的使用的。
m3e 的数据格式是 PairRecord ,所以没有 add_swap_loss。在复现 instructor 的时候使用的是 Medi 数据集,Medi 数据集是 TripletRecord ,添加了 swap loss。

请教你,batch size 80是怎么达到的?

没有其他的优化了,不过可能 instructor 这个 max-length 我可能记混了... max-length 应该是 400。m3e 的是 512 ,复现 instructor 的是 400。你先试试 512 的,如果还是不行的话,再换成 400。实在抱歉,没有好好记录实验的细节。

好的 谢谢!

from uniem.

ARSblithe212 avatar ARSblithe212 commented on July 29, 2024

之前说的有些问题,我用的是32 张 32G V100,而不是40G A100,batch size 能开到 34。用的是deepspeed的ZERO2/3优化,以及 mix precision 16。那么看上去如果是80G A100的话,似乎是能开到batch size 80的,以下是我的deepspeed config,给你做参考: deepspeed ZERO2 { "zero_optimization": { "stage": 2, "allgather_partitions": true, "allgather_bucket_size": 2e8, "overlap_comm": true, "reduce_scatter": true, "reduce_bucket_size": 2e8, "contiguous_gradients": true }, "train_micro_batch_size_per_gpu": "auto" }

deepspeed ZERO3 { "zero_optimization": { "stage": 3, "offload_optimizer": { "device": "cpu", "pin_memory": true }, "offload_param": { "device": "cpu", "pin_memory": true }, "overlap_comm": true, "contiguous_gradients": true, "reduce_bucket_size": "auto", "stage3_prefetch_bucket_size": "auto", "stage3_param_persistence_threshold": "auto", "sub_group_size": 1e9, "stage3_max_live_parameters": 1e9, "stage3_max_reuse_distance": 1e9, "stage3_gather_16bit_weights_on_model_save": "auto" }, "train_micro_batch_size_per_gpu": "auto" }

能否参考一下代码。我用accelerate开deepspeed报错。[email protected] 谢谢

from uniem.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.