Comments (6)
1. 关于如何增加 batch_size?
非常感谢您能提供一个具体实现的 demo,我会测试一下,如果一切正常,我会把这个方法尽快加到库里。
2. hard negative mining 更为重要?
是的,我们发现 m3e 的模型在 reranking 上面的表现并不好,究其原因是缺少 hard negative sample ,导致模型对于文本细节的区分能力较差。
3. 如何构建样本池
我觉得来自同一个 task 就可以 。我能想到比较好的方案是,先通过 m3e 来做召回,在使用下面的模型来做精排。
- openai
- m3e-base
- sgpt 的 cross 方法,使用 chatglm2,百川之类的
用上面多个模型投票(其他 ensemble 方法也可以)一起来判断 top k
你现在做的方向其实也是 m3e 想要优化的方向,如果有任何进展,欢迎交流~
from uniem.
microsoft/unilm#1120 (comment)
但是从E5作者的issue来看,似乎hard negative mining更为重要。
对于M3E的数据集的话,做hard negative mining。
我的理解是先建立一个样本池,对于每一个样本,都把它和样本池里的其他样本用模型作比较,然后筛出top k score的,作为hard negative。
不过这个样本池应该怎样建立比较好呢,首先它们都应该来自同一个task,同一种domain。emmmmm,暂时也没想到更多了.......
from uniem.
突然想起来,简单地update下:
上面那个demo我这里实际跑起来效果几乎没有提升,而由于需要在forward的时候对tensor进行gather,会增加GPU的显存负载,导致batch size参数还需要设低一些。
我后面的实验也几乎没用这种gather方式了,不过具体不work的原因的话,目前还没深入研究。
以后有时间我会再看看。
from uniem.
目前我能想到最简单直接的方法就是使用 FSDP,把模型的权重,优化器和梯度的状态分配给不同的机器,从而增加 batch_size。
稍微复杂一点的就是把最终的结果汇聚到一张卡上,计算 loss 后再广播给所有的节点,但我没写过这部分的代码,所以不确定容不容易实现,理论上应该挺容易的。
from uniem.
我写了个简单的实现,在我自己的机器上能够跑,后面我也看看这样做的效果。
这里分享下:
import os
from accelerate import Accelerator
class AllGather_multi(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor, accelerator: Accelerator):
ctx.rank = int(os.environ['RANK'])
ctx.batch_size = tensor.shape[0]
tensor = accelerator.gather(tensor)
return tensor
@staticmethod
def backward(ctx, grad_output):
return (
grad_output[ctx.batch_size * ctx.rank : ctx.batch_size * (ctx.rank + 1)],
None, None,
)
class EmbedderForPairAllNegTrain(EmbedderForPairInBatchNegTrain):
def __init__(
self,
accelerator: Accelerator,
*args, **kwargs,
):
super().__init__(*args, **kwargs)
self.accelrator = accelerator
def forward(self, text_ids: torch.Tensor, text_pos_ids: torch.Tensor) -> dict[str, torch.Tensor]:
text_embeddings = self.embedder(text_ids)
text_pos_embeddings = self.embedder(text_pos_ids)
text_embeddings = AllGather_multi.apply(text_embeddings, self.accelrator)
text_pos_embeddings = AllGather_multi.apply(text_pos_embeddings, self.accelrator)
loss = self.criterion(text_embeddings, text_pos_embeddings)
return {'loss': loss}
初始化EmbedderForPairAllNegTrain时候传入一下accelerator就可以了
from uniem.
收到
from uniem.
Related Issues (20)
- 关于huggingface方法调用 HOT 1
- sentence-transformer调用huggingface模型 HOT 1
- 负采样 HOT 3
- 请教贴:文本最大长度 HOT 5
- 进行评测时会报错,分叉可能会导致死锁. HOT 1
- 求一份评测数据集 HOT 1
- 微调后模型保存和load的问题 HOT 3
- m3e-large数据集的相关问题 HOT 1
- m3e训练的时候使用的数据集是hugg上面列出的数据集,训练和测试集和验证集一起用来训练了吗? HOT 1
- 请问微调之后的模型如何支持C_MTEB数据集上的评测呢 HOT 2
- 实际测试 PairInBatchNegSoftmaxContrastLoss和PairInBatchNegCoSentLoss的值是一样的 HOT 1
- 转onnx问题 HOT 1
- 能不能说明一下显卡要求啊? HOT 3
- fintuner如何使用gpu? HOT 1
- 问题 HOT 3
- 代码跑着跑着就挂了,CUDA out of memory HOT 1
- Loss固定不变或者上升 HOT 2
- M3E如何指定为GPU训练,而不是CPU训练? HOT 1
- 模型训练后生成的内容不对 HOT 1
- 微调后模型没有保存2_Dense HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from uniem.