bojone / simcse Goto Github PK
View Code? Open in Web Editor NEWSimCSE在中文任务上的简单实验
SimCSE在中文任务上的简单实验
已解决
请问在做MLM+CL无监督训练的时候是直接用随机mask掉之后的句子做dropout计算CLloss吗?,
例如,先对句子A=[a,b,c,d,e,f]做随机MASK得到B=[a,[MASK],c,[MASK],e,f], 再把句子B两次输入到bert模型中得到dropout之后的两个句子对,然后计算得到CL loss和MLM loss
请问我说的对么?
if pooling == 'first-last-avg':
outputs = [
keras.layers.GlobalAveragePooling1D()(outputs[0]),
keras.layers.GlobalAveragePooling1D()(outputs[-1])
]
output = keras.layers.Average()(outputs)
elif pooling == 'last-avg':
output = keras.layers.GlobalAveragePooling1D()(outputs[-1])
elif pooling == 'cls':
output = keras.layers.Lambda(lambda x: x[:, 0])(outputs[-1])
elif pooling == 'pooler':
output = bert.output
平时不用keras, 但是查了一下GlobalAveragePooling1D是有mask入参的,这个不加attention_mask进来不是会有问题吗?
https://keras.io/api/layers/pooling_layers/global_average_pooling1d/
你好,遇到了一个疑问,我用自己的中文数据在这个预训练中文模型上做了微调https://huggingface.co/cyclone/simcse-chinese-roberta-wwm-ext,用于生成不同句子的向量,计算其相似度来做匹配。实际应用发现,当两个句子仅有靠前位置的词语不同时,得到的向量差异较大,相似度较低。但当两个句子仅有中间/靠后位置词语不同时,得到的向量一致,相似度为1。
不知道这种情况是本质上是什么原因导致的呢,大佬们有没有遇到这种问题,以及我该如何做一些调整呢?
def simcse_loss(y_true, y_pred):
"""用于SimCSE训练的loss
"""
# 构造标签
idxs = K.arange(0, K.shape(y_pred)[0])
idxs_1 = idxs[None, :]
idxs_2 = (idxs + 1 - idxs % 2 * 2)[:, None]
y_true = K.equal(idxs_1, idxs_2)
y_true = K.cast(y_true, K.floatx())
# 计算相似度
y_pred = K.l2_normalize(y_pred, axis=1)
similarities = K.dot(y_pred, K.transpose(y_pred))
similarities = similarities - tf.eye(K.shape(y_pred)[0]) * 1e12
similarities = similarities * 20
loss = K.categorical_crossentropy(y_true, similarities, from_logits=True)
return K.mean(loss)
all_corrcoefs = []
for (a_vecs, b_vecs), labels in zip(all_vecs, all_labels):
a_vecs = l2_normalize(a_vecs)
b_vecs = l2_normalize(b_vecs)
sims = (a_vecs * b_vecs).sum(axis=1)
corrcoef = compute_corrcoef(labels, sims)
all_corrcoefs.append(corrcoef)
sims和labels都是维度为1的一维向量。方差为0,为什么可以求相关系数? 是我理解错了吗
Hi,
This maybe the reason why you get worse result.
You can refer the source code of transformers BertPooler
为啥报这个错误呢?请问下要跑这个模型需要多大的资源呢?
通过设置 TF_KERAS==1 , 切换至 tf.keras. 启动训练脚本后可正常编译模型, 但训练时报错.
报错信息如下:
Traceback (most recent call last): File "train.py", line 94, in <module> train_generator.forfit(), steps_per_epoch=len(train_generator), epochs=1 File "/Users/yuxi/opt/anaconda3/envs/TrainingRobot/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py", line 728, in fit use_multiprocessing=use_multiprocessing) File "/Users/yuxi/opt/anaconda3/envs/TrainingRobot/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py", line 224, in fit distribution_strategy=strategy) File "/Users/yuxi/opt/anaconda3/envs/TrainingRobot/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py", line 547, in _process_training_inputs use_multiprocessing=use_multiprocessing) File "/Users/yuxi/opt/anaconda3/envs/TrainingRobot/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py", line 606, in _process_inputs use_multiprocessing=use_multiprocessing) File "/Users/yuxi/opt/anaconda3/envs/TrainingRobot/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/data_adapter.py", line 566, in __init__ reassemble, nested_dtypes, output_shapes=nested_shape) File "/Users/yuxi/opt/anaconda3/envs/TrainingRobot/lib/python3.7/site-packages/tensorflow_core/python/data/ops/dataset_ops.py", line 540, in from_generator output_types, tensor_shape.as_shape, output_shapes) File "/Users/yuxi/opt/anaconda3/envs/TrainingRobot/lib/python3.7/site-packages/tensorflow_core/python/data/util/nest.py", line 471, in map_structure_up_to results = [func(*tensors) for tensors in zip(*all_flattened_up_to)] File "/Users/yuxi/opt/anaconda3/envs/TrainingRobot/lib/python3.7/site-packages/tensorflow_core/python/data/util/nest.py", line 471, in <listcomp> results = [func(*tensors) for tensors in zip(*all_flattened_up_to)] File "/Users/yuxi/opt/anaconda3/envs/TrainingRobot/lib/python3.7/site-packages/tensorflow_core/python/framework/tensor_shape.py", line 1216, in as_shape return TensorShape(shape) File "/Users/yuxi/opt/anaconda3/envs/TrainingRobot/lib/python3.7/site-packages/tensorflow_core/python/framework/tensor_shape.py", line 776, in __init__ self._dims = [as_dimension(d) for d in dims_iter] File "/Users/yuxi/opt/anaconda3/envs/TrainingRobot/lib/python3.7/site-packages/tensorflow_core/python/framework/tensor_shape.py", line 776, in <listcomp> self._dims = [as_dimension(d) for d in dims_iter] File "/Users/yuxi/opt/anaconda3/envs/TrainingRobot/lib/python3.7/site-packages/tensorflow_core/python/framework/tensor_shape.py", line 718, in as_dimension return Dimension(value) File "/Users/yuxi/opt/anaconda3/envs/TrainingRobot/lib/python3.7/site-packages/tensorflow_core/python/framework/tensor_shape.py", line 193, in __init__ self._value = int(value) TypeError: int() argument must be a string, a bytes-like object or a number, not 'tuple'
使用 keras 可正常训练. 至于想要用 tf.keras训练的原因是, 想将 hdf5 格式保存的模型切换至 SavedModel. keras.model 无法直接转换.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.